Skip to content

Commit

Permalink
[DOC] Add example to docstrings of sktime estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
felipeangelimvieira authored Aug 26, 2024
2 parents fba002c + c6a9da2 commit 8284b2e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 23 deletions.
70 changes: 47 additions & 23 deletions src/prophetverse/sktime/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,34 @@


class HierarchicalProphet(BaseProphetForecaster):
"""A Bayesian hierarchical time series forecasting model based on the Prophet.
"""A Bayesian hierarchical time series forecasting model based on Meta's Prophet.
This class forecasts all series in a hierarchy at once, using a MultivariateNormal
as the likelihood function and LKJ priors for the correlation matrix.
This method forecasts all bottom series in a hierarchy at once, using a
MultivariateNormal as the likelihood function and LKJ priors for the correlation
matrix.
This class may be interesting if you want to fit shared coefficients across series.
By default, all coefficients are obtained exclusively for each series, but this can
be changed through the `shared_coefficients` parameter.
This forecaster is particularly interesting if you want to fit shared coefficients
across series. In that case, `shared_features` parameter should be a list of
feature names that should have that behaviour.
Parameters
----------
trend : Union[str, BaseEffect], optional, one of "linear" (default) or "logistic"
trend : Union[str, BaseEffect], optional, default="linear"
Type of trend to use. Can also be a custom effect object.
changepoint_interval : int, optional, default=25
Number of potential changepoints to sample in the history.
changepoint_range : float or int, optional, default=0.8
changepoint_range : Union[float, int], optional, default=0.8
Proportion of the history in which trend changepoints will be estimated.
* if float, must be between 0 and 1.
* If float, must be between 0 and 1 (inclusive).
The range will be that proportion of the training history.
* if int, can be positive or negative.
Absolute value must be less than number of training points.
* If int, can be positive or negative.
Absolute value must be less than the number of training points.
The range will be that number of points.
A negative int indicates number of points
A negative int indicates the number of points
counting from the end of the history, a positive int from the beginning.
changepoint_prior_scale : float, optional, default=0.001
Expand All @@ -62,14 +63,14 @@ class HierarchicalProphet(BaseProphetForecaster):
feature_transformer : BaseTransformer or None, optional, default=None
A transformer to preprocess the exogenous features.
exogenous_effects : list of AbstractEffect, optional, default=None
exogenous_effects : list of AbstractEffect or None, optional, default=None
A list defining the exogenous effects to be used in the model.
default_effect : AbstractEffect, optional, default=None
default_effect : AbstractEffect or None, optional, default=None
The default effect to be used when no effect is specified for a variable.
shared_features : list, optional, default=[]
List of shared features across series.
List of features shared across all series in the hierarchy.
mcmc_samples : int, optional, default=2000
Number of MCMC samples to draw.
Expand All @@ -86,7 +87,7 @@ class HierarchicalProphet(BaseProphetForecaster):
optimizer_name : str, optional, default='Adam'
Name of the optimizer to use.
optimizer_kwargs : dict, optional, default={'step_size': 1e-4}
optimizer_kwargs : dict or None, optional, default={'step_size': 1e-4}
Additional keyword arguments for the optimizer.
optimizer_steps : int, optional, default=100_000
Expand All @@ -98,14 +99,38 @@ class HierarchicalProphet(BaseProphetForecaster):
correlation_matrix_concentration : float, optional, default=1.0
Concentration parameter for the correlation matrix.
rng_key : jax.random.PRNGKey, optional, default=None
rng_key : jax.random.PRNGKey or None, optional, default=None
Random number generator key.
Examples
--------
>>> from sktime.forecasting.naive import NaiveForecaster
>>> from sktime.transformations.hierarchical.aggregate import Aggregator
>>> from sktime.utils._testing.hierarchical import _bottom_hier_datagen
>>> from prophetverse.sktime.multivariate import HierarchicalProphet
>>> agg = Aggregator()
>>> y = _bottom_hier_datagen(
... no_bottom_nodes=3,
... no_levels=1,
... random_seed=123,
... length=7,
... )
>>> y = agg.fit_transform(y)
>>> forecaster = HierarchicalProphet()
>>> forecaster.fit(y)
>>> forecaster.predict(fh=[1])
"""

_tags = {
"scitype:y": "univariate", # which y are fine? univariate/multivariate/both
"ignores-exogeneous-X": False, # does estimator ignore the exogeneous X?
"handles-missing-data": False, # can estimator handle missing data?
# packaging info
# --------------
"authors": "felipeangelimvieira",
"maintainers": "felipeangelimvieira",
"python_dependencies": "prophetverse",
# estimator type
"scitype:y": "univariate",
"ignores-exogeneous-X": False,
"handles-missing-data": False,
"y_inner_mtype": [
"pd.DataFrame",
"pd-multiindex",
Expand All @@ -116,9 +141,8 @@ class HierarchicalProphet(BaseProphetForecaster):
"pd-multiindex",
"pd_multiindex_hier",
], # which types do _fit, _predict, assume for X?
"requires-fh-in-fit": False, # is forecasting horizon already required in fit?
"X-y-must-have-same-index": False, # can estimator handle different X/y index?
"enforce_index_type": None, # index type that needs to be enforced in X/y
"requires-fh-in-fit": False,
"X-y-must-have-same-index": False,
"fit_is_empty": False,
"capability:pred_int": True,
"capability:pred_int:insample": True,
Expand Down
24 changes: 24 additions & 0 deletions src/prophetverse/sktime/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,30 @@ class Prophet(Prophetverse):
rng_key : jax.random.PRNGKey or None (default
Random number generator key.
Examples
--------
>>> from sktime.datasets import load_airline
>>> from sktime.forecasting.prophetverse import Prophetverse
>>> from prophetverse.effects.fourier import LinearFourierSeasonality
>>> from prophetverse.utils.regex import no_input_columns
>>> y = load_airline()
>>> model = Prophetverse(
... exogenous_effects=[
... (
... "seasonality",
... LinearFourierSeasonality(
... sp_list=[12],
... fourier_terms_list=[3],
... freq="M",
... effect_mode="multiplicative",
... ),
... no_input_columns,
... )
... ],
... )
>>> model.fit(y)
>>> model.predict(fh=[1, 2, 3])
"""

def __init__(
Expand Down

0 comments on commit 8284b2e

Please sign in to comment.