Source code for ai4water.models._torch.pytorch_training


from typing import Union

import gc
from SeqMetrics import RegressionMetrics

from ai4water.backend import easy_mpl as ep
from ai4water.backend import os, np, plt, torch, pd

try:
    import wandb
except ModuleNotFoundError:
    wandb = None


# only so that docs can be built without having torch to be installed
try:
    from .utils import to_torch_dataset, TorchMetrics
except ModuleNotFoundError:
    to_torch_dataset, TorchMetrics = None, None

if torch is not None:
    from .pytorch_attributes import LOSSES
else:
    LOSSES = {}

from ai4water.utils.utils import dateandtime_now, find_best_weight

F = {
    'mse': [np.nanmin, np.less],
    'nse': [np.nanmax, np.greater],
    'r2': [np.nanmax, np.greater],
    'pbias': [np.nanmin, np.less],
    'mape': [np.nanmin, np.less],
    'rmse': [np.nanmin, np.less],
    'nrmse': [np.nanmin, np.less],
    'kge': [np.nanmax, np.greater],
}


class AttributeContainer(object):

    def __init__(self, num_epochs, to_monitor=None, use_cuda=None,
                 path=None, verbosity=1):
        self.to_monitor = get_metrics_to_monitor(to_monitor)
        self.num_epochs = num_epochs

        self.epoch = 0
        self.val_loader = None
        self.train_loader = None
        self.criterion = None
        self.optimizer = None
        self.val_epoch_losses = {}
        self.train_epoch_losses = None
        self.train_metrics = {metric: np.full(num_epochs, np.nan) for metric in self.to_monitor}
        self.val_metrics = {f'val_{metric}': np.full(num_epochs, np.nan) for metric in self.to_monitor}
        self.best_epoch = 0  # todo,
        self.use_cuda = use_cuda if use_cuda is not None else torch.cuda.is_available()
        self.verbosity = verbosity

        def_path = path if path is not None else os.path.join(os.getcwd(), 'results', dateandtime_now())
        if not os.path.exists(def_path) and verbosity >= 0:
            if not os.path.isdir(def_path):
                os.makedirs(def_path)
            else:
                os.mkdir(def_path)
        self.path = def_path

    @property
    def use_cuda(self):
        return self._use_cuda

    @use_cuda.setter
    def use_cuda(self, x):
        self._use_cuda = x

    @property
    def optimizer(self):
        return self._optimizer

    @optimizer.setter
    def optimizer(self, x):
        self._optimizer = x

    @property
    def loss(self):
        return self._loss

    @loss.setter
    def loss(self, x):
        if isinstance(x, str):
            x = LOSSES[x.upper()]()
        self._loss = x

    @property
    def path(self):
        return self._path

    @path.setter
    def path(self, x):
        self._path = x

    def _device(self):
        if self.use_cuda:
            return torch.device("cuda")
        else:
            return torch.device("cpu")


[docs]class Learner(AttributeContainer): """Trains the pytorch model. Motivated from fastai"""
[docs] def __init__( self, model, # torch.nn.Module, batch_size: int = 32, num_epochs: int = 14, patience: int = 100, shuffle: bool = True, to_monitor: list = None, use_cuda:bool = False, path: str = None, wandb_config:dict = None, verbosity=1, **kwargs ): """ Initializes the Learner class Arguments: model : a pytorch model having following attributes and methods - num_outs - w_path - `loss` - `get_optimizer` batch_size : batch size num_epochs : Number of epochs for which to train the model patience : how many epochs to wait before stopping the training in case `to_monitor` does not improve. shuffle : use_cuda : whether to use cuda or not to_monitor : list of metrics to monitor path : path to save results/weights wandb_config : config for wandb Example ------- >>> from torch import nn >>> import torch >>> from ai4water.models._torch import Learner ... >>> class Net(nn.Module): >>> def __init__(self, D_in, H, D_out): ... super(Net, self).__init__() ... # hidden layer ... self.linear1 = nn.Linear(D_in, H) ... self.linear2 = nn.Linear(H, D_out) >>> def forward(self, x): ... l1 = self.linear1(x) ... a1 = torch.sigmoid(l1) ... yhat = torch.sigmoid(self.linear2(a1)) ... return yhat ... >>> learner = Learner(model=Net(1, 2, 1), ... num_epochs=501, ... patience=50, ... batch_size=1, ... shuffle=False) ... >>> learner.optimizer = torch.optim.SGD(learner.model.parameters(), lr=0.1) >>> def criterion_cross(labels, outputs): ... out = -1 * torch.mean(labels * torch.log(outputs) + (1 - labels) * torch.log(1 - outputs)) ... return out >>> learner.loss = criterion_cross ... >>> X = torch.arange(-20, 20, 1).view(-1, 1).type(torch.FloatTensor) >>> Y = torch.zeros(X.shape[0]) >>> Y[(X[:, 0] > -4) & (X[:, 0] < 4)] = 1.0 ... >>> learner.fit(X, Y) >>> metrics = learner.evaluate(X, y=Y, metrics=['r2', 'nse', 'mape']) >>> t = learner.predict(X, y=Y, name='training') """ super().__init__(num_epochs, to_monitor, path=path, use_cuda=use_cuda, verbosity=verbosity) if self.use_cuda: model = model.to(self._device()) self.model = model self.batch_size = batch_size self.shuffle = shuffle self.patience = patience self.wandb_config = wandb_config
[docs] def fit( self, x, y=None, validation_data=None, **kwargs ): """Runs the training loop for pytorch model. Arguments --------- x : Can be one of following - an instance of torch.Dataset, y will be ignored - an instance of torch.DataLoader, y will be ignored - a torch tensor containing input data for each example - a numpy array or pandas DataFrame - a list of torch tensors or numpy arrays y : if `x` is torch tensor, then `y` is the label/target for each corresponding example. validation_data : can be one of following: - an instance of torch.Dataset - an instance of torch.DataLoader - a tuple of x,y pairs where x and y are tensors Default is None, which means no validation is performed. kwargs : can be `callbacks` For example to use a callable as callback use following >>> callbacks = [{'after_epochs': 300, 'func': PlotStuff}] where `PlotStuff` is a callable. Each `callable` is provided with following keyword arguments - epoch : the current epoch at which callable is called. - model : the model - train_data : training data_loader - val_data : validation data_loader """ self.on_train_begin(x, y=y, validation_data=validation_data, **kwargs) for epoch in range(self.num_epochs): self.epoch = epoch self.train_for_epoch() self.validate_for_epoch() self.on_epoch_end() if epoch - self.best_epoch > self.patience: if self.verbosity > 0: print(f"Stopping early because improvment in loss did not happen since {self.best_epoch}th epoch") break return self.on_train_end()
[docs] def predict( self, x, y=None, batch_size: int = None, reg_plot: bool = True, name: str = None, **kwargs ) -> np.ndarray: """Makes prediction on the given data Arguments: x : data on which to evalute. It can be - a torch.utils.data.Dataset - a torch.utils.data.DataLoader - a torch.Tensor - a numpy array - a list of torch tensors numpy arrays y : only relevent if `x` is torch.Tensor. It comprises labels for correspoing x. batch_size : None means make prediction on whole data in one go reg_plot : whether to plot regression line or not name : string to be used for title and name of saved plot Returns: predicted output as numpy array """ true, pred = self._eval(x=x, y=y, batch_size=batch_size) if y is not None and reg_plot and pred.size > 0.0: ep.regplot(true, pred, show=False) plt.savefig(os.path.join(self.path, f'{name}_regplot.png')) #if self.use_cuda: torch.cuda.empty_cache() gc.collect() return pred
def _eval(self, x, y=None, batch_size=None): loader, _ = self._get_loader(x=x, y=y, batch_size=batch_size, shuffle=False) true, pred = [], [] for i, (batch_x, batch_y) in enumerate(loader): batch_y, pred_y = self.eval(batch_x, batch_y) true.append(batch_y.detach().cpu().numpy()) pred.append(pred_y.detach().cpu().numpy()) true = np.concatenate(true) pred = np.concatenate(pred) del loader del batch_x del batch_y gc.collect() return true, pred def eval(self, batch_x, batch_y): """Calls the model with x and y data and returns trues and preds""" batch_x = batch_x if isinstance(batch_x, list) else [batch_x] batch_x = [tensor.float() for tensor in batch_x] if self.use_cuda: batch_x = [tensor.cuda() for tensor in batch_x] batch_y = batch_y.cuda() pred_y = self.model(*batch_x) del batch_x return batch_y, pred_y
[docs] def evaluate( self, x, y, batch_size: int = None, metrics: Union[str, list] = 'r2', **kwargs ): """ Evaluates the `model` on the given data. Arguments: x : data on which to evalute. It can be - a torch.utils.data.Dataset - a torch.utils.data.DataLoader - a torch.Tensor - a numpy.ndarray - a list of torch tensors numpy arrays y : It comprises labels for correspoing x. batch_size : None means make prediction on whole data in one go metrics : name of performance metric to measure. It can be a single metric or a list of metrics. Allowed metrics are anyone from `ai4water.post_processing.SeqMetrics.RegressionMetrics` kwargs : Returns: if metrics is string the returned value is float otherwise it will be a dictionary """ # todo y->pred is only converting tensor into numpy array true, pred = self._eval(x=x, y=y, batch_size=batch_size) evaluator = RegressionMetrics(true, pred) errors = {} if isinstance(metrics, str): errors = getattr(evaluator, metrics)() else: assert isinstance(metrics, list) for m in metrics: errors[m] = getattr(evaluator, m)() return errors
def train_for_epoch(self): """Trains pytorch model for one complete epoch""" epoch_losses = {metric: np.full(len(self.train_loader), np.nan) for metric in self.to_monitor} # todo, it would be better to avoid reshaping/view at all if hasattr(self.model, 'num_outs'): num_outs = self.model.num_outs else: num_outs = self.num_outs for i, (batch_x, batch_y) in enumerate(self.train_loader): # todo, feeding batch_y to eval is only putting it on right device # can we do it before? batch_y, pred_y = self.eval(batch_x, batch_y) if num_outs: batch_y = batch_y.float().view(len(batch_y), num_outs) pred_y = pred_y.view(len(pred_y), num_outs) loss = self.criterion(batch_y, pred_y) loss = loss.float() loss.backward() self.optimizer.step() self.optimizer.zero_grad() # calculate metrics for each mini-batch er = TorchMetrics(batch_y, pred_y) for k, v in epoch_losses.items(): v[i] = getattr(er, k)().detach().item() # epoch_losses['mse'][i] = loss.detach() # take the mean for all mini-batches without considering infinite values self.train_epoch_losses = {k: round(float(np.mean(v[np.isfinite(v)])), 4) for k, v in epoch_losses.items()} if self.wandb_config is not None: wandb.log(self.train_epoch_losses, step=self.epoch) if self.use_cuda: torch.cuda.empty_cache() return def validate_for_epoch(self): """If validation data is available, then it performs the validation """ if self.val_loader is not None: epoch_losses = {metric: np.full(len(self.val_loader), np.nan) for metric in self.to_monitor} for i, (batch_x, batch_y) in enumerate(self.val_loader): batch_y, pred_y = self.eval(batch_x, batch_y) # calculate metrics for each mini-batch er = TorchMetrics(batch_y, pred_y) for k, v in epoch_losses.items(): v[i] = getattr(er, k)().detach().item() # take the mean for all mini-batches self.val_epoch_losses = {f'val_{k}': round(float(np.mean(v)), 4) for k, v in epoch_losses.items()} if self.wandb_config is not None: wandb.log(self.val_epoch_losses, step=self.epoch) for k, v in self.val_epoch_losses.items(): metric = k.split('_')[1] f1 = F[metric][0] f2 = F[metric][1] if f2(v, f1(self.val_metrics[k])): torch.save(self.model.state_dict(), self._weight_fname(self.epoch, v)) self.best_epoch = self.epoch break # weights are saved for this epoch so no need to check other metrics return def _weight_fname(self, epoch, loss): return os.path.join(getattr(self.model, 'w_path', self.path), f"weights_{epoch}_{loss}") def _get_train_val_loaders(self, x, y=None, validation_data=None): train_loader, self.num_outs = self._get_loader(x=x, y=y, batch_size=self.batch_size, shuffle=self.shuffle) val_loader, _ = self._get_loader(x=validation_data, batch_size=self.batch_size, shuffle=self.shuffle) return train_loader, val_loader def on_train_begin(self, x, y=None, validation_data=None, **kwargs): self.cbs = kwargs.get('callbacks', []) # no callback by default if self.verbosity > 0: print("{}{}{}".format('*' * 25, 'Training Started', '*' * 25)) formatter = "{:<7}" + " {:<15}" * (len(self.train_metrics) + len(self.val_metrics)) print(formatter.format('Epoch: ', *self.train_metrics.keys(), *self.train_metrics.keys())) print("{}".format('*' * 70)) if hasattr(self.model, 'loss'): self.criterion = self.model.loss() else: self.criterion = self.loss if hasattr(self.model, 'get_optimizer'): self.optimizer = self.model.get_optimizer() else: self.optimizer = self.optimizer self.train_loader, self.val_loader = self._get_train_val_loaders( x, y=y, validation_data=validation_data) if self.wandb_config is not None: assert wandb is not None assert isinstance(self.wandb_config, dict) wandb.init(name=os.path.basename(self.path), project=self.wandb_config.get('probject', 'test_project'), notes='This is Learner from AI4Water test run', tags=['ai4water', 'pytorch'], entity=self.wandb_config.get('entity', '')) return def on_train_end(self): self.update_weights() self.train_metrics['loss'] = self.train_metrics.pop('mse') self.val_metrics['val_loss'] = self.val_metrics.pop('val_mse') class History(object): history = {} history.update(self.train_metrics) history.update(self.val_metrics) setattr(self, 'history', History()) if self.wandb_config is not None: wandb.finish() return History()
[docs] def update_weights(self, weight_file_path: str = None): """If `weight_file_path` is not given then it finds the best weights and updates the model with best wieghts. Arguments: weight_file_path : complete path of weights which are to be loaded """ if weight_file_path: assert os.path.exists(weight_file_path) best_weights = os.path.basename(weight_file_path) else: w_path = getattr(self.model, 'w_path', self.path) best_weights = find_best_weight(w_path, epoch_identifier=self.best_epoch) if best_weights is not None: if best_weights.endswith(".hdf5"): # todo, find_best_weight should not add .hdf5 best_weights = best_weights.split(".hdf5")[0] weight_file_path = os.path.join(w_path, best_weights) if best_weights is not None: # fpath = os.path.splitext(weight_file_path)[0] # we are not saving the whole model but only state_dict self.model.load_state_dict(torch.load(weight_file_path)) if self.verbosity > 0: print("{} Successfully loaded weights from {} file {}".format('*' * 10, best_weights, '*' * 10)) return
[docs] def update_metrics(self): for k, v in self.train_metrics.items(): v[self.epoch] = self.train_epoch_losses[k] if self.val_loader is not None: for k, v in self.val_metrics.items(): v[self.epoch] = self.val_epoch_losses[k] return
def on_epoch_begin(self): return def on_epoch_end(self): formatter = "{:<7}" + "{:<15.7f} " * (len(self.val_epoch_losses) + len(self.train_epoch_losses)) if self.val_loader is None: # otherwise model is already saved based upon validation performance for k, v in self.train_epoch_losses.items(): f1 = F[k][0] f2 = F[k][1] if f2(v, f1(self.train_metrics[k])): torch.save(self.model.state_dict(), self._weight_fname(self.epoch, v)) self.best_epoch = self.epoch break if self.verbosity > 0: print(formatter.format(self.epoch, *self.train_epoch_losses.values(), *self.val_epoch_losses.values())) for cb in self.cbs: if self.epoch % cb['after_epochs'] == 0: cb['func'](epoch=self.epoch, model=self.model, train_data=self.train_loader, val_data=self.val_loader ) self.update_metrics() return def _get_loader(self, x, y=None, batch_size=None, shuffle=True): data_loader = None num_outs = None if x is None: return None, None if isinstance(x, list): if len(x) == 1: x = x[0] if isinstance(x, torch.utils.data.Dataset): dataset = x else: dataset = to_torch_dataset(x, y) else: dataset = to_torch_dataset(x, y) elif isinstance(x, (np.ndarray, pd.DataFrame)): if y is not None: # if x is numpy array or DataFrame, so should y assert isinstance(y, (np.ndarray, pd.DataFrame, pd.Series)) # if it is DataFrame or Series if hasattr(y, 'values'): y = y.values if len(y.shape) == 1: num_outs = 1 else: num_outs = y.shape[-1] if isinstance(x, pd.DataFrame): x = x.values dataset = to_torch_dataset(x, y) elif isinstance(x, torch.utils.data.Dataset): dataset = x elif isinstance(x, torch.utils.data.DataLoader): data_loader = x elif isinstance(x, torch.Tensor): dataset = to_torch_dataset(x=x, y=y) elif isinstance(x, tuple): # x is tuple of x,y pairs assert len(x) == 2 dataset = to_torch_dataset(x=x[0], y=x[1]) else: raise NotImplementedError(f'unrecognized data of type {x.__class__.__name__} given') if data_loader is None: if batch_size is None: batch_size = len(dataset) data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=shuffle ) return data_loader, num_outs def plot_model_using_tensorboard( self, x=None, path='tensorboard/tensorboard' ): """Plots the neural network on tensorboard Arguments --------- x : torch.Tensor input to the model path : str path to save tensorboard graph """ from torch.utils.tensorboard import SummaryWriter # default `log_dir` is "runs" - we'll be more specific here writer = SummaryWriter(path) if x is None: x, _ = iter(self.train_loader).next() writer.add_graph(self.model, x) writer.close() return
[docs] def plot_model(self, y=None): """Helper function to plot dot diagram of model using torchviz module. Arguments --------- y : torch.Tensor output tensor """ try: from torchviz import make_dot except ModuleNotFoundError: print("You must install torchviz to plot model." "see https://github.com/szagoruyko/pytorchviz#installation for installation") return if y is None: x, _ = iter(self.train_loader).next() y = self.model(x) fname = os.path.join(self.path, 'model.png') dot = make_dot(y, dict(self.model.named_parameters()), show_attrs=True, show_saved=True) dot.render(fname) return dot
def get_metrics_to_monitor(metrics): if metrics is None: _metrics = ['mse'] elif isinstance(metrics, list): _metrics = metrics + ['mse'] else: assert isinstance(metrics, str) _metrics = ['mse', metrics] return list(set(_metrics))