DeepBaseModel¶
- class DeepBaseModel(*, net: etna.models.base.DeepBaseNet, encoder_length: int, decoder_length: int, train_batch_size: int, test_batch_size: int, trainer_params: Optional[dict], train_dataloader_params: Optional[dict], test_dataloader_params: Optional[dict], val_dataloader_params: Optional[dict], split_params: Optional[dict])[source]¶
Bases:
etna.models.base.DeepBaseAbstractModel,etna.models.mixins.SaveNNMixin,etna.models.base.NonPredictionIntervalContextRequiredAbstractModelClass for partially implemented interfaces for holding deep models.
Init DeepBaseModel.
- Parameters
net (etna.models.base.DeepBaseNet) – network to train
encoder_length (int) – encoder length
decoder_length (int) – decoder length
train_batch_size (int) – batch size for training
test_batch_size (int) – batch size for testing
trainer_params (Optional[dict]) – Pytorch ligthning trainer parameters (api reference
pytorch_lightning.trainer.trainer.Trainer)train_dataloader_params (Optional[dict]) – parameters for train dataloader like sampler for example (api reference
torch.utils.data.DataLoader)test_dataloader_params (Optional[dict]) – parameters for test dataloader
val_dataloader_params (Optional[dict]) – parameters for validation dataloader
split_params (Optional[dict]) –
- dictionary with parameters for
torch.utils.data.random_split()for train-test splitting train_size: (float) value from 0 to 1 - fraction of samples to use for training
generator: (Optional[torch.Generator]) - generator for reproducibile train-test splitting
torch_dataset_size: (Optional[int]) - number of samples in dataset, in case of dataset not implementing
__len__
- dictionary with parameters for
- Inherited-members
Methods
fit(ts)Fit model.
forecast(ts, prediction_size)Make predictions.
Get model.
load(path)Load an object.
predict(ts, prediction_size)Make predictions.
raw_fit(torch_dataset)Fit model on torch like Dataset.
raw_predict(torch_dataset)Make inference on torch like Dataset.
save(path)Save the object.
to_dict()Collect all information about etna object in dict.
Attributes
Context size of the model.
- fit(ts: etna.datasets.tsdataset.TSDataset) etna.models.base.DeepBaseModel[source]¶
Fit model.
- Parameters
ts (etna.datasets.tsdataset.TSDataset) – TSDataset with features
- Returns
Model after fit
- Return type
- forecast(ts: etna.datasets.tsdataset.TSDataset, prediction_size: int) etna.datasets.tsdataset.TSDataset[source]¶
Make predictions.
This method will make autoregressive predictions.
- Parameters
ts (etna.datasets.tsdataset.TSDataset) – Dataset with features and expected decoder length for context
prediction_size (int) – Number of last timestamps to leave after making prediction. Previous timestamps will be used as a context.
- Returns
Dataset with predictions
- Return type
- get_model() etna.models.base.DeepBaseNet[source]¶
Get model.
- Returns
Torch Module
- Return type
- predict(ts: etna.datasets.tsdataset.TSDataset, prediction_size: int) etna.datasets.tsdataset.TSDataset[source]¶
Make predictions.
This method will make predictions using true values instead of predicted on a previous step. It can be useful for making in-sample forecasts.
- Parameters
ts (etna.datasets.tsdataset.TSDataset) – Dataset with features and expected decoder length for context
prediction_size (int) – Number of last timestamps to leave after making prediction. Previous timestamps will be used as a context.
- Returns
Dataset with predictions
- Return type
- raw_fit(torch_dataset: torch.utils.data.dataset.Dataset) etna.models.base.DeepBaseModel[source]¶
Fit model on torch like Dataset.
- Parameters
torch_dataset (torch.utils.data.dataset.Dataset) – Torch like dataset for model fit
- Returns
Model after fit
- Return type
- raw_predict(torch_dataset: torch.utils.data.dataset.Dataset) Dict[Tuple[str, str], numpy.ndarray][source]¶
Make inference on torch like Dataset.
- Parameters
torch_dataset (torch.utils.data.dataset.Dataset) – Torch like dataset for model inference
- Returns
Dictionary with predictions
- Return type
Dict[Tuple[str, str], numpy.ndarray]
- property context_size: int¶
Context size of the model.