diff --git a/nbs/index.ipynb b/nbs/index.ipynb index dcfb7e139..9a620a91f 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -123,8 +123,8 @@ "\n", "# Fit and predict with NBEATS and NHITS models\n", "horizon = len(Y_test_df)\n", - "models = [NBEATS(input_size=2 * horizon, h=horizon, max_epochs=50),\n", - " NHITS(input_size=2 * horizon, h=horizon, max_epochs=50)]\n", + "models = [NBEATS(input_size=2 * horizon, h=horizon, max_steps=50),\n", + " NHITS(input_size=2 * horizon, h=horizon, max_steps=50)]\n", "nf = NeuralForecast(models=models, freq='M')\n", "nf.fit(df=Y_train_df)\n", "Y_hat_df = nf.predict().reset_index()\n",