Source code for ai4water.postprocessing.explain._lime


from typing import Union

from ai4water.backend import np, pd, plt, os, lime

if lime is not None:
    from lime import lime_tabular

from ._explain import ExplainerMixin


[docs]class LimeExplainer(ExplainerMixin): """ Wrapper around LIME module. Example: >>> from ai4water import Model >>> from ai4water.postprocessing import LimeExplainer >>> from ai4water.datasets import busan_beach >>> model = Model(model="GradientBoostingRegressor") >>> model.fit(data=busan_beach()) >>> lime_exp = LimeExplainer(model=model, ... train_data=model.training_data()[0], ... data=model.test_data()[0], ... mode="regression") >>> lime_exp.explain_example(0) Attributes: explaination_objects : location explaination objects for each individual example/instance """
[docs] def __init__( self, model, data, train_data, mode: str, explainer=None, path=None, feature_names: list = None, verbosity: Union[int, bool] = True, save: bool = True, show: bool = True, **kwargs ): """ Arguments: model : the model to explain. The model must have `predict` method. data : the data to explain. This would typically be test data but it can be any data. train_data : the data on which the model was trained. mode : either of `regression` or `classification` explainer : The explainer to use. By default, LimeTabularExplainer is used. path : path where to save all the plots. By default, plots will be saved in current working directory. feature_names : name/names of features. verbosity : whether to print information or not. show: whether to show the plot or not save: whether to save the plot or not """ self.model = model self.train_data = to_np(train_data) super(LimeExplainer, self).__init__(path=path or os.getcwd(), data=to_np(data), save=save, show=show, features=feature_names) self.mode = mode self.verbosity = verbosity self.explainer = self._get_explainer(explainer, **kwargs) self.explaination_objects = {}
@property def mode(self): return self._mode @mode.setter def mode(self, x): if x is not None: assert x in ["regression", "classification"], f"mode must be either regression or classification not {x}" self._mode = x def _get_explainer(self, proposed_explainer=None, **kwargs): if proposed_explainer is None and self.data.ndim <= 2: lime_explainer = lime.lime_tabular.LimeTabularExplainer( self.train_data, feature_names=self.features, # class_names=['price'], # categorical_features=categorical_features, verbose=self.verbosity, mode=self.mode, **kwargs ) elif proposed_explainer in lime.lime_tabular.__dict__.keys(): lime_explainer = getattr(lime.lime_tabular, proposed_explainer)( self.train_data, feature_names=self.features, mode=self.mode, verbose=self.verbosity, **kwargs ) elif self.data.ndim == 3: lime_explainer = lime.lime_tabular.RecurrentTabularExplainer( self.train_data, mode=self.mode, feature_names=self.features, verbose=self.verbosity, **kwargs ) elif proposed_explainer is not None: lime_explainer = getattr(lime, proposed_explainer)( self.train_data, features=self.features, mode=self.mode, **kwargs ) else: raise ValueError(f"Can not infer explainer. Please specify explainer to use.") return lime_explainer def __call__(self, *args, **kwargs): self.explain_all_examples(*args, **kwargs) return
[docs] def explain_all_examples(self, plot_type="pyplot", name="lime_explaination", num_features=None, **kwargs ): """ Draws and saves plot for all examples of test_data. Arguments: plot_type : name : num_features : kwargs : any keyword argument for `explain_instance` An example here means an instance/sample/data point. """ for i in range(len(self.data)): self.explain_example(i, plot_type=plot_type, name=f"{name}_{i}", num_features=num_features, **kwargs) return
[docs] def explain_example( self, index: int, plot_type: str = "pyplot", name: str = "lime_explaination", num_features: int = None, colors=None, annotate=False, **kwargs )->plt.Figure: """ Draws and saves plot for a single example of test_data. Arguments: index : index of test_data plot_type : either pyplot or html name : name with which to save the file num_features : colors : annotate : whether to annotate figure or not kwargs : any keyword argument for `explain_instance` Returns: matplotlib figure if plot_type="pyplot" and show is False. """ assert plot_type in ("pyplot", "html") exp = self.explainer.explain_instance(self.data[index], self.model.predict, num_features=num_features or len(self.features), **kwargs ) self.explaination_objects[index] = exp fig = None if plot_type == "pyplot": plt.close() fig = as_pyplot_figure(exp, colors=colors, example_index=index, annotate=annotate) if self.save: plt.savefig(os.path.join(self.path, f"{name}_{index}"), bbox_inches="tight") if self.show: plt.show() else: exp.save_to_file(os.path.join(self.path, f"{name}_{index}")) return fig
def to_np(x) -> np.ndarray: if isinstance(x, pd.DataFrame): x = x.values else: assert isinstance(x, np.ndarray) return x def as_pyplot_figure( inst_explainer, label=1, example_index=None, colors: [str, tuple, list] = None, annotate=False, **kwargs): """Returns the explanation as a pyplot figure. Will throw an error if you don't have matplotlib installed Args: inst_explainer : instance explainer label: desired label. If you ask for a label for which an explanation wasn't computed, will throw an exception. Will be ignored for regression explanations. colors : if tuple it must be names of two colors for +ve and -ve example_index : annotate : whether to annotate the figure or not? kwargs: keyword arguments, passed to domain_mapper Returns: pyplot figure (barchart). """ textstr = f"""Prediction: {round(inst_explainer.predicted_value, 2)} Local prediction: {round(inst_explainer.local_pred.item(), 2)}""" if colors is None: colors = ([0.9375, 0.01171875, 0.33203125], [0.23828125, 0.53515625, 0.92578125]) elif isinstance(colors, str): colors = (colors, colors) exp = inst_explainer.as_list(label=label, **kwargs) fig = plt.figure() vals = [x[1] for x in exp] names = [x[0] for x in exp] vals.reverse() names.reverse() if isinstance(colors, tuple): colors = [colors[0] if x > 0 else colors[1] for x in vals] pos = np.arange(len(exp)) + .5 h = plt.barh(pos, vals, align='center', color=colors) plt.yticks(pos, names) if inst_explainer.mode == "classification": title = 'Local explanation for class %s' % inst_explainer.class_names[label] else: title = f'Local explanation for example {example_index}' plt.title(title) plt.grid(linestyle='--', alpha=0.5) if annotate: # https://stackoverflow.com/a/59109053/5982232 plt.legend(h, [textstr], loc="best", fancybox=True, framealpha=0.7, handlelength=0, handletextpad=0) return fig