Models#

Module with models for time-series forecasting.

Basic usage#

Models are used to make predictions. Let’s look at the basic example of usage:

>>> import pandas as pd
>>> from etna.datasets import TSDataset, generate_ar_df
>>> from etna.transforms import LagTransform
>>> from etna.models import LinearPerSegmentModel
>>>
>>> df = generate_ar_df(periods=100, start_time="2021-01-01", ar_coef=[1/2], n_segments=2)
>>> ts = TSDataset(df, freq="D")
>>> lag_transform = LagTransform(in_column="target", lags=[3, 4, 5])
>>> ts.fit_transform(transforms=[lag_transform])
>>> future_ts = ts.make_future(future_steps=3, transforms=[lag_transform])
>>> model = LinearPerSegmentModel()
>>> model.fit(ts)
LinearPerSegmentModel(fit_intercept = True, )
>>> forecast_ts = model.forecast(future_ts)
>>> forecast_ts is future_ts
True

There is a key note to mention: future_ts and forecast_ts are the same objects. Method forecast only fills ‘target’ columns in future_ts and return reference to it.

API details#

Base:

NonPredictionIntervalContextIgnorantAbstractModel()

Interface for models that don't support prediction intervals and don't need context for prediction.

NonPredictionIntervalContextRequiredAbstractModel()

Interface for models that don't support prediction intervals and need context for prediction.

PredictionIntervalContextIgnorantAbstractModel()

Interface for models that support prediction intervals and don't need context for prediction.

PredictionIntervalContextRequiredAbstractModel()

Interface for models that support prediction intervals and need context for prediction.

Naive models:

SeasonalMovingAverageModel([window, seasonality])

Seasonal moving average.

MovingAverageModel([window])

MovingAverageModel averages previous series values to forecast future one.

NaiveModel([lag])

Naive model predicts t-th value of series with its (t - lag) value.

DeadlineMovingAverageModel([window, seasonality])

Moving average model that uses exact previous dates to predict.

Statistical models:

AutoARIMAModel(**kwargs)

Class for holding auto arima model.

SARIMAXModel([order, seasonal_order, trend, ...])

Class for holding SARIMAX model.

HoltWintersModel([trend, damped_trend, ...])

Holt-Winters' etna model.

HoltModel([exponential, damped_trend, ...])

Holt etna model.

SimpleExpSmoothingModel([...])

Exponential smoothing etna model.

ProphetModel([growth, changepoints, ...])

Class for holding Prophet model.

TBATSModel([use_box_cox, box_cox_bounds, ...])

Class for holding segment interval TBATS model.

BATSModel([use_box_cox, box_cox_bounds, ...])

Class for holding segment interval BATS model.

StatsForecastARIMAModel([order, ...])

Class for holding statsforecast.models.ARIMA.

StatsForecastAutoARIMAModel([d, D, max_p, ...])

Class for holding statsforecast.models.AutoARIMA.

StatsForecastAutoCESModel([season_length, model])

Class for holding statsforecast.models.AutoCES.

StatsForecastAutoETSModel([season_length, ...])

Class for holding statsforecast.models.AutoETS.

StatsForecastAutoThetaModel([season_length, ...])

Class for holding statsforecast.models.AutoTheta.

ML-models:

CatBoostMultiSegmentModel([iterations, ...])

Class for holding Catboost model for all segments.

CatBoostPerSegmentModel([iterations, depth, ...])

Class for holding per segment Catboost model.

ElasticMultiSegmentModel([alpha, l1_ratio, ...])

Class holding sklearn.linear_model.ElasticNet for all segments.

ElasticPerSegmentModel([alpha, l1_ratio, ...])

Class holding per segment sklearn.linear_model.ElasticNet.

LinearMultiSegmentModel([fit_intercept])

Class holding sklearn.linear_model.LinearRegression for all segments.

LinearPerSegmentModel([fit_intercept])

Class holding per segment sklearn.linear_model.LinearRegression.

SklearnMultiSegmentModel(regressor)

Class for holding Sklearn model for all segments.

SklearnPerSegmentModel(regressor)

Class for holding per segment Sklearn model.

Native neural network models:

nn.RNNModel(input_size, decoder_length, ...)

RNN based model on LSTM cell.

nn.MLPModel(input_size, decoder_length, ...)

MLPModel.

nn.DeepStateModel(ssm, input_size, ...[, ...])

DeepState model.

nn.NBeatsGenericModel(input_size, output_size)

Generic N-BEATS model.

nn.NBeatsInterpretableModel(input_size, ...)

Interpretable N-BEATS model.

nn.PatchTSModel(decoder_length, encoder_length)

PatchTS model using PyTorch layers.

nn.DeepARNativeModel(input_size, ...[, ...])

DeepAR based model on LSTM cell.

nn.TFTNativeModel(encoder_length, decoder_length)

TFT model.

Utilities for DeepStateModel

nn.deepstate.CompositeSSM(seasonal_ssms[, ...])

Class to compose several State Space Models.

nn.deepstate.LevelSSM()

Class for Level State Space Model.

nn.deepstate.LevelTrendSSM()

Class for Level-Trend State Space Model.

nn.deepstate.SeasonalitySSM(num_seasons, ...)

Class for Seasonality State Space Model.

nn.deepstate.DaylySeasonalitySSM()

Class for Daily Seasonality State Space Model.

nn.deepstate.SeasonalitySSM(num_seasons, ...)

Class for Seasonality State Space Model.

nn.deepstate.YearlySeasonalitySSM()

Class for Yearly Seasonality State Space Model.

Neural network models based on pytorch_forecasting:

nn.DeepARModel(*args, **kwargs)

Wrapper for pytorch_forecasting.models.deepar.DeepAR.

nn.TFTModel(*args, **kwargs)

Wrapper for pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer.

Utilities for neural network models based on pytorch_forecasting:

nn.PytorchForecastingDatasetBuilder([...])

Builder for PytorchForecasting dataset.