Input Attention LSTM

This notebook shows how to generate interpretable results using deep learning model. The deep learning model used is Input Attention LSTM model of Qin et al., 2017. The dataset used is of Camels Australia

Open In Colab

View Source on GitHub

[2]:
# Some features used in this notebook are not available in the latest release of ai4water from pip which is 1.06
# at the moment. They will be available in ai4water's next release in 1.07. Since 1.07 is not currently available on
# pip, we will install ai4water from github using the following command.

# try:
#     import ai4water
# except (ImportError, ModuleNotFoundError):
#     !pip install git+https://github.com/AtrCheema/AI4Water.git@b0cb440f1c5e28477e2d1ea6f3ece2f68851bd93

[3]:
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\numpy\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:
D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\numpy\.libs\libopenblas.EL2C6PLE4ZYW3ECEVIV3OXXGRN2NRFM2.gfortran-win_amd64.dll
D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\numpy\.libs\libopenblas.GK7GX5KEQ4F6UYO3P26ULGBQYHGQO7J4.gfortran-win_amd64.dll
  warnings.warn("loaded more than 1 DLL from .libs:"
[4]:
import os

from easy_mpl import hist
import matplotlib.pyplot as plt

from ai4water import InputAttentionModel
from ai4water.datasets import CAMELS_AUS
from ai4water.postprocessing import Interpret
from ai4water.utils.utils import get_version_info
D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\sklearn\experimental\enable_hist_gradient_boosting.py:16: UserWarning: Since version 1.0, it is not needed to import enable_hist_gradient_boosting anymore. HistGradientBoostingClassifier and HistGradientBoostingRegressor are now stable and can be normally imported from sklearn.ensemble.
  warnings.warn(
[5]:
for lib,ver in get_version_info().items():
    print(lib, ver)
python 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:20:16) [MSC v.1916 64 bit (AMD64)]
os nt
ai4water 1.07
lightgbm 3.3.1
tcn 3.4.0
catboost 0.26
xgboost 1.5.0
easy_mpl 0.21.3
SeqMetrics 1.3.3
tensorflow 2.7.0
keras.api._v2.keras 2.7.0
numpy 1.21.0
pandas 1.3.4
matplotlib 3.4.3
h5py 3.5.0
sklearn 1.0.1
shapefile 2.3.0
fiona 1.8.22
xarray 0.20.1
netCDF4 1.5.7
optuna 2.10.1
skopt 0.9.0
hyperopt 0.2.7
plotly 5.3.1
lime NotDefined
seaborn 0.11.2
[6]:
dataset = CAMELS_AUS(
    path="F:\\data\\CAMELS\\CAMELS_AUS"  # path where data is downloaded. If data is not available, set path=None
)

inputs = ['et_morton_point_SILO',
           'precipitation_AWAP',
           'tmax_AWAP',
           'tmin_AWAP',
           'vprp_AWAP',
           'rh_tmax_SILO',
           'rh_tmin_SILO'
          ]

outputs = ['streamflow_MLd']

[10]:

data = dataset.fetch('401203', dynamic_features=inputs+outputs, as_dataframe=True) data
[10]:
401203
time dynamic_features
1957-01-01 et_morton_point_SILO 8.062837
precipitation_AWAP 0.000000
tmax_AWAP 20.784480
tmin_AWAP 4.358533
vprp_AWAP 8.142806
... ... ...
2014-12-31 tmin_AWAP 4.522077
vprp_AWAP 8.885449
rh_tmax_SILO 26.714771
rh_tmin_SILO 75.013684
streamflow_MLd 322.791000

169472 rows × 1 columns

[11]:
data = data.unstack()
data.columns = [a[1] for a in data.columns.to_flat_index()]
data.head()
[11]:
et_morton_point_SILO precipitation_AWAP tmax_AWAP tmin_AWAP vprp_AWAP rh_tmax_SILO rh_tmin_SILO streamflow_MLd
time
1957-01-01 8.062837 0.0 20.784480 4.358533 8.142806 28.888577 88.900993 538.551
1957-01-02 8.519483 0.0 27.393169 4.835900 5.281136 23.516738 99.002080 531.094
1957-01-03 9.879688 0.0 28.945301 8.175408 12.920509 19.434872 77.429917 503.011
1957-01-04 6.744638 0.0 26.133843 7.017990 13.951027 42.350667 100.000000 484.512
1957-01-05 8.135359 0.0 21.450775 8.686258 12.168659 30.374862 87.634483 463.416
[12]:
data.shape
[12]:
(21184, 8)
[13]:
(data['streamflow_MLd'].values<0.0).sum()
[13]:
0
[14]:
(data['streamflow_MLd'].values==0.0).sum()
[14]:
0
[15]:
data.isna().sum()
[15]:
et_morton_point_SILO    0
precipitation_AWAP      0
tmax_AWAP               0
tmin_AWAP               0
vprp_AWAP               0
rh_tmax_SILO            0
rh_tmin_SILO            0
streamflow_MLd          0
dtype: int64
[16]:
_ = data.plot(subplots=True, sharex=True, figsize=(10, 10))
../../_images/_notebooks_model_interpretability_ia_12_0.png
[17]:


_ = hist(data, share_axes=False, subplots_kws=dict(figsize=(12, 10)), edgecolor = "k", grid=False)
../../_images/_notebooks_model_interpretability_ia_13_0.png
[18]:
skew_inputs = [
               'precipitation_AWAP',
           'rh_tmin_SILO'
]
[19]:
normal_inputs = ['et_morton_point_SILO',
           'tmax_AWAP',
           'tmin_AWAP',
           'vprp_AWAP',
           'rh_tmax_SILO',
          ]
[20]:


model = InputAttentionModel( enc_config = {'n_h': 62, # length of hidden state m 'n_s': 62, # length of hidden state m 'm': 62, # length of hidden state m 'enc_lstm1_act': "elu", 'enc_lstm2_act': "relu", }, input_features=inputs, output_features=outputs, epochs=500, ts_args={'lookback':15}, lr=0.0049, batch_size=16, x_transformation=[{'method': 'robust', 'features': normal_inputs}, {'method': 'log', "replace_zeros": True, 'features': skew_inputs}], y_transformation={'method': 'robust', 'features': outputs}, verbosity=0 )
building input attention model
[21]:
model.verbosity=1
[22]:
model.inputs
[22]:
[<tf.Tensor 'enc_input1:0' shape=(None, 15, 7) dtype=float32>,
 <tf.Tensor 'enc_first_cell_state_1:0' shape=(None, 62) dtype=float32>,
 <tf.Tensor 'enc_first_hidden_state_1:0' shape=(None, 62) dtype=float32>]
[23]:
model.outputs
[23]:
[<tf.Tensor 'dense/BiasAdd:0' shape=(None, 1) dtype=float32>]
[24]:
h = model.fit(data=data)
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)
inferred mode is regression. Ignore this message if the inferred mode is correct.
input_x shape:  [(2964, 15, 7), (2964, 62), (2964, 62)]
target shape:  (2964, 1)
Train on 11854 samples, validate on 2964 samples
Epoch 1/500
11840/11854 [============================>.] - ETA: 0s - loss: 0.8860
`Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
11854/11854 [==============================] - 14s 1ms/sample - loss: 0.8859 - val_loss: 0.6613
Epoch 2/500
11854/11854 [==============================] - 6s 535us/sample - loss: 0.6698 - val_loss: 0.5784
Epoch 3/500
11854/11854 [==============================] - 6s 529us/sample - loss: 0.5649 - val_loss: 0.5355
Epoch 4/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.5270 - val_loss: 0.4771
Epoch 5/500
11854/11854 [==============================] - 6s 531us/sample - loss: 0.4819 - val_loss: 0.4564
Epoch 6/500
11854/11854 [==============================] - 6s 530us/sample - loss: 0.4615 - val_loss: 0.4084
Epoch 7/500
11854/11854 [==============================] - 6s 527us/sample - loss: 0.4076 - val_loss: 0.4688
Epoch 8/500
11854/11854 [==============================] - 6s 537us/sample - loss: 0.3892 - val_loss: 0.4032
Epoch 9/500
11854/11854 [==============================] - 6s 527us/sample - loss: 0.3519 - val_loss: 0.4458
Epoch 10/500
11854/11854 [==============================] - 6s 529us/sample - loss: 0.3585 - val_loss: 0.4269
Epoch 11/500
11854/11854 [==============================] - 6s 525us/sample - loss: 0.3369 - val_loss: 0.4451
Epoch 12/500
11854/11854 [==============================] - 6s 525us/sample - loss: 0.3166 - val_loss: 0.4580
Epoch 13/500
11854/11854 [==============================] - 6s 523us/sample - loss: 0.3213 - val_loss: 0.4300
Epoch 14/500
11854/11854 [==============================] - 6s 525us/sample - loss: 0.2915 - val_loss: 0.4084
Epoch 15/500
11854/11854 [==============================] - 7s 557us/sample - loss: 0.2741 - val_loss: 0.3997
Epoch 16/500
11854/11854 [==============================] - 6s 524us/sample - loss: 0.2666 - val_loss: 0.4371
Epoch 17/500
11854/11854 [==============================] - 6s 527us/sample - loss: 0.2748 - val_loss: 0.4323
Epoch 18/500
11854/11854 [==============================] - 6s 529us/sample - loss: 0.2599 - val_loss: 0.3841
Epoch 19/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.2283 - val_loss: 0.3901
Epoch 20/500
11854/11854 [==============================] - 6s 530us/sample - loss: 0.2249 - val_loss: 0.3655
Epoch 21/500
11854/11854 [==============================] - 6s 527us/sample - loss: 0.2213 - val_loss: 0.4015
Epoch 22/500
11854/11854 [==============================] - 6s 529us/sample - loss: 0.2080 - val_loss: 0.3814
Epoch 23/500
11854/11854 [==============================] - 6s 529us/sample - loss: 0.2048 - val_loss: 0.4125
Epoch 24/500
11854/11854 [==============================] - 6s 526us/sample - loss: 0.2009 - val_loss: 0.3687
Epoch 25/500
11854/11854 [==============================] - 6s 527us/sample - loss: 0.1784 - val_loss: 0.5114
Epoch 26/500
11854/11854 [==============================] - 6s 526us/sample - loss: 0.1840 - val_loss: 0.3700
Epoch 27/500
11854/11854 [==============================] - 7s 560us/sample - loss: 0.1833 - val_loss: 0.3600
Epoch 28/500
11854/11854 [==============================] - 6s 527us/sample - loss: 0.2156 - val_loss: 0.3684
Epoch 29/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.1686 - val_loss: 0.3616
Epoch 30/500
11854/11854 [==============================] - 6s 537us/sample - loss: 0.1480 - val_loss: 0.3747
Epoch 31/500
11854/11854 [==============================] - 6s 535us/sample - loss: 0.1410 - val_loss: 0.4073
Epoch 32/500
11854/11854 [==============================] - 6s 537us/sample - loss: 0.1536 - val_loss: 0.3528
Epoch 33/500
11854/11854 [==============================] - 6s 531us/sample - loss: 0.1628 - val_loss: 0.3771
Epoch 34/500
11854/11854 [==============================] - 6s 534us/sample - loss: 0.1488 - val_loss: 0.3711
Epoch 35/500
11854/11854 [==============================] - 6s 535us/sample - loss: 0.1443 - val_loss: 0.3875
Epoch 36/500
11854/11854 [==============================] - 6s 534us/sample - loss: 0.1280 - val_loss: 0.4221
Epoch 37/500
11854/11854 [==============================] - 7s 566us/sample - loss: 0.1536 - val_loss: 0.3384
Epoch 38/500
11854/11854 [==============================] - 6s 541us/sample - loss: 0.1279 - val_loss: 0.3258
Epoch 39/500
11854/11854 [==============================] - 6s 538us/sample - loss: 0.1332 - val_loss: 0.4067
Epoch 40/500
11854/11854 [==============================] - 6s 537us/sample - loss: 0.1333 - val_loss: 0.3578
Epoch 41/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.1152 - val_loss: 0.3633
Epoch 42/500
11854/11854 [==============================] - 6s 533us/sample - loss: 0.1113 - val_loss: 0.4077
Epoch 43/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.1202 - val_loss: 0.3671
Epoch 44/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.1155 - val_loss: 0.4412
Epoch 45/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.1246 - val_loss: 0.3626
Epoch 46/500
11854/11854 [==============================] - 6s 531us/sample - loss: 0.1062 - val_loss: 0.3785
Epoch 47/500
11854/11854 [==============================] - 6s 530us/sample - loss: 0.1128 - val_loss: 0.4564
Epoch 48/500
11854/11854 [==============================] - 7s 555us/sample - loss: 0.1083 - val_loss: 0.4337
Epoch 49/500
11854/11854 [==============================] - 6s 534us/sample - loss: 0.1072 - val_loss: 0.3948
Epoch 50/500
11854/11854 [==============================] - 6s 540us/sample - loss: 0.0985 - val_loss: 0.4036
Epoch 51/500
11854/11854 [==============================] - 6s 538us/sample - loss: 0.1065 - val_loss: 0.4161
Epoch 52/500
11854/11854 [==============================] - 6s 540us/sample - loss: 0.1038 - val_loss: 0.3724
Epoch 53/500
11854/11854 [==============================] - 6s 535us/sample - loss: 0.1031 - val_loss: 0.4129
Epoch 54/500
11854/11854 [==============================] - 6s 530us/sample - loss: 0.1029 - val_loss: 0.3731
Epoch 55/500
11854/11854 [==============================] - 6s 529us/sample - loss: 0.0895 - val_loss: 0.4002
Epoch 56/500
11854/11854 [==============================] - 6s 534us/sample - loss: 0.0882 - val_loss: 0.3604
Epoch 57/500
11854/11854 [==============================] - 6s 534us/sample - loss: 0.1000 - val_loss: 0.3837
Epoch 58/500
11854/11854 [==============================] - 6s 542us/sample - loss: 0.1255 - val_loss: 0.3730
Epoch 59/500
11854/11854 [==============================] - 6s 537us/sample - loss: 0.0982 - val_loss: 0.3623
Epoch 60/500
11854/11854 [==============================] - 6s 544us/sample - loss: 0.0777 - val_loss: 0.3528
Epoch 61/500
11854/11854 [==============================] - 6s 530us/sample - loss: 0.0867 - val_loss: 0.3903
Epoch 62/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.0851 - val_loss: 0.4092
Epoch 63/500
11854/11854 [==============================] - 6s 532us/sample - loss: 0.1017 - val_loss: 0.4183
Epoch 64/500
11854/11854 [==============================] - 6s 537us/sample - loss: 0.0841 - val_loss: 0.4137
Epoch 65/500
11854/11854 [==============================] - 6s 531us/sample - loss: 0.0982 - val_loss: 0.3605
Epoch 66/500
11854/11854 [==============================] - 6s 531us/sample - loss: 0.1120 - val_loss: 0.3662
Epoch 67/500
11854/11854 [==============================] - 6s 536us/sample - loss: 0.0717 - val_loss: 0.3686
Epoch 68/500
11854/11854 [==============================] - 6s 535us/sample - loss: 0.0679 - val_loss: 0.3970
Epoch 69/500
11854/11854 [==============================] - 6s 533us/sample - loss: 0.0946 - val_loss: 0.3725
Epoch 70/500
 2368/11854 [====>.........................] - ETA: 4s - loss: 0.0926
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_124924\414651862.py in <module>
----> 1 h = model.fit(data=data)

D:\mytools\AI4Water\ai4water\_main.py in fit(self, x, y, data, callbacks, **kwargs)
    979             assert y is not None
    980
--> 981         return self.call_fit(x=x, y=y, data=data, callbacks=callbacks, **kwargs)
    982
    983     def call_fit(self,

D:\mytools\AI4Water\ai4water\_main.py in call_fit(self, x, y, data, callbacks, **kwargs)
   1016         if self.category == "DL":
   1017
-> 1018             history = self._fit(inputs,
   1019                                 outputs,
   1020                                 callbacks=callbacks,

D:\mytools\AI4Water\ai4water\_main.py in _fit(self, inputs, outputs, validation_data, validation_steps, callbacks, **kwargs)
    737                 _kwargs.pop(k)
    738
--> 739         self._call_fit_fn(
    740             **_kwargs,
    741             **kwargs,

D:\mytools\AI4Water\ai4water\_main.py in _call_fit_fn(self, x, **kwargs)
    658             self._model.test_step = MethodType(test_step, self._model)
    659
--> 660         return self.fit_fn(x, **kwargs)
    661
    662     def _fit(self,

D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\keras\engine\training_v1.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    775
    776     func = self._select_training_loop(x)
--> 777     return func.fit(
    778         self,
    779         x=x,

D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\keras\engine\training_arrays_v1.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
    638       val_x, val_y, val_sample_weights = None, None, None
    639
--> 640     return fit_loop(
    641         model,
    642         inputs=x,

D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\keras\engine\training_arrays_v1.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
    374
    375         # Get outputs.
--> 376         batch_outs = f(ins_batch)
    377         if not isinstance(batch_outs, list):
    378           batch_outs = [batch_outs]

D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\keras\backend.py in __call__(self, inputs)
   4184       self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
   4185
-> 4186     fetched = self._callable_fn(*array_vals,
   4187                                 run_metadata=self.run_metadata)
   4188     self._call_fetch_callbacks(fetched[-len(self._fetches):])

D:\C\Anaconda3\envs\tfcpu27_py39\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)
   1481       try:
   1482         run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
-> 1483         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1484                                                self._handle, args,
   1485                                                run_metadata_ptr)

KeyboardInterrupt:

Making predictions on training data

[26]:
_ = model.predict_on_training_data(data=data, plots=["regression", "residual", "prediction"])
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)
../../_images/_notebooks_model_interpretability_ia_22_1.png
../../_images/_notebooks_model_interpretability_ia_22_2.png
../../_images/_notebooks_model_interpretability_ia_22_3.png
[27]:
model.evaluate_on_training_data(data=data, metrics="nse")
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)

            argument test is deprecated and will be removed in future. Please
            use 'predict_on_test_data' method instead.
[27]:
0.9517542124227427

Making predictions on test data

[28]:
_ = model.predict_on_test_data(data=data, plots=["regression", "residual", "prediction"])
input_x shape:  [(6352, 15, 7), (6352, 62), (6352, 62)]
target shape:  (6352, 1)
../../_images/_notebooks_model_interpretability_ia_25_1.png
../../_images/_notebooks_model_interpretability_ia_25_2.png
../../_images/_notebooks_model_interpretability_ia_25_3.png
[29]:
model.evaluate_on_test_data(data=data, metrics=["nse", "r2", "rmse"])
input_x shape:  [(6352, 15, 7), (6352, 62), (6352, 62)]
target shape:  (6352, 1)

            argument test is deprecated and will be removed in future. Please
            use 'predict_on_test_data' method instead.
[29]:
{'nse': 0.6512428046837175, 'r2': 0.6530917330311746, 'rmse': 741.789854001792}

Getting interpretable results

[30]:
_ = model.plot_avg_attentions_along_inputs(data=data,
                                          colorbar=True)
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)
../../_images/_notebooks_model_interpretability_ia_28_1.png
[31]:
x, y = model.test_data(data=data)
input_x shape:  [(6352, 15, 7), (6352, 62), (6352, 62)]
target shape:  (6352, 1)
[32]:
_ = model.plot_act_along_inputs(data=data,
                                feature="precipitation_AWAP",
                                cbar_params={"pad": 0.5}, cmap="jet")
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)

            argument test is deprecated and will be removed in future. Please
            use 'predict_on_test_data' method instead.
../../_images/_notebooks_model_interpretability_ia_30_2.png
[33]:
_ = model.plot_act_along_inputs(data=data, feature="et_morton_point_SILO",
                               cbar_params={"pad": 0.5}, cmap="jet")
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)

            argument test is deprecated and will be removed in future. Please
            use 'predict_on_test_data' method instead.
../../_images/_notebooks_model_interpretability_ia_31_2.png
[34]:
_ = model.plot_act_along_inputs(data=data, feature="tmax_AWAP",
                               cbar_params={"pad": 0.5}, cmap="jet")
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)
input_x shape:  [(11854, 15, 7), (11854, 62), (11854, 62)]
target shape:  (11854, 1)

            argument test is deprecated and will be removed in future. Please
            use 'predict_on_test_data' method instead.
../../_images/_notebooks_model_interpretability_ia_32_2.png
[35]:
print(f'All results are saved in {model.path}')
All results are saved in D:\mytools\ai4water_examples\docs\source\_notebooks\model\results\20230222_191908
[ ]: