Source code for ai4water.preprocessing.dataset._pipeline



__all__ = ['DataSetPipeline']

from ai4water.backend import np, os

from ._main import _DataSet


[docs]class DataSetPipeline(_DataSet): """A collection of DataSets concatenated one after the other. A DataSetPipeLine of four DataSets will be as follows: +----------+ | DataSet1 | +----------+ | DataSet2 | +----------+ | DataSet3 | +----------+ | DataSet4 | +----------+ The only condition for different datasets is that they have the same output dimension. """
[docs] def __init__( self, *datasets: _DataSet, verbosity=1 ) -> None: """ Parameters ---------- *datasets : the datasets to be combined verbosity : controls the output information being printed. Examples --------- >>> import pandas as pd >>> from ai4water.preprocessing import DataSet, DataSetPipeline >>> df1 = pd.DataFrame(np.random.random((100, 10)), ... columns=[f"Feat_{i}" for i in range(10)]) >>> df2 = pd.DataFrame(np.random.random((200, 10)), ... columns=[f"Feat_{i}" for i in range(10)]) >>> ds1 = DataSet(df1) >>> ds2 = DataSet(df2) >>> ds = DataSetPipeline(ds1, ds2) >>> train_x, train_y = ds.training_data() >>> val_x, val_y = ds.validation_data() >>> test_x, test_y = ds.test_data() """ self.verbosity = verbosity self._datasets = [] for ds in datasets: ds.verbosity = 0 assert isinstance(ds, _DataSet), f""" {ds} is not a valid dataset""" self._datasets.append(ds) self.examples = {} _DataSet.__init__(self, config={}, path=os.getcwd()) self.index = 0
def __iter__(self): return self def __next__(self): try: item = self._datasets[self.index] except IndexError: self.index = 0 raise StopIteration self.index += 1 return item def __getitem__(self, item:int): return self._datasets[item] @property def num_datasets(self) -> int: return len(self._datasets) @property def teacher_forcing(self): return all([ds.teacher_forcing for ds in self._datasets]) @property def mode(self): return all([ds.mode for ds in self._datasets]) @property def is_binary(self): return all([ds.is_binary for ds in self._datasets]) @property def input_features(self): _input_features = [ds.input_features for ds in self._datasets] return _input_features @property def output_features(self): _output_features = [ds.output_features for ds in self._datasets] return _output_features
[docs] def training_data(self, key="train", **kwargs): if self.teacher_forcing: x, prev_y, y = self._get_x_yy('training_data') return x, prev_y, y else: x, y = self._get_xy('training_data') return self.return_xy(x, y, "Training")
[docs] def validation_data(self, key="val", **kwargs): if self.teacher_forcing: x, prev_y, y = self._get_x_yy('validation_data') return x, prev_y, y else: x, y = self._get_xy('validation_data') return self.return_xy(x, y, "Validation")
[docs] def test_data(self, key="test", **kwargs): if self.teacher_forcing: x, prev_y, y = self._get_x_yy('test_data') return x, prev_y, y else: x, y = self._get_xy('test_data') return self.return_xy(x, y, "Test")
def _get_x_yy(self, method): x, prev_y, y = [], [], [] exs = {} for idx, ds in enumerate(self._datasets): _x, _prev_y, _y = getattr(ds, method)() x.append(_x) prev_y.append(_prev_y) y.append(_y) exs[idx] = {'x': len(x), 'y': len(y)} self.examples[method] = exs if not all([i.size for i in x]): x = conform_shape(x) prev_y = conform_shape(prev_y) y = conform_shape(y) return np.row_stack(x), np.row_stack(prev_y), np.row_stack(y) def _get_xy(self, method): x, y = [], [] exs = {} for idx, ds in enumerate(self._datasets): _x, _y = getattr(ds, method)() x.append(_x) y.append(_y) exs[idx] = {'x': len(x), 'y': len(y)} self.examples[method] = exs if not all([i.size for i in x]): x = conform_shape(x) y = conform_shape(y) return np.row_stack(x), np.row_stack(y)
def conform_shape(alist:list): desired_shape = list([i.shape for i in alist if i.size != 0][0]) desired_shape[0] = 0 return [np.zeros(desired_shape) if arr.size == 0 else arr for arr in alist]