This notebook shows how to generate interpretable results using deep learning model. The deep learning model used is Input Attention LSTM model of Qin et al., 2017. The dataset used is of Camels Australia

Open In Colab

View Source on GitHub

[2]:
# try:
#     import ai4water
# except (ImportError, ModuleNotFoundError):
#     !pip install git+https://github.com/AtrCheema/AI4Water.git@b0cb440f1c5e28477e2d1ea6f3ece2f68851bd93
[3]:
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\numpy\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:
D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\numpy\.libs\libopenblas.EL2C6PLE4ZYW3ECEVIV3OXXGRN2NRFM2.gfortran-win_amd64.dll
D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\numpy\.libs\libopenblas.GK7GX5KEQ4F6UYO3P26ULGBQYHGQO7J4.gfortran-win_amd64.dll
  warnings.warn("loaded more than 1 DLL from .libs:"
[4]:
import os

from easy_mpl import hist
import matplotlib.pyplot as plt

from ai4water import InputAttentionModel
from ai4water.datasets import CAMELS_AUS
from ai4water.postprocessing import Interpret
from ai4water.utils.utils import get_version_info
D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\sklearn\experimental\enable_hist_gradient_boosting.py:16: UserWarning: Since version 1.0, it is not needed to import enable_hist_gradient_boosting anymore. HistGradientBoostingClassifier and HistGradientBoostingRegressor are now stable and can be normally imported from sklearn.ensemble.
  warnings.warn(
[5]:
for lib,ver in get_version_info().items():
    print(lib, ver)
python 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:20:16) [MSC v.1916 64 bit (AMD64)]
os nt
ai4water 1.06
lightgbm 3.3.1
tcn 3.4.0
catboost 0.26
xgboost 1.5.0
easy_mpl 0.21.3
SeqMetrics 1.3.3
tensorflow 2.7.0
keras.api._v2.keras 2.7.0
numpy 1.21.0
pandas 1.3.4
matplotlib 3.4.3
h5py 3.5.0
sklearn 1.0.1
shapefile 2.3.0
fiona 1.8.22
xarray 0.20.1
netCDF4 1.5.7
optuna 2.10.1
skopt 0.9.0
hyperopt 0.2.7
plotly 5.3.1
lime NotDefined
seaborn 0.11.2
[6]:
dataset = CAMELS_AUS(path="F:\\data\\CAMELS\\CAMELS_AUS")

inputs = ['et_morton_point_SILO',
           'precipitation_AWAP',
           'tmax_AWAP',
           'tmin_AWAP',
           'vprp_AWAP',
           'rh_tmax_SILO',
           'rh_tmin_SILO'
          ]

outputs = ['streamflow_MLd']

[7]:

data = dataset.fetch('401203', dynamic_features=inputs+outputs, as_dataframe=True) data
[7]:
401203
time dynamic_features
1957-01-01 et_morton_point_SILO 8.062837
precipitation_AWAP 0.000000
tmax_AWAP 20.784480
tmin_AWAP 4.358533
vprp_AWAP 8.142806
... ... ...
2014-12-31 tmin_AWAP 4.522077
vprp_AWAP 8.885449
rh_tmax_SILO 26.714771
rh_tmin_SILO 75.013684
streamflow_MLd 322.791000

169472 rows × 1 columns

[8]:
data = data.unstack()
data.columns = [a[1] for a in data.columns.to_flat_index()]
data.head()
[8]:
et_morton_point_SILO precipitation_AWAP tmax_AWAP tmin_AWAP vprp_AWAP rh_tmax_SILO rh_tmin_SILO streamflow_MLd
time
1957-01-01 8.062837 0.0 20.784480 4.358533 8.142806 28.888577 88.900993 538.551
1957-01-02 8.519483 0.0 27.393169 4.835900 5.281136 23.516738 99.002080 531.094
1957-01-03 9.879688 0.0 28.945301 8.175408 12.920509 19.434872 77.429917 503.011
1957-01-04 6.744638 0.0 26.133843 7.017990 13.951027 42.350667 100.000000 484.512
1957-01-05 8.135359 0.0 21.450775 8.686258 12.168659 30.374862 87.634483 463.416
[9]:
data.shape
[9]:
(21184, 8)
[10]:
(data['streamflow_MLd'].values<0.0).sum()
[10]:
0
[11]:
(data['streamflow_MLd'].values==0.0).sum()
[11]:
0
[12]:
data.isna().sum()
[12]:
et_morton_point_SILO    0
precipitation_AWAP      0
tmax_AWAP               0
tmin_AWAP               0
vprp_AWAP               0
rh_tmax_SILO            0
rh_tmin_SILO            0
streamflow_MLd          0
dtype: int64
[13]:
_ = data.plot(subplots=True, sharex=True, figsize=(10, 10))
../../_images/_notebooks_model_interpretability_ia_12_0.png
[14]:


_ = hist(data, share_axes=False, subplots_kws=dict(figsize=(12, 10)), edgecolor = "k", grid=False)
../../_images/_notebooks_model_interpretability_ia_13_0.png
[15]:
skew_inputs = [
               'precipitation_AWAP',
           'rh_tmin_SILO'
]
[16]:
normal_inputs = ['et_morton_point_SILO',
           'tmax_AWAP',
           'tmin_AWAP',
           'vprp_AWAP',
           'rh_tmax_SILO',
          ]
[17]:


model = InputAttentionModel( input_features=inputs, output_features=outputs, epochs=500, ts_args={'lookback':15}, lr=0.0001, batch_size=64, x_transformation=[{'method': 'robust', 'features': normal_inputs}, {'method': 'log', "replace_zeros": True, 'features': skew_inputs}], y_transformation={'method': 'robust', 'features': outputs}, verbosity=0 )
building input attention model
[18]:
model.verbosity=1
[19]:
model.inputs
[19]:
[<tf.Tensor 'enc_input1:0' shape=(None, 15, 7) dtype=float32>,
 <tf.Tensor 'enc_first_cell_state_1:0' shape=(None, 20) dtype=float32>,
 <tf.Tensor 'enc_first_hidden_state_1:0' shape=(None, 20) dtype=float32>]
[20]:
model.outputs
[20]:
[<tf.Tensor 'dense/BiasAdd:0' shape=(None, 1) dtype=float32>]
[21]:
h = model.fit(data=data)
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)
inferred mode is regression. Ignore this message if the inferred mode is correct.
input_x shape:  [(2964, 15, 7), (2964, 20), (2964, 20)]
target shape:  (2964, 1)
Train on 11854 samples, validate on 2964 samples
Epoch 1/500
11584/11854 [============================>.] - ETA: 0s - loss: 1.4162
`Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
11854/11854 [==============================] - 8s 714us/sample - loss: 1.4206 - val_loss: 1.3942
Epoch 2/500
11854/11854 [==============================] - 2s 139us/sample - loss: 1.2633 - val_loss: 1.2610
Epoch 3/500
11854/11854 [==============================] - 2s 135us/sample - loss: 1.1647 - val_loss: 1.1965
Epoch 4/500
11854/11854 [==============================] - 2s 137us/sample - loss: 1.0976 - val_loss: 1.1317
Epoch 5/500
11854/11854 [==============================] - 2s 137us/sample - loss: 1.0283 - val_loss: 1.0351
Epoch 6/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.9542 - val_loss: 0.9281
Epoch 7/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.8775 - val_loss: 0.8254
Epoch 8/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.8192 - val_loss: 0.8105
Epoch 9/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.7794 - val_loss: 0.7468
Epoch 10/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.7499 - val_loss: 0.7336
Epoch 11/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.7350 - val_loss: 0.7174
Epoch 12/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.7127 - val_loss: 0.6835
Epoch 13/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.6968 - val_loss: 0.6707
Epoch 14/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.6839 - val_loss: 0.6873
Epoch 15/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.6707 - val_loss: 0.6563
Epoch 16/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.6585 - val_loss: 0.6397
Epoch 17/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.6456 - val_loss: 0.6442
Epoch 18/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.6353 - val_loss: 0.6147
Epoch 19/500
11854/11854 [==============================] - 2s 138us/sample - loss: 0.6292 - val_loss: 0.6106
Epoch 20/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.6114 - val_loss: 0.6112
Epoch 21/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.6054 - val_loss: 0.5812
Epoch 22/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.5933 - val_loss: 0.5740
Epoch 23/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.5835 - val_loss: 0.5732
Epoch 24/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.5814 - val_loss: 0.5743
Epoch 25/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.5716 - val_loss: 0.5570
Epoch 26/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.5600 - val_loss: 0.5420
Epoch 27/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.5575 - val_loss: 0.5334
Epoch 28/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.5476 - val_loss: 0.5346
Epoch 29/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.5401 - val_loss: 0.5287
Epoch 30/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.5369 - val_loss: 0.5294
Epoch 31/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.5258 - val_loss: 0.5168
Epoch 32/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.5196 - val_loss: 0.5205
Epoch 33/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.5237 - val_loss: 0.5257
Epoch 34/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.5184 - val_loss: 0.4979
Epoch 35/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.5091 - val_loss: 0.5001
Epoch 36/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.5053 - val_loss: 0.4887
Epoch 37/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.5022 - val_loss: 0.4937
Epoch 38/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.4957 - val_loss: 0.4974
Epoch 39/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.4907 - val_loss: 0.4877
Epoch 40/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.4988 - val_loss: 0.4823
Epoch 41/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.4810 - val_loss: 0.4814
Epoch 42/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.4789 - val_loss: 0.4847
Epoch 43/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.4733 - val_loss: 0.4873
Epoch 44/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.4685 - val_loss: 0.4732
Epoch 45/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.4678 - val_loss: 0.4658
Epoch 46/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.4678 - val_loss: 0.4756
Epoch 47/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.4610 - val_loss: 0.4671
Epoch 48/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.4581 - val_loss: 0.4685
Epoch 49/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.4530 - val_loss: 0.4531
Epoch 50/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.4504 - val_loss: 0.4582
Epoch 51/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.4508 - val_loss: 0.4623
Epoch 52/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.4456 - val_loss: 0.4523
Epoch 53/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.4472 - val_loss: 0.4735
Epoch 54/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.4396 - val_loss: 0.4418
Epoch 55/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.4362 - val_loss: 0.4537
Epoch 56/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.4386 - val_loss: 0.4551
Epoch 57/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.4304 - val_loss: 0.4460
Epoch 58/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.4304 - val_loss: 0.4406
Epoch 59/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.4252 - val_loss: 0.4436
Epoch 60/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.4191 - val_loss: 0.4399
Epoch 61/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.4176 - val_loss: 0.4841
Epoch 62/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.4181 - val_loss: 0.4334
Epoch 63/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.4216 - val_loss: 0.4413
Epoch 64/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.4149 - val_loss: 0.4349
Epoch 65/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.4110 - val_loss: 0.4521
Epoch 66/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.4103 - val_loss: 0.4411
Epoch 67/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.4094 - val_loss: 0.4338
Epoch 68/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.4034 - val_loss: 0.4299
Epoch 69/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.4038 - val_loss: 0.4452
Epoch 70/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.4043 - val_loss: 0.4299
Epoch 71/500
11854/11854 [==============================] - 2s 128us/sample - loss: 0.3980 - val_loss: 0.4368
Epoch 72/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.3934 - val_loss: 0.4216
Epoch 73/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3898 - val_loss: 0.4393
Epoch 74/500
11854/11854 [==============================] - 2s 140us/sample - loss: 0.3910 - val_loss: 0.4212
Epoch 75/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.3830 - val_loss: 0.4246
Epoch 76/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.3870 - val_loss: 0.4238
Epoch 77/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3871 - val_loss: 0.4227
Epoch 78/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.3833 - val_loss: 0.4674
Epoch 79/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3819 - val_loss: 0.4260
Epoch 80/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.3779 - val_loss: 0.4349
Epoch 81/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3757 - val_loss: 0.4255
Epoch 82/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.3744 - val_loss: 0.4337
Epoch 83/500
11854/11854 [==============================] - 2s 138us/sample - loss: 0.3691 - val_loss: 0.4168
Epoch 84/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3664 - val_loss: 0.4478
Epoch 85/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3677 - val_loss: 0.4354
Epoch 86/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.3645 - val_loss: 0.4143
Epoch 87/500
11854/11854 [==============================] - 2s 137us/sample - loss: 0.3610 - val_loss: 0.4458
Epoch 88/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.3611 - val_loss: 0.4354
Epoch 89/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.3552 - val_loss: 0.4313
Epoch 90/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.3575 - val_loss: 0.4219
Epoch 91/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3589 - val_loss: 0.4167
Epoch 92/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3539 - val_loss: 0.4354
Epoch 93/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3504 - val_loss: 0.4172
Epoch 94/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.3467 - val_loss: 0.4112
Epoch 95/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3434 - val_loss: 0.4275
Epoch 96/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.3466 - val_loss: 0.4230
Epoch 97/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.3442 - val_loss: 0.4151
Epoch 98/500
11854/11854 [==============================] - 2s 138us/sample - loss: 0.3391 - val_loss: 0.4026
Epoch 99/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.3371 - val_loss: 0.4258
Epoch 100/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.3365 - val_loss: 0.4104
Epoch 101/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3341 - val_loss: 0.4150
Epoch 102/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.3325 - val_loss: 0.4103
Epoch 103/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.3296 - val_loss: 0.4050
Epoch 104/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.3263 - val_loss: 0.4173
Epoch 105/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.3274 - val_loss: 0.4190
Epoch 106/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.3250 - val_loss: 0.4383
Epoch 107/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.3247 - val_loss: 0.4151
Epoch 108/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.3227 - val_loss: 0.4090
Epoch 109/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.3195 - val_loss: 0.4248
Epoch 110/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3385 - val_loss: 0.4488
Epoch 111/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3158 - val_loss: 0.4213
Epoch 112/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.3106 - val_loss: 0.4630
Epoch 113/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3114 - val_loss: 0.4202
Epoch 114/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3149 - val_loss: 0.4141
Epoch 115/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.3115 - val_loss: 0.4387
Epoch 116/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.3103 - val_loss: 0.4293
Epoch 117/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3046 - val_loss: 0.4426
Epoch 118/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.3074 - val_loss: 0.4322
Epoch 119/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3033 - val_loss: 0.4286
Epoch 120/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.3042 - val_loss: 0.4134
Epoch 121/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2986 - val_loss: 0.4158
Epoch 122/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2977 - val_loss: 0.4486
Epoch 123/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2980 - val_loss: 0.4874
Epoch 124/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2941 - val_loss: 0.4302
Epoch 125/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2941 - val_loss: 0.4457
Epoch 126/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2932 - val_loss: 0.4443
Epoch 127/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2933 - val_loss: 0.4550
Epoch 128/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2914 - val_loss: 0.4411
Epoch 129/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2878 - val_loss: 0.4522
Epoch 130/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2916 - val_loss: 0.4462
Epoch 131/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2880 - val_loss: 0.4247
Epoch 132/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2819 - val_loss: 0.4651
Epoch 133/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2831 - val_loss: 0.4537
Epoch 134/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2801 - val_loss: 0.4525
Epoch 135/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2818 - val_loss: 0.4321
Epoch 136/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.2819 - val_loss: 0.4884
Epoch 137/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.2775 - val_loss: 0.4500
Epoch 138/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.2751 - val_loss: 0.4304
Epoch 139/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2721 - val_loss: 0.4742
Epoch 140/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2860 - val_loss: 0.4386
Epoch 141/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.2747 - val_loss: 0.4449
Epoch 142/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2714 - val_loss: 0.4735
Epoch 143/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2784 - val_loss: 0.4466
Epoch 144/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2687 - val_loss: 0.5237
Epoch 145/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.2665 - val_loss: 0.4646
Epoch 146/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2660 - val_loss: 0.4590
Epoch 147/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2656 - val_loss: 0.4741
Epoch 148/500
11854/11854 [==============================] - 2s 138us/sample - loss: 0.2608 - val_loss: 0.4464
Epoch 149/500
11854/11854 [==============================] - 2s 139us/sample - loss: 0.2633 - val_loss: 0.4619
Epoch 150/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.2658 - val_loss: 0.4548
Epoch 151/500
11854/11854 [==============================] - 2s 139us/sample - loss: 0.2567 - val_loss: 0.4708
Epoch 152/500
11854/11854 [==============================] - 2s 139us/sample - loss: 0.2578 - val_loss: 0.4828
Epoch 153/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2558 - val_loss: 0.4885
Epoch 154/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2535 - val_loss: 0.4606
Epoch 155/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2567 - val_loss: 0.4817
Epoch 156/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2528 - val_loss: 0.4849
Epoch 157/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2534 - val_loss: 0.5077
Epoch 158/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2511 - val_loss: 0.5115
Epoch 159/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.2472 - val_loss: 0.4841
Epoch 160/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.2503 - val_loss: 0.4702
Epoch 161/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2454 - val_loss: 0.5041
Epoch 162/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.2478 - val_loss: 0.5100
Epoch 163/500
11854/11854 [==============================] - 2s 136us/sample - loss: 0.2461 - val_loss: 0.5008
Epoch 164/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2447 - val_loss: 0.4676
Epoch 165/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2439 - val_loss: 0.4915
Epoch 166/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2424 - val_loss: 0.5039
Epoch 167/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2422 - val_loss: 0.5154
Epoch 168/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2443 - val_loss: 0.4931
Epoch 169/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2380 - val_loss: 0.5194
Epoch 170/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2379 - val_loss: 0.4996
Epoch 171/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2359 - val_loss: 0.5021
Epoch 172/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2367 - val_loss: 0.5350
Epoch 173/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2362 - val_loss: 0.4947
Epoch 174/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.2322 - val_loss: 0.5131
Epoch 175/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.2326 - val_loss: 0.4732
Epoch 176/500
11854/11854 [==============================] - 2s 135us/sample - loss: 0.2289 - val_loss: 0.4843
Epoch 177/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2305 - val_loss: 0.5144
Epoch 178/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2287 - val_loss: 0.5446
Epoch 179/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2320 - val_loss: 0.5356
Epoch 180/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2255 - val_loss: 0.5713
Epoch 181/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.2235 - val_loss: 0.5233
Epoch 182/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2235 - val_loss: 0.4942
Epoch 183/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2230 - val_loss: 0.5594
Epoch 184/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2289 - val_loss: 0.4998
Epoch 185/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2232 - val_loss: 0.5639
Epoch 186/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2230 - val_loss: 0.5492
Epoch 187/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2193 - val_loss: 0.5367
Epoch 188/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2191 - val_loss: 0.5768
Epoch 189/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2189 - val_loss: 0.5992
Epoch 190/500
11854/11854 [==============================] - 2s 129us/sample - loss: 0.2163 - val_loss: 0.5456
Epoch 191/500
11854/11854 [==============================] - 2s 133us/sample - loss: 0.2164 - val_loss: 0.5622
Epoch 192/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2182 - val_loss: 0.5184
Epoch 193/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2158 - val_loss: 0.5381
Epoch 194/500
11854/11854 [==============================] - 2s 131us/sample - loss: 0.2130 - val_loss: 0.5700
Epoch 195/500
11854/11854 [==============================] - 2s 132us/sample - loss: 0.2107 - val_loss: 0.5250
Epoch 196/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2119 - val_loss: 0.5583
Epoch 197/500
11854/11854 [==============================] - 2s 130us/sample - loss: 0.2094 - val_loss: 0.5669
Epoch 198/500
11854/11854 [==============================] - 2s 134us/sample - loss: 0.2119 - val_loss: 0.5317
../../_images/_notebooks_model_interpretability_ia_20_5.png
********** Successfully loaded weights from weights_098_0.40258.hdf5 file **********

Making predictions on training data

[22]:
_ = model.predict_on_training_data(data=data)
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)
`Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
../../_images/_notebooks_model_interpretability_ia_22_2.png
../../_images/_notebooks_model_interpretability_ia_22_3.png
../../_images/_notebooks_model_interpretability_ia_22_4.png
invalid value encountered in log2
invalid value encountered in log1p
invalid value encountered in log1p
invalid value encountered in log1p
invalid value encountered in log1p
../../_images/_notebooks_model_interpretability_ia_22_6.png
../../_images/_notebooks_model_interpretability_ia_22_7.png
../../_images/_notebooks_model_interpretability_ia_22_8.png

Making predictions on test data

[23]:
_ = model.predict_on_test_data(data=data)
input_x shape:  [(6352, 15, 7), (6352, 20), (6352, 20)]
target shape:  (6352, 1)
../../_images/_notebooks_model_interpretability_ia_24_1.png
../../_images/_notebooks_model_interpretability_ia_24_2.png
../../_images/_notebooks_model_interpretability_ia_24_3.png
../../_images/_notebooks_model_interpretability_ia_24_4.png
../../_images/_notebooks_model_interpretability_ia_24_5.png
../../_images/_notebooks_model_interpretability_ia_24_6.png

Getting interpretable results

[24]:
_ = model.plot_avg_attentions_along_inputs(data=data,
                                          colorbar=True)
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)
../../_images/_notebooks_model_interpretability_ia_26_1.png
[29]:
_ = model.plot_act_along_inputs(data=data, feature="precipitation_AWAP",
                               cbar_params={"pad": 0.5}, cmap="jet")
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)

            argument test is deprecated and will be removed in future. Please
            use 'predict_on_test_data' method instead.
../../_images/_notebooks_model_interpretability_ia_27_2.png
[26]:
_ = model.plot_act_along_inputs(data=data, feature="et_morton_point_SILO",
                               cbar_params={"pad": 0.5}, cmap="jet")
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)

            argument test is deprecated and will be removed in future. Please
            use 'predict_on_test_data' method instead.
../../_images/_notebooks_model_interpretability_ia_28_2.png
[30]:
_ = model.plot_act_along_inputs(data=data, feature="tmax_AWAP",
                               cbar_params={"pad": 0.5}, cmap="jet")
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)
input_x shape:  [(11854, 15, 7), (11854, 20), (11854, 20)]
target shape:  (11854, 1)

            argument test is deprecated and will be removed in future. Please
            use 'predict_on_test_data' method instead.
../../_images/_notebooks_model_interpretability_ia_29_2.png
[28]:
print(f'All results are saved in {model.path}')
All results are saved in D:\mytools\ai4water_examples\docs\source\_notebooks\results\20230212_205108
[ ]: