Model Interpretation
- class ai4water.postprocessing.interpret.Interpret(model, save: bool = False, show: bool = True)[source]
Bases:
PlotInterprets the ai4water Model. The
Interpretclass is different than the methods inexplainmodule. TheInterpretclass explains the behaviour of the model by using consituents of the model itself for example attention weights or feature importance.- compare_xgb_f_imp(calculation_method='all', rescale=True, figsize: Optional[tuple] = None, backend: str = 'matplotlib', **kwargs)[source]
compare various feature importance calculations methods that are built in in XGBoost
- get_enc_var_selection_weights(data, data_type: str = 'test')[source]
Returns encoder variable selection weights of TFT model
- interpret_attention_lstm(x=None, data=None, data_type: str = 'test')[source]
- Parameters:
x – input data. If not given,
dataargument must be given.data – the data to use to interpret model. It is only required when
xis not given.data_type – either
training,test,validationorall. It is only useful whendataargument is used.
- interpret_example_tft(example_index: int, x=None, data=None, data_type='test')[source]
interprets a single example using TFT model.
- interpret_ft_transformer(x=None, data=None, data_type: str = 'test')[source]
- Parameters:
x – input data. If not given,
dataargument must be given.data – the data to use to interpret model. It is only required when
xis not given.data_type – either
training,test,validationorall. It is only useful whendataargument is used.
- interpret_tab_transformer(x=None, data=None, data_type: str = 'test')[source]
- Parameters:
x – input data. If not given,
dataargument must be given.data – the data to use to interpret model. It is only required when
xis not given.data_type – either
training,test,validationorall. It is only useful whendataargument is used.
- interpret_tft(x=None, y=None, data=None, data_type='test')[source]
global interpretation of TFT model.
- Parameters:
x – input data. If not given,
dataargument must be given.y – labels/target/true data corresponding to
x. It is only used for plotting.data – the data to use to interpret model. It is only required when
xis not given.data_type – either
training,test,validationorall. It is only useful whendataargument is used.
- property model
- plot_feature_importance(importance=None, use_xgb=False, max_num_features=20, figsize=None, **kwargs)[source]
plots feature importance when the model is tree based.
- tft_attention_components(x=None, data=None, data_type: str = 'test')[source]
Gets attention components of tft layer from ai4water’s Model.
- Parameters:
x – the input data to the model
data – raw data from which
x/inputs are extracted.data_type – the data to use to calculate attention components
- Returns:
dict – dictionary containing attention components of tft as numpy arrays. Following four attention components are present in the dictionary
decoder_self_attn: (attention_heads, ?, total_time_steps, 22)
static_variable_selection_weights:
encoder_variable_selection_weights: (?, encoder_steps, input_features)
decoder_variable_selection_weights: (?, decoder_steps, input_features)
str – a string indicating which data was used