Source code for ai4water.postprocessing._process_predictions



from typing import Union

from SeqMetrics import RegressionMetrics, ClassificationMetrics
from SeqMetrics.utils import plot_metrics

from ai4water.backend import easy_mpl as ep
from ai4water.backend import np, pd, mpl, plt, os, wandb, sklearn

from ai4water.utils.utils import AttribtueSetter
from ai4water.utils.utils import get_values
from ai4water.utils.utils import dateandtime_now, ts_features, dict_to_file

from ai4water.utils.visualizations import Plot, init_subplots
from ai4water.utils.visualizations import murphy_diagram, fdc_plot, edf_plot


# competitive skill score plot/ bootstrap skill score plot as in MLAir
# rank histogram and reliability diagram for probabilitic forecasting model.
# show availability plot of data
# classification report as in yellow brick
# class prediction error as in yellow brick
# discremination threshold as in yellow brick
# Friedman's H statistic https://blog.macuyiko.com/post/2019/discovering-interaction-effects-in-ensemble-models.html
# silhouette analysis
#  KS Statistic plot from labels and scores/probabilities
# reliability curves
# cumulative gain
# lift curve
#

mdates = mpl.dates


# in order to unify the use of metrics
Metrics = {
    'regression': lambda t, p, multiclass=False, **kwargs: RegressionMetrics(t, p, **kwargs),
    'classification': lambda t, p, multiclass=False, **kwargs: ClassificationMetrics(t, p,
                                                                                     multiclass=multiclass, **kwargs)
}


[docs]class ProcessPredictions(Plot): """post processing of results after training.""" available_plots = [ 'regression', 'prediction', 'residual', 'murphy', 'fdc', 'errors', "edf" ]
[docs] def __init__( self, mode: str, forecast_len: int = None, output_features: Union[list, str] = None, wandb_config: dict = None, path: str = None, dpi: int = 300, show=1, save: bool = True, plots: Union[list, str] = None, quantiles:int=None, ): """ Parameters ---------- mode : str either "regression" or "classification" forecast_len : int, optional (default=None) forecast length, only valid when mode is regression output_features : str, optional names of output features plots : int/list, optional (default=None) the names of plots to draw. Following plots are avialble. ``residual`` ``regression`` ``prediction`` ``errors`` ``fdc`` ``murphy`` ``edf`` path : str folder in which to save the results/plots show : bool whether to show the plots or not save : bool whether to save the plots or not wandb_config : weights and bias configuration dictionary dpi : int determines resolution of saved figure Examples -------- >>> import numpy as np >>> from ai4water.postprocessing import ProcessPredictions >>> true = np.random.random(100) >>> predicted = np.random.random(100) >>> processor = ProcessPredictions("regression", forecast_len=1, ... plots=['prediction', 'regression', 'residual']) >>> processor(true, predicted) # for postprocessing of classification, we need to set the mode >>> true = np.random.randint(0, 2, (100, 1)) >>> predicted = np.random.randint(0, 2, (100, 1)) >>> processor = ProcessPredictions("classification") >>> processor(true, predicted) """ self.mode = mode self.forecast_len = forecast_len self.output_features = output_features self.wandb_config = wandb_config self.quantiles = quantiles self.show = show self.save = save self.dpi = dpi if plots is None: if mode == "regression": plots = ['regression', 'prediction', "residual", "errors", "edf"] else: plots = [] elif not isinstance(plots, list): plots = [plots] assert all([plot in self.available_plots for plot in plots]), f""" {plots}""" self.plots = plots super().__init__(path, save=save)
@property def quantiles(self): return self._quantiles @quantiles.setter def quantiles(self, x): self._quantiles = x def _classes(self, array): if self.mode == "classification": return np.unique(array) return []
[docs] def n_classes(self, array): if self.mode == "classification": return len(self._classes(array)) return None
[docs] def save_or_show(self, show=None, **kwargs): if show is None: show = self.show return super().save_or_show(save=self.save, show=show, **kwargs)
[docs] def __call__( self, true_outputs, predicted, metrics="minimal", prefix="test", index=None, inputs=None, model=None, ): if self.quantiles: return self.process_quantiles(true_outputs, predicted) # it true_outputs and predicted are dictionary of len(1) then just get the values true_outputs = get_values(true_outputs) predicted = get_values(predicted) true_outputs = np.array(true_outputs) predicted = np.array(predicted) AttribtueSetter(self, true_outputs) key = {"regression": "rgr", "classification": "cls"} getattr(self, f"process_{key[self.mode]}_results")( true_outputs, predicted, inputs=inputs, metrics=metrics, prefix=prefix, index=index, ) return
[docs] def process_quantiles(self, true, predicted): #assert self.num_outs == 1 if true.ndim == 2: # todo, this should be avoided true = np.expand_dims(true, axis=-1) self.quantiles = self.quantiles self.plot_quantiles1(true, predicted) self.plot_quantiles2(true, predicted) self.plot_all_qs(true, predicted) return
[docs] def horizon_plots(self, errors: dict, fname=''): plt.close('') _, axis = plt.subplots(len(errors), sharex='all') legends = {'r2': "$R^2$", 'rmse': "RMSE", 'nse': "NSE"} idx = 0 for metric_name, val in errors.items(): ax = axis[idx] ax.plot(val, '--o', label=legends.get(metric_name, metric_name)) ax.legend(fontsize=14) if idx >= len(errors) - 1: ax.set_xlabel("Horizons", fontsize=14) ax.set_ylabel(legends.get(metric_name, metric_name), fontsize=14) idx += 1 self.save_or_show(fname=fname) return
[docs] def plot_results( self, true, predicted: pd.DataFrame, prefix, where, inputs=None, ): """ # kwargs can be any/all of followings # fillstyle: # marker: # linestyle: # markersize: # color: """ for plot in self.plots: if plot == "murphy": self.murphy_plot(true, predicted, prefix, where, inputs) else: getattr(self, f"{plot}_plot")(true, predicted, prefix, where) return
[docs] def average_target_across_feature(self, true, predicted, feature): raise NotImplementedError
[docs] def prediction_distribution_across_feature(self, true, predicted, feature): raise NotImplementedError
[docs] def edf_plot(self, true, predicted, prefix, where, **kwargs): """cumulative distribution function of absolute error between true and predicted. Parameters ----------- true : array like predicted : array like prefix : where : """ if isinstance(true, (pd.DataFrame, pd.Series)): true = true.values if isinstance(predicted, (pd.DataFrame, pd.Series)): predicted = predicted.values error = np.abs(true - predicted) edf_plot(error, xlabel="Absolute Error", show=False) return self.save_or_show(fname=f"{prefix}_error_dist", where=where)
[docs] def murphy_plot(self, true, predicted, prefix, where, inputs, **kwargs): murphy_diagram(true, predicted, reference_model="LinearRegression", plot_type="diff", inputs=inputs, show=False, **kwargs) return self.save_or_show(fname=f"{prefix}_murphy", where=where)
[docs] def fdc_plot(self, true, predicted, prefix, where, **kwargs): fdc_plot(predicted, true, show=False, **kwargs) return self.save_or_show(fname=f"{prefix}_fdc", where=where)
[docs] def residual_plot( self, true, predicted, prefix, where, hist_kws:dict = None, **kwargs ): """ Makes residual plot Parameters ---------- true : array like predicted : array like prefix : where : hist_kws : """ fig, axis = plt.subplots(2, sharex="all") x = predicted.values y = true.values - predicted.values _hist_kws = dict(bins=20, linewidth=0.5, edgecolor="k", grid=False, color='khaki') if hist_kws is not None: _hist_kws.update(hist_kws) ep.hist(y, show=False, ax=axis[0], **_hist_kws) axis[0].set_xticks([]) ep.plot(x, y, 'o', show=False, ax=axis[1], color="darksalmon", markerfacecolor=np.array([225, 121, 144]) / 256.0, markeredgecolor="black", markeredgewidth=0.5, ax_kws=dict( xlabel="Predicted", ylabel="Residual", xlabel_kws={"fontsize": 14}, ylabel_kws={"fontsize": 14}), ) # draw horizontal line on y=0 axis[1].axhline(0.0) plt.suptitle("Residual") return self.save_or_show(fname=f"{prefix}_residual", where=where)
[docs] def errors_plot(self, true, predicted, prefix, where, **kwargs): errors = Metrics[self.mode](true, predicted, multiclass=self.is_multiclass_) return plot_metrics( errors.calculate_all(), show=self.show, save_path=os.path.join(self.path, where), save=self.save, text_kws = {"fontsize": 16}, max_metrics_per_fig=20, )
[docs] def regression_plot( self, true, predicted, target_name, where, annotate_with="r2" ): annotation_val = getattr(RegressionMetrics(true, predicted), annotate_with)() metric_names = {'r2': "$R^2$"} annotation_key = metric_names.get(annotate_with, annotate_with) RIDGE_LINE_KWS = {'color': 'firebrick', 'lw': 1.0} if isinstance(predicted, (pd.DataFrame, pd.Series)): predicted = predicted.values marginals = True if np.isnan(np.array(true)).any() or np.isnan(predicted).any(): marginals = False # if all the values in predicted are same, calculation of kde gives error if (predicted == predicted[0]).all(): marginals = False try: axes = ep.regplot(true, predicted, marker_color='crimson', line_color='k', scatter_kws={'marker': "o", 'edgecolors': 'black', 'linewidth':0.5}, show=False, marginals=marginals, marginal_ax_pad=0.25, marginal_ax_size=0.7, ridge_line_kws=RIDGE_LINE_KWS, hist=False, ) except np.linalg.LinAlgError: axes = ep.regplot(true, predicted, marker_color='crimson', line_color='k', scatter_kws={'marker': "o", 'edgecolors': 'black', 'linewidth': 0.5}, show=False, marginals=False ) axes.annotate(f'{annotation_key}: {round(annotation_val, 3)}', xy=(0.3, 0.95), xycoords='axes fraction', horizontalalignment='right', verticalalignment='top', fontsize=16) return self.save_or_show(fname=f"{target_name}_regression", where=where)
[docs] def prediction_plot(self, true, predicted, prefix, where): mpl.rcParams.update(mpl.rcParamsDefault) _, axis = init_subplots(width=12, height=8) # it is quite possible that when data is datetime indexed, then it is not # equalidistant and large amount of graph # will have not data in that case lines plot will create a lot of useless # interpolating lines where no data is present. datetime_axis = False if isinstance(true.index, pd.DatetimeIndex) and pd.infer_freq(true.index) is not None: style = '.' true = true predicted = predicted datetime_axis = True else: if np.isnan(true.values).sum() > 0: # For Nan values we should be using this style otherwise nothing is plotted. style = '.' else: style = '-' true = true.values predicted = predicted.values ms = 4 if style == '.' else 2 # because the data is very large, so better to use small marker size if len(true) > 1000: ms = 2 axis.plot(predicted, style, color='r', label='Prediction') axis.plot(true, style, color='b', marker='o', fillstyle='none', markersize=ms, label='True') axis.legend(loc="best", fontsize=22, markerscale=4) if datetime_axis: loc = mdates.AutoDateLocator(minticks=4, maxticks=6) axis.xaxis.set_major_locator(loc) fmt = mdates.AutoDateFormatter(loc) axis.xaxis.set_major_formatter(fmt) plt.xticks(fontsize=18) plt.yticks(fontsize=18) plt.xlabel("Time", fontsize=18) return self.save_or_show(fname=f"{prefix}_prediction", where=where)
[docs] def plot_all_qs(self, true_outputs, predicted, save=False): plt.close('all') plt.style.use('ggplot') st, en = 0, true_outputs.shape[0] plt.plot(np.arange(st, en), true_outputs[st:en, 0], label="True", color='navy') for idx, q in enumerate(self.quantiles): q_name = "{:.1f}".format(q * 100) plt.plot(np.arange(st, en), predicted[st:en, idx], label="q {} %".format(q_name)) plt.legend(loc="best") self.save_or_show(save, fname="all_quantiles", where='results') return
[docs] def plot_quantiles1(self, true_outputs, predicted, st=0, en=None, save=True): plt.close('all') plt.style.use('ggplot') assert true_outputs.shape[-2:] == (1, 1) if en is None: en = true_outputs.shape[0] for q in range(len(self.quantiles) - 1): st_q = "{:.1f}".format(self.quantiles[q] * 100) en_q = "{:.1f}".format(self.quantiles[-q] * 100) plt.plot(np.arange(st, en), true_outputs[st:en, 0], label="True", color='navy') plt.fill_between(np.arange(st, en), predicted[st:en, q].reshape(-1, ), predicted[st:en, -q].reshape(-1, ), alpha=0.2, color='g', edgecolor=None, label=st_q + '_' + en_q) plt.legend(loc="best") self.save_or_show(save, fname='q' + st_q + '_' + en_q, where='results') return
[docs] def plot_quantiles2( self, true_outputs, predicted, st=0, en=None, save=True ): plt.close('all') plt.style.use('ggplot') if en is None: en = true_outputs.shape[0] for q in range(len(self.quantiles) - 1): st_q = "{:.1f}".format(self.quantiles[q] * 100) en_q = "{:.1f}".format(self.quantiles[q + 1] * 100) plt.plot(np.arange(st, en), true_outputs[st:en, 0], label="True", color='navy') plt.fill_between(np.arange(st, en), predicted[st:en, q].reshape(-1, ), predicted[st:en, q + 1].reshape(-1, ), alpha=0.2, color='g', edgecolor=None, label=st_q + '_' + en_q) plt.legend(loc="best") self.save_or_show(save, fname='q' + st_q + '_' + en_q + ".png", where='results') return
[docs] def plot_quantile(self, true_outputs, predicted, min_q: int, max_q, st=0, en=None, save=False): plt.close('all') plt.style.use('ggplot') if en is None: en = true_outputs.shape[0] q_name = "{:.1f}_{:.1f}_{}_{}".format(self.quantiles[min_q] * 100, self.quantiles[max_q] * 100, str(st), str(en)) plt.plot(np.arange(st, en), true_outputs[st:en, 0], label="True", color='navy') plt.fill_between(np.arange(st, en), predicted[st:en, min_q].reshape(-1, ), predicted[st:en, max_q].reshape(-1, ), alpha=0.2, color='g', edgecolor=None, label=q_name + ' %') plt.legend(loc="best") self.save_or_show(save, fname="q_" + q_name + ".png", where='results') return
[docs] def roc_curve(self, estimator, x, y, prefix=None): if hasattr(estimator, '_model'): if estimator._model.__class__.__name__ in ["XGBClassifier", "XGBRFClassifier"] and isinstance(x, np.ndarray): x = pd.DataFrame(x, columns=estimator.input_features) plot_roc_curve(estimator, x, y.reshape(-1, )) self.save_or_show(fname=f"{prefix}_roc") return
[docs] def confusion_matrix(self, true, predicted, prefix=None, cmap="Blues", **kwargs): """plots confusion matrix. cmap : **kwargs : any keyword arguments for imshow """ cm = ClassificationMetrics( true, predicted, multiclass=self.is_multiclass_).confusion_matrix() kws = { 'annotate': True, 'colorbar': True, 'cmap': cmap, 'xticklabels': self.classes_, 'yticklabels': self.classes_, 'ax_kws': {'xlabel': "Predicted Label", 'ylabel': "True Label"}, 'show': False, 'annotate_kws': {'fontsize': 14, "fmt": '%.f', 'ha':"left"} } kws.update(kwargs) ep.imshow(cm, **kws) self.save_or_show(fname=f"{prefix}_confusion_matrix") return
[docs] def precision_recall_curve(self, estimator, x, y, prefix=None): if hasattr(estimator, '_model'): if estimator._model.__class__.__name__ in ["XGBClassifier", "XGBRFClassifier"] and isinstance(x, np.ndarray): x = pd.DataFrame(x, columns=estimator.input_features) plot_precision_recall_curve(estimator, x, y.reshape(-1, )) self.save_or_show(fname=f"{prefix}_plot_precision_recall_curve") return
[docs] def process_rgr_results( self, true: np.ndarray, predicted: np.ndarray, metrics="minimal", prefix=None, index=None, remove_nans=True, inputs=None, ): """ predicted, true are arrays of shape (examples, outs, forecast_len). """ # if user_defined_data: if self.output_features is None: # when data is user_defined, we don't know what out_cols, and forecast_len are if predicted.size == len(predicted): out_cols = ['output'] forecast_len = 1 else: out_cols = [f'output_{i}' for i in range(predicted.shape[-1])] forecast_len = 1 true, predicted = self.maybe_not_3d_data(true, predicted) else: # for cases if they are 2D/1D, add the third dimension. true, predicted = self.maybe_not_3d_data(true, predicted) forecast_len = self.forecast_len if isinstance(forecast_len, dict): forecast_len = np.unique(list(forecast_len.values())).item() out_cols = self.output_features if isinstance(out_cols, dict): _out_cols = [] for cols in out_cols.values(): _out_cols = _out_cols + cols out_cols = _out_cols if len(out_cols) > 1 and not isinstance(predicted, np.ndarray): raise NotImplementedError(""" can not process results with more than 1 output arrays""") for idx, out in enumerate(out_cols): horizon_errors = {metric_name: [] for metric_name in ['nse', 'rmse']} for h in range(forecast_len): errs = dict() fpath = os.path.join(self.path, out) if not os.path.exists(fpath): os.makedirs(fpath) t = pd.DataFrame(true[:, idx, h], index=index, columns=['true_' + out]) p = pd.DataFrame(predicted[:, idx, h], index=index, columns=['pred_' + out]) if wandb is not None and self.wandb_config is not None: _wandb_scatter(t.values, p.values, out) df = pd.concat([t, p], axis=1) df = df.sort_index() fname = f"{prefix}_{out}_{h}" df.to_csv(os.path.join(fpath, fname + ".csv"), index_label='index') self.plot_results(t, p, prefix=fname, where=out, inputs=inputs) if remove_nans: nan_idx = np.isnan(t) t = t.values[~nan_idx] p = p.values[~nan_idx] errors = RegressionMetrics(t, p) errs[out + '_errors_' + str(h)] = getattr(errors, f'calculate_{metrics}')() errs[out + 'true_stats_' + str(h)] = ts_features(t) errs[out + 'predicted_stats_' + str(h)] = ts_features(p) dict_to_file(fpath, errors=errs, name=prefix) for p in horizon_errors.keys(): horizon_errors[p].append(getattr(errors, p)()) if forecast_len > 1: self.horizon_plots(horizon_errors, f'{prefix}_{out}_horizons.png') return
[docs] def process_cls_results( self, true: np.ndarray, predicted: np.ndarray, metrics="minimal", prefix=None, index=None, inputs=None, model=None, ): """post-processes classification results.""" if self.is_multilabel_: return self.process_multilabel(true, predicted, metrics, prefix, index) if self.is_multiclass_: return self.process_multiclass(true, predicted, metrics, prefix, index) else: return self.process_binary(true, predicted, metrics, prefix, index, model=None)
[docs] def process_multilabel(self, true, predicted, metrics, prefix, index): for label in range(true.shape[1]): if self.n_classes(true[:, label]) == 2: self.process_binary(true[:, label], predicted[:, label], metrics, f"{prefix}_{label}", index) else: self.process_multiclass(true[:, label], predicted[:, label], metrics, f"{prefix}_{label}", index) return
[docs] def process_multiclass(self, true, predicted, metrics, prefix, index): if len(predicted) == predicted.size: predicted = predicted.reshape(-1, 1) else: predicted = np.argmax(predicted, axis=1).reshape(-1, 1) if len(true) == true.size: true = true.reshape(-1, 1) else: true = np.argmax(true, axis=1).reshape(-1, 1) if self.output_features is None: self.output_features = [f'feature_{i}' for i in range(self.n_classes(true))] self.confusion_matrix(true, predicted, prefix=prefix) fname = os.path.join(self.path, f"{prefix}_prediction.csv") pd.DataFrame(np.concatenate([true, predicted], axis=1), columns=['true', 'predicted'], index=index).to_csv(fname) class_metrics = ClassificationMetrics(true, predicted, multiclass=True) dict_to_file(self.path, errors=class_metrics.calculate_all(), name=f"{prefix}_{dateandtime_now()}.json") return
[docs] def process_binary(self, true, predicted, metrics, prefix, index, model=None): assert self.n_classes(true) == 2 if model is not None: try: # todo, also plot for DL self.precision_recall_curve(model, x=true, y=predicted, prefix=prefix) self.roc_curve(model, x=true, y=predicted, prefix=prefix) except NotImplementedError: pass if predicted.ndim == 1: predicted = predicted.reshape(-1, 1) elif predicted.size != len(predicted): predicted = np.argmax(predicted, axis=1).reshape(-1, 1) if true.ndim == 1: true = true.reshape(-1, 1) elif true.size != len(true): true = np.argmax(true, axis=1).reshape(-1, 1) self.confusion_matrix(true, predicted, prefix=prefix) fpath = os.path.join(self.path, prefix) if not os.path.exists(fpath): os.makedirs(fpath) metrics_instance = ClassificationMetrics(true, predicted, multiclass=False) metrics = getattr(metrics_instance, f"calculate_{metrics}")() dict_to_file(fpath, errors=metrics, name=f"{prefix}_{dateandtime_now()}.json" ) fname = os.path.join(fpath, f"{prefix}_.csv") array = np.concatenate([true.reshape(-1, 1), predicted.reshape(-1, 1)], axis=1) pd.DataFrame(array, columns=['true', 'predicted'], index=index).to_csv(fname) return
[docs] def maybe_not_3d_data(self, true, predicted): forecast_len = self.forecast_len if true.ndim < 3: if isinstance(forecast_len, dict): forecast_len = set(list(forecast_len.values())) assert len(forecast_len) == 1 forecast_len = forecast_len.pop() assert forecast_len == 1, f'{forecast_len}' axis = 2 if true.ndim == 2 else (1, 2) true = np.expand_dims(true, axis=axis) if predicted.ndim < 3: assert forecast_len == 1 axis = 2 if predicted.ndim == 2 else (1, 2) predicted = np.expand_dims(predicted, axis=axis) return true, predicted
def plot_roc_curve(*args, **kwargs): try: func = sklearn.metrics.RocCurveDisplay.from_estimator except AttributeError: func = sklearn.metrics.plot_roc_curve return func(*args, **kwargs) def plot_precision_recall_curve(*args, **kwargs): try: func = sklearn.metrics.PrecisionRecallDisplay.from_estimator except AttributeError: func = sklearn.metrics.plot_precision_recall_curve return func(*args, **kwargs) def _wandb_scatter(true: np.ndarray, predicted: np.ndarray, name: str) -> None: """Adds a scatter plot on wandb.""" data = [[x, y] for (x, y) in zip(true.reshape(-1, ), predicted.reshape(-1, ))] table = wandb.Table(data=data, columns=["true", "predicted"]) wandb.log({ "scatter_plot": wandb.plot.scatter(table, "true", "predicted", title=name) }) return