Source code for ai4water.postprocessing.explain._shap


from typing import Union, Callable, List

try:
    import shap
    from shap import Explanation
except ModuleNotFoundError:
    shap = None
    Explanation = None

import scipy as sp

try:
    import tensorflow.keras.backend as K
except ModuleNotFoundError:
    K = None

from ._explain import ExplainerMixin
from .utils import convert_ai4water_model
from ai4water.backend import sklearn_models, np, pd, os, plt, easy_mpl


[docs]class ShapExplainer(ExplainerMixin): """ Wrapper around SHAP `explainers` and `plots` to draw and save all the plots for a given model. Attributes: features : train_summary : only for KernelExplainer explainer : shap_values : Methods -------- - summary_plot - force_plot_single_example - dependence_plot_single_feature - force_plot_all Examples: >>> from ai4water.postprocessing import ShapExplainer >>> from sklearn.model_selection import train_test_split >>> from sklearn import linear_model >>> import shap ... >>> X,y = shap.datasets.diabetes() >>> X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=0.2, random_state=0) >>> lin_regr = linear_model.LinearRegression() >>> lin_regr.fit(X_train, y_train) >>> explainer = ShapExplainer(lin_regr, X_test, X_train, num_means=10) >>> explainer() """ allowed_explainers = [ "Explainer", "DeepExplainer", "TreeExplainer", "KernelExplainer", "LinearExplainer", "AdditiveExplainer", "GPUTreeExplainer", "GradientExplainer", "PermutationExplainer", "SamplingExplainer", "PartitionExplainer" ]
[docs] def __init__( self, model, data: Union[np.ndarray, pd.DataFrame, List[np.ndarray]], train_data: Union[np.ndarray, pd.DataFrame, List[np.ndarray]] = None, explainer: Union[str, Callable] = None, num_means: int = 10, path: str = None, feature_names: list = None, framework: str = None, layer: Union[int, str] = None, save: bool = True, show: bool = True, ): """ Args: model : a Model/regressor/classifier from sklearn/xgboost/catboost/LightGBM/tensorflow/pytorch/ai4water The model must have a `predict` method. data : Data on which to make interpretation. Its dimension should be same as that of training data. It can be either training or test data train_data : The data on which the `model` was trained. It is used to get train_summary. It can a numpy array or a pandas DataFrame. Only required for scikit-learn based models. explainer : str the explainer to use. If not given, the explainer will be inferred. num_means : int Numher of means, used in `shap.kmeans` to calculate train_summary using shap.kmeans. Only used when explainer is "KernelExplainer" path : str path to save the plots. By default, plots will be saved in current working directory feature_names : list Names of features. Should only be given if train/test data is numpy array. framework : str either "DL" or "ML". Here "DL" shows that the `model` is a deep learning or neural network based model and "ML" represents other models. For "DL" the explainer will be either "DeepExplainer" or "GradientExplainer". If not given, it will be inferred. In such a case "DeepExplainer" will be prioritized over "GradientExplainer" for DL frameworks and "TreeExplainer" will be prioritized for "ML" frameworks. layer : Union[int, str] only relevant when framework is "DL" i.e when the model consits of layers of neural networks. show: whether to show the plot or not save: whether to save the plot or not """ assert shap is not None, f""" shap package must be installed to use this class. please install shap e.g with pip install shap """ test_data = maybe_to_dataframe(data, feature_names) train_data = maybe_to_dataframe(train_data, feature_names) super(ShapExplainer, self).__init__(path=path or os.getcwd(), data=test_data, features=feature_names, save=save, show=show ) if train_data is None: self._check_data(test_data) else: self._check_data(train_data, test_data) model, framework, explainer, model_name = convert_ai4water_model(model, framework, explainer) self.is_sklearn = True if model_name not in sklearn_models: if model_name in ["XGBRegressor", "XGBClassifier", "LGBMRegressor", "LGBMClassifier", "CatBoostRegressor", "CatBoostClassifier" "XGBRFRegressor" "XGBRFClassifier" ]: self.is_sklearn = False elif not self._is_dl(model): raise ValueError(f"{model.__class__.__name__} is not a valid model model") self._framework = self.infer_framework(model, framework, layer, explainer) self.model = model self.data = test_data self.layer = layer self.features = feature_names self.explainer = self._get_explainer(explainer, train_data=train_data, num_means=num_means) self.shap_values = self.get_shap_values(test_data)
@staticmethod def _is_dl(model): if hasattr(model, "count_params") or hasattr(model, "named_parameters"): return True return False @property def layer(self): return self._layer @layer.setter def layer(self, x): if x is not None: if not isinstance(x, str): assert isinstance(x, int), f"layer must either b string or integer" assert x <= len(self.model.layers) # todo, what about pytorch self._layer = x
[docs] def map2layer(self, x, layer): feed_dict = dict(zip([self.model.layers[0].input], [x.copy()])) import tensorflow as tf if int(tf.__version__[0]) < 2: sess = K.get_session() else: sess = tf.compat.v1.keras.backend.get_session() if isinstance(layer, int): return sess.run(self.model.layers[layer].input, feed_dict) else: return sess.run(self.model.get_layer(layer).input, feed_dict)
[docs] def infer_framework(self, model, framework, layer, explainer): if framework is not None: inf_framework = framework elif self._is_dl(model): inf_framework = "DL" elif isinstance(explainer, str) and explainer in ("DeepExplainer", "GradientExplainer"): inf_framework = "DL" elif explainer.__class__.__name__ in ("DeepExplainer", "GradientExplainer"): inf_framework = "DL" else: inf_framework = "ML" assert inf_framework in ("ML", "DL") if inf_framework != "DL": assert layer is None if inf_framework == "DL" and isinstance(explainer, str): assert explainer in ("DeepExplainer", "GradientExplainer", "PermutationExplainer"), f"invalid explainer {inf_framework}" return inf_framework
def _get_explainer(self, explainer: Union[str, Callable], num_means, train_data ): if explainer is not None: if callable(explainer): return explainer assert isinstance(explainer, str), f"explainer should be callable or string but" \ f" it is {explainer.__class__.__name__}" assert explainer in self.allowed_explainers, f"{explainer} is not a valid explainer" if explainer == "KernelExplainer": explainer = self._get_kernel_explainer(train_data, num_means) elif explainer == "DeepExplainer": explainer = self._get_deep_explainer() elif explainer == "GradientExplainer": explainer = self._get_gradient_explainer() elif explainer == "PermutationExplainer": explainer = shap.PermutationExplainer(self.model, self.data) else: explainer = getattr(shap, explainer)(self.model) else: # explainer is not given explicitly, we need to infer it explainer = self._infer_explainer_to_use(train_data, num_means) return explainer def _get_kernel_explainer(self, data, num_means): assert isinstance(num_means, int), f'num_means should be integer but given value is of type {num_means.__class__.__name__}' if data is None: raise ValueError("Provide train_data in order to use KernelExplainer.") self.train_summary = shap.kmeans(data, num_means) explainer = shap.KernelExplainer(self.model.predict, self.train_summary) return explainer def _infer_explainer_to_use(self, train_data, num_means): """Tries to infer explainer to use from the type of model.""" # todo, Fig 3 of Lundberberg's Nature MI paper shows that TreeExplainer # performs better than KernelExplainer, so try to use supports_model_with_masker if self.model.__class__.__name__ in ["XGBRegressor", "LGBMRegressor", "CatBoostRegressor", "XGBRFRegressor"]: explainer = shap.TreeExplainer(self.model) elif self.model.__class__.__name__ in sklearn_models: explainer = self._get_kernel_explainer(train_data, num_means) elif self._framework == "DL": explainer = self._get_deep_explainer() else: raise ValueError(f"Can not infer explainer for model {self.model.__class__.__name__}." f" Plesae specify explainer by using `explainer` keyword argument") return explainer def _get_deep_explainer(self): data = self.data.values if isinstance(self.data, pd.DataFrame) else self.data return getattr(shap, "DeepExplainer")(self.model, data) def _get_gradient_explainer(self): if self.layer is None: # GradientExplainer is also possible without specifying a layer return shap.GradientExplainer(self.model, self.data) if isinstance(self.layer, int): return shap.GradientExplainer((self.model.layers[self.layer].input, self.model.layers[-1].output), self.map2layer(self.data, self.layer)) else: return shap.GradientExplainer((self.model.get_layer(self.layer).input, self.model.layers[-1].output), self.map2layer(self.data, self.layer)) def _check_data(self, *data): if self.single_source: for d in data: assert type(d) == np.ndarray or type(d) == pd.DataFrame, f""" data must be numpy array or pandas dataframe but it is of type {d.__class__.__name__}""" assert len(set([d.ndim for d in data])) == 1, "train and test data should have same ndim" assert len(set([d.shape[-1] for d in data])) == 1, "train and test data should have same input features" assert len(set([type(d) for d in data])) == 1, "train and test data should be of same type" return
[docs] def get_shap_values(self, data, **kwargs): if self.explainer.__class__.__name__ in ["Permutation"]: return self.explainer(data) elif self._framework == "DL": return self._shap_values_dl(data, **kwargs) return self.explainer.shap_values(data)
def _shap_values_dl(self, data, ranked_outputs=None, **kwargs): """Gets the SHAP values""" data = data.values if isinstance(data, pd.DataFrame) else data if self.explainer.__class__.__name__ == "Deep": shap_values = self.explainer.shap_values(data, ranked_outputs=ranked_outputs, **kwargs) elif isinstance(self.explainer, shap.GradientExplainer) and self.layer is None: shap_values = self.explainer.shap_values(data, ranked_outputs=ranked_outputs, **kwargs) else: shap_values = self.explainer.shap_values(self.map2layer(data, self.layer), ranked_outputs=ranked_outputs, **kwargs) if ranked_outputs: shap_values, indexes = shap_values return shap_values def __call__(self, force_plots=True, plot_force_all=False, dependence_plots=False, beeswarm_plots=False, heatmap=False, ): """Draws and saves all the plots for a given sklearn model in the path. plot_force_all is set to False by default because it is causing Process finished error due. Avoiding this error is a complex function of scipy and numba versions. """ if dependence_plots: for feature in self.features: self.dependence_plot_single_feature(feature, f"dependence_plot_{feature}") if force_plots: for i in range(self.data.shape[0]): self.force_plot_single_example(i, f"force_plot_{i}") if beeswarm_plots: self.beeswarm_plot() if plot_force_all: self.force_plot_all("force_plot_all") if heatmap: self.heatmap() self.summary_plot("summary_plot") return
[docs] def summary_plot( self, plot_type: str = None, name: str = "summary_plot", **kwargs ): """ Plots the `summary <https://shap-lrjball.readthedocs.io/en/latest/generated/shap.summary_plot.html#shap.summary_plot>`_ plot of SHAP package. Arguments: plot_type : str, either "bar", or "violen" or "dot" name: name of saved file kwargs: any keyword arguments to shap.summary_plot """ def _summary_plot(_shap_val, _data, _features, _name): plt.close('all') shap.summary_plot(_shap_val, _data, show=False, plot_type=plot_type, feature_names=_features, **kwargs) if self.save: plt.savefig(os.path.join(self.path, _name + " _bar"), dpi=300, bbox_inches="tight") if self.show: plt.show() return shap_vals = self.shap_values if isinstance(shap_vals, list) and len(shap_vals) == 1: shap_vals = shap_vals[0] data = self.data if self.single_source: if data.ndim == 3: assert shap_vals.ndim == 3 for lookback in range(data.shape[1]): _summary_plot(shap_vals[:, lookback], data[:, lookback], self.features, _name=f"{name}_{lookback}") else: _summary_plot(shap_vals, data, self.features, name) else: # data is a list of data sources for idx, _data in enumerate(data): if _data.ndim == 3: for lb in range(_data.shape[1]): _summary_plot(shap_vals[idx][:, lb], _data[:, lb], self.features[idx], _name=f"{name}_{idx}_{lb}") else: _summary_plot(shap_vals[idx], _data, self.features[idx], _name=f"{name}_{idx}") return
[docs] def force_plot_single_example( self, idx:int, name=None, **force_kws ): """Draws force_plot_ for a single example/row/sample/instance/data point. If the data is 3d and shap values are 3d then they are unrolled/flattened before plotting Arguments: idx: index of exmaple to use. It can be any value >=0 name: name of saved file force_kws : any keyword argument for force plot Returns: plotter object .. _force_plot: https://shap.readthedocs.io/en/latest/generated/shap.plots.force.html """ shap_vals = self.shap_values if isinstance(shap_vals, list) and len(shap_vals) == 1: shap_vals = shap_vals[0] shap_vals = shap_vals[idx] if type(self.data) == np.ndarray: data = self.data[idx] else: data = self.data.iloc[idx, :] if self.explainer.__class__.__name__ == "Gradient": expected_value = [0] else: expected_value = self.explainer.expected_value features = self.features if data.ndim == 2 and shap_vals.ndim == 2: # input was 3d i.e. ml model uses 3d input features = self.unrolled_features expected_value = expected_value[0] # todo shap_vals = shap_vals.reshape(-1,) data = data.reshape(-1, ) plt.close('all') plotter = shap.force_plot( expected_value, shap_vals, data, feature_names=features, show=False, matplotlib=True, **force_kws ) if self.save: name = name or f"force_plot_{idx}" plotter.savefig(os.path.join(self.path, name), dpi=300, bbox_inches="tight") if self.show: plotter.show() return plotter
[docs] def dependence_plot_all_features(self, **dependence_kws): """dependence plot for all features""" for feature in self.features: self.dependence_plot_single_feature(feature, f"dependence_plot_{feature}", **dependence_kws) return
[docs] def dependence_plot_single_feature(self, feature, name="dependence_plot", **kwargs): """dependence_ plot for a single feature. See this_ . .. _dependence: https://slundberg.github.io/shap/notebooks/plots/dependence_plot.html .. _this: https://shap-lrjball.readthedocs.io/en/docs_update/generated/shap.dependence_plot.html """ plt.close('all') if len(name) > 150: # matplotlib raises error if the length of filename is too large name = name[0:150] shap_values = self.shap_values if isinstance(shap_values, list) and len(shap_values) == 1: shap_values = shap_values[0] shap.dependence_plot(feature, shap_values, self.data, show=False, **kwargs) if self.save: plt.savefig(os.path.join(self.path, name), dpi=300, bbox_inches="tight") if self.show: plt.show() return
[docs] def force_plot_all(self, name="force_plot.html", save=True, show=True, **force_kws): """draws force plot for all examples in the given data and saves it in an html""" # following scipy versions cause kernel stoppage when calculating if sp.__version__ in ["1.4.1", "1.5.2", "1.7.1"]: print(f"force plot can not be plotted for scipy version {sp.__version__}. Please change your scipy") return shap_values = self.shap_values if isinstance(shap_values, list) and len(shap_values) == 1: shap_values = shap_values[0] plt.close('all') plot = shap.force_plot(self.explainer.expected_value, shap_values, self.data, **force_kws) if save: shap.save_html(os.path.join(self.path, name), plot) return
[docs] def waterfall_plot_all_examples( self, name: str = "waterfall", **waterfall_kws ): """Plots the waterfall_ plot of SHAP package It plots for all the examples/instances from test_data. .. _waterfall: https://shap.readthedocs.io/en/latest/generated/shap.plots.waterfall.html """ for i in range(len(self.data)): self.waterfall_plot_single_example(i, name=name, **waterfall_kws) return
[docs] def waterfall_plot_single_example( self, example_index: int, name: str = "waterfall", max_display: int = 10, ): """draws and saves waterfall_ plot for one example. The waterfall plots are based upon SHAP values and show the contribution by each feature in model's prediction. It shows which feature pushed the prediction in which direction. They answer the question, why the ML model simply did not predict mean of training y instead of what it predicted. The mean of training observations that the ML model saw during training is called base value or expected value. Arguments: example_index : int index of example to use max_display : int maximu features to display name : str name of plot .. _waterfall: https://shap.readthedocs.io/en/latest/generated/shap.plots.waterfall.html """ if self.explainer.__class__.__name__ in ["Deep", "Kernel"]: shap_vals_as_exp = None else: shap_vals_as_exp = self.explainer(self.data) shap_values = self.shap_values if isinstance(shap_values, list) and len(shap_values) == 1: shap_values = shap_values[0] plt.close('all') if shap_vals_as_exp is None: features = self.features if not self.data_is_2d: features = self.unrolled_features # waterfall plot expects first argument as Explaination class # which must have at least these attributes (values, data, feature_names, base_values) # https://github.com/slundberg/shap/issues/1420#issuecomment-715190610 if not self.data_is_2d: # if original data is 3d then we flat it into 1d array values = shap_values[example_index].reshape(-1, ) data = self.data[example_index].reshape(-1, ) else: values = shap_values[example_index] data = self.data.iloc[example_index] exp_value = self.explainer.expected_value if self.explainer.__class__.__name__ in ["Kernel"]: pass else: exp_value = exp_value[0] e = Explanation( values, base_values=exp_value, data=data, feature_names=features ) shap.plots.waterfall(e, show=False, max_display=max_display) else: shap.plots.waterfall(shap_vals_as_exp[example_index], show=False, max_display=max_display) if self.save: plt.savefig(os.path.join(self.path, f"{name}_{example_index}"), dpi=300, bbox_inches="tight") if self.show: plt.show() return
[docs] def scatter_plot_single_feature( self, feature: int, name: str = "scatter", **scatter_kws ): """scatter plot for a single feature""" if self.explainer.__class__.__name__ in ["Kernel"]: shap_values = Explanation(self.shap_values, data=self.data.values, feature_names=self.features) else: shap_values = self.explainer(self.data) shap.plots.scatter(shap_values[:, feature], show=False, **scatter_kws) if self.save: plt.savefig(os.path.join(self.path, f"{name}_{feature}"), dpi=300, bbox_inches="tight") if self.show: plt.show() return
[docs] def scatter_plot_all_features(self, name="scatter_plot", **scatter_kws): """draws scatter plot for all features""" if isinstance(self.data, pd.DataFrame): features = self.features else: features = [i for i in range(self.data.shape[-1])] for feature in features: self.scatter_plot_single_feature(feature, name=name, **scatter_kws) return
[docs] def heatmap(self, name: str = 'heatmap', max_display=10): """Plots the heatmap_ and saves it This can be drawn for xgboost/lgbm as well as for randomforest type models but not for CatBoostRegressor which is todo. Note ---- The upper line plot on the heat map shows $-fx/max(abs(fx))$ where $fx$ is the mean SHAP value of all features. The length of $fx$ is equal to length of data/examples. Thus one point on this line is the mean of SHAP values of all input features for the given/one example normalized by the maximum absolute value of $fx$. .. _heatmap: https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/heatmap.html """ # if heat map is drawn with np.ndarray, it throws errors therefore convert # it into pandas DataFrame. It is more interpretable and does not hurt. try: shap_values = self._get_shap_values_locally() except (ValueError, AttributeError): # some times we are not able to calculate shap values as # being calcultaed inside '_get_shap_values_locally' shap_values = Explanation( self.shap_values, data=self.data.values, feature_names=self.features) # by default examples are ordered in such a way that examples with similar # explanations are grouped together. self._heatmap(shap_values, f"{name}_basic", max_display=max_display) # sort by the maximum absolute value of a feature over all the examples self._heatmap(shap_values, f"{name}_sortby_maxabs", max_display=max_display, feature_values=shap_values.abs.max(0)) # sorting by the sum of the SHAP values over all features gives a complementary perspective on the data self._heatmap(shap_values, f"{name}_sortby_SumOfShap", max_display=max_display, instance_order=shap_values.sum(1)) return
def _heatmap(self, shap_values, name, max_display=10, **kwargs): plt.close('all') # set show to False because we want to reset xlabel shap.plots.heatmap(shap_values, show=False, max_display=max_display, **kwargs) plt.xlabel("Examples") if self.save: plt.savefig(os.path.join(self.path, f"{name}_sortby_SumOfShap"), dpi=300, bbox_inches="tight") if self.show: plt.show() return def _get_shap_values_locally(self): data = self.data if not isinstance(self.data, pd.DataFrame) and data.ndim == 2: data = pd.DataFrame(self.data, columns=self.features) # not using global explainer because, this explainer should data as well explainer = shap.Explainer(self.model, data) shap_values = explainer(data) return shap_values
[docs] def beeswarm_plot( self, name: str = "beeswarm", max_display: int = 10, **kwargs ): """ Draws the beeswarm_ plot of shap. Arguments: name : str name of saved file max_display : maximum kwargs : any keyword arguments for shap.beeswarm plot .. _beeswarm: https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/beeswarm.html """ try: shap_values = self._get_shap_values_locally() except (ValueError, AttributeError): shap_values = Explanation(self.shap_values, data=self.data.values, feature_names=self.features) self._beeswarm_plot(shap_values, name=f"{name}_basic", max_display=max_display, **kwargs) # find features with high impacts self._beeswarm_plot(shap_values, name=f"{name}_sortby_maxabs", max_display=max_display, order=shap_values.abs.max(0), **kwargs) # plot the absolute value self._beeswarm_plot(shap_values.abs, name=f"{name}_abs_shapvalues", max_display=max_display, **kwargs) return
def _beeswarm_plot(self, shap_values, name, max_display=10, **kwargs): plt.close('all') shap.plots.beeswarm(shap_values, show=False, max_display=max_display, **kwargs) if self.save: plt.savefig(os.path.join(self.path, name), dpi=300, bbox_inches="tight") if self.show: plt.show() return
[docs] def decision_plot( self, indices=None, name: str = "decision_", **decision_kwargs): """decision_ plot. For details see this blog. .. _decision: https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/decision_plot.html .. _blog: https://towardsdatascience.com/introducing-shap-decision-plots-52ed3b4a1cba """ shap_values = self.shap_values legend_location = "best" legend_labels = None if indices is not None: shap_values = shap_values[(indices), :] if len(shap_values) <= 10: legend_labels = indices legend_location = "lower right" if self.explainer.__class__.__name__ in ["Tree"]: shap.decision_plot(self.explainer.expected_value, shap_values, self.features, legend_labels=legend_labels, show=False, legend_location=legend_location, **decision_kwargs) if self.save: plt.savefig(os.path.join(self.path, name), dpi=300, bbox_inches="tight") if self.show: plt.show() else: raise NotImplementedError
[docs] def plot_shap_values( self, interpolation=None, cmap="coolwarm", name: str = "shap_values", ): """Plots the SHAP values. Arguments: name: name of saved file interpolation: interpolation argument to axis.imshow cmap: color map """ shap_values = self.shap_values if isinstance(shap_values, list) and len(shap_values) == 1: shap_values: np.ndarray = shap_values[0] def plot_shap_values_single_source(_data, _shap_vals, _features, _name): if _data.ndim == 3 and _shap_vals.ndim == 3: # input is 3d # assert _shap_vals.ndim == 3 return imshow_3d(_shap_vals, _data, _features, name=_name, path=self.path, show=self.show, cmap=cmap) plt.close('all') fig, axis = plt.subplots() im = axis.imshow(_shap_vals.T, aspect='auto', interpolation=interpolation, cmap=cmap ) if _features is not None: # if imshow is successful then don't worry if features are None axis.set_yticks(np.arange(len(_features))) axis.set_yticklabels(_features) axis.set_ylabel("Features") axis.set_xlabel("Examples") fig.colorbar(im) if self.save: plt.savefig(os.path.join(self.path, _name), dpi=300, bbox_inches="tight") if self.show: plt.show() return if self.single_source: plot_shap_values_single_source(self.data, shap_values, self.features, name) else: for idx, d in enumerate(self.data): plot_shap_values_single_source(d, shap_values[idx], self.features[idx], f"{idx}_{name}") return
[docs] def pdp_all_features( self, **pdp_kws ): """partial dependence plot of all features. Arguments: pdp_kws: any keyword arguments """ for feat in self.features: self.pdp_single_feature(feat, **pdp_kws) return
[docs] def pdp_single_feature( self, feature_name: str, **pdp_kws ): """partial depence plot using SHAP package for a single feature.""" shap_values = None if hasattr(self.shap_values, 'base_values'): shap_values = self.shap_values if self.model.__class__.__name__.startswith("XGB"): self.model.get_booster().feature_names = self.features fig = shap.partial_dependence_plot( feature_name, model=self.model.predict, data=self.data, model_expected_value=True, feature_expected_value=True, shap_values=shap_values, feature_names=self.features, show=False, **pdp_kws ) if self.save: fname = f"pdp_{feature_name}" plt.savefig(os.path.join(self.path, fname), dpi=300, bbox_inches="tight") if self.show: plt.show() return fig
def imshow_3d(values, data, feature_names: list, path, vmin=None, vmax=None, name="", show=False, cmap=None, ): num_examples, lookback, input_features = values.shape assert data.shape == values.shape for idx, feat in enumerate(feature_names): plt.close('all') fig, (ax1, ax2) = plt.subplots(2, sharex='all', figsize=(10, 12)) yticklabels=[f"t-{int(i)}" for i in np.linspace(lookback - 1, 0, lookback)] axis, im = easy_mpl.imshow(data[:, :, idx].transpose(), yticklabels=yticklabels, ax=ax1, vmin=vmin, vmax=vmax, title=feat, cmap=cmap, show=False ) fig.colorbar(im, ax=axis, orientation='vertical', pad=0.2) axis, im = easy_mpl.imshow(values[:, :, idx].transpose(), yticklabels=yticklabels, vmin=vmin, vmax=vmax, xlabel="Examples", title=f"SHAP Values", cmap=cmap, show=False, ax=ax2) fig.colorbar(im, ax=axis, orientation='vertical', pad=0.2) _name = f'{name}_{feat}_shap_values' plt.savefig(os.path.join(path, _name), dpi=400, bbox_inches='tight') if show: plt.show() return def infer_framework(model): if hasattr(model, 'config') and 'backend' in model.config: framework = model.config['backend'] elif type(model) is tuple: a, _ = model try: a.named_parameters() framework = 'pytorch' except: framework = 'tensorflow' else: try: model.named_parameters() framework = 'pytorch' except: framework = 'tensorflow' return framework def maybe_to_dataframe(data, features=None) -> pd.DataFrame: if isinstance(data, np.ndarray) and isinstance(features, list) and data.ndim == 2: data = pd.DataFrame(data, columns=features) return data