DirectEnsemble

class DirectEnsemble(pipelines: List[etna.pipeline.base.BasePipeline], n_jobs: int = 1, joblib_params: Optional[Dict[str, Any]] = None)[source]

Bases: etna.ensembles.mixins.EnsembleMixin, etna.ensembles.mixins.SaveEnsembleMixin, etna.pipeline.base.BasePipeline

DirectEnsemble is a pipeline that forecasts future values merging the forecasts of base pipelines.

Ensemble expects several pipelines during init. These pipelines are expected to have different forecasting horizons. For each point in the future, forecast of the ensemble is forecast of base pipeline with the shortest horizon, which covers this point.

Examples

>>> from etna.datasets import generate_ar_df
>>> from etna.datasets import TSDataset
>>> from etna.ensembles import DirectEnsemble
>>> from etna.models import NaiveModel
>>> from etna.models import ProphetModel
>>> from etna.pipeline import Pipeline
>>> df = generate_ar_df(periods=30, start_time="2021-06-01", ar_coef=[1.2], n_segments=3)
>>> df_ts_format = TSDataset.to_dataset(df)
>>> ts = TSDataset(df_ts_format, "D")
>>> prophet_pipeline = Pipeline(model=ProphetModel(), transforms=[], horizon=3)
>>> naive_pipeline = Pipeline(model=NaiveModel(lag=10), transforms=[], horizon=5)
>>> ensemble = DirectEnsemble(pipelines=[prophet_pipeline, naive_pipeline])
>>> _ = ensemble.fit(ts=ts)
>>> forecast = ensemble.forecast()
>>> forecast
segment    segment_0 segment_1 segment_2
feature       target    target    target
timestamp
2021-07-01    -10.37   -232.60    163.16
2021-07-02    -10.59   -242.05    169.62
2021-07-03    -11.41   -253.82    177.62
2021-07-04     -5.85   -139.57     96.99
2021-07-05     -6.11   -167.69    116.59

Init DirectEnsemble.

Parameters
  • pipelines (List[etna.pipeline.base.BasePipeline]) – List of pipelines that should be used in ensemble

  • n_jobs (int) – Number of jobs to run in parallel

  • joblib_params (Optional[Dict[str, Any]]) – Additional parameters for joblib.Parallel

Raises

ValueError: – If two or more pipelines have the same horizons.

Inherited-members

Methods

backtest(ts, metrics[, n_folds, mode, ...])

Run backtest with the pipeline.

fit(ts)

Fit pipelines in ensemble.

forecast([prediction_interval, quantiles, ...])

Make predictions.

load(path[, ts])

Load an object.

predict(ts[, start_timestamp, ...])

Make in-sample predictions on dataset in a given range.

save(path)

Save the object.

to_dict()

Collect all information about etna object in dict.

Attributes

fit(ts: etna.datasets.tsdataset.TSDataset) etna.ensembles.direct_ensemble.DirectEnsemble[source]

Fit pipelines in ensemble.

Parameters

ts (etna.datasets.tsdataset.TSDataset) – TSDataset to fit ensemble

Returns

Fitted ensemble

Return type

self