Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update evaluation code and scripts #113

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion project/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ This directory contains the code and scripts for benchmarking.

`chronos_scripts` contains the scripts to run Chronos on different datasets.

Example:
### Examples
On Monash dataset:
```
sh chronos_scripts/monash_chronos_base.sh
```

On datasets for Probabilistic forecasting:
```
sh chronos_scripts/pf_chronos_base.sh
```
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/pf_chronos_base.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=base
model_path=amazon/chronos-t5-${model_size}
for ds in electricity solar-energy walmart jena_weather istanbul_traffic turkey_power
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} --save_dir=pf_results_20 --test_setting=pf --num_samples=20
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/pf_chronos_mini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=mini
model_path=amazon/chronos-t5-${model_size}
for ds in electricity solar-energy walmart jena_weather istanbul_traffic turkey_power
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} --save_dir=pf_results_20 --test_setting=pf --num_samples=20
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/pf_chronos_small.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=small
model_path=amazon/chronos-t5-${model_size}
for ds in electricity solar-energy walmart jena_weather istanbul_traffic turkey_power
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} --save_dir=pf_results_20 --test_setting=pf --num_samples=20
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/pf_chronos_tiny.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=tiny
model_path=amazon/chronos-t5-${model_size}
for ds in electricity solar-energy walmart jena_weather istanbul_traffic turkey_power
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} --save_dir=pf_results_20 --test_setting=pf --num_samples=20
done
49 changes: 44 additions & 5 deletions project/benchmarks/run_chronos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
from functools import partial

import numpy as np
import torch
Expand All @@ -24,14 +25,30 @@
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm

from uni2ts.eval_util.data import get_gluonts_test_dataset
from uni2ts.eval_util.data import get_gluonts_test_dataset, get_lsf_test_dataset
from uni2ts.eval_util.evaluation import evaluate_forecasts
from uni2ts.eval_util.metrics import MedianMSE


def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512):
print("-" * 5, f"Evaluating {dataset}", "-" * 5)
test_data, metadata = get_gluonts_test_dataset(dataset)
def evaluate(
pipeline,
dataset,
save_path,
num_samples=20,
batch_size=512,
test_setting="monash",
pred_length=96,
):
print("-" * 5, f"Evaluating {dataset} on {test_setting} setting", "-" * 5)
if test_setting == "monash" or test_setting == "pf":
get_dataset = get_gluonts_test_dataset # for monash and pf, the prediction length can be inferred.
elif test_setting == "lsf":
get_dataset = partial(get_lsf_test_dataset, prediction_length=pred_length)
else:
raise NotImplementedError(
f"Cannot find the test setting {test_setting}. Please select from monash, pf, lsf."
)
test_data, metadata = get_dataset(dataset)
prediction_length = metadata.prediction_length

while True:
Expand Down Expand Up @@ -110,6 +127,16 @@ def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512):
"--batch_size", type=int, default=512, help="Batch size for generating samples"
)
parser.add_argument("--run_name", type=str, default="test", help="Name of the run")
parser.add_argument(
"--test_setting",
type=str,
default="monash",
choices=["monash", "lsf", "pf"],
help="Name of the test setting",
)
parser.add_argument(
"--pred_length", type=int, default=96, help="Prediction length for LSF dataset"
)

args = parser.parse_args()
# Load Chronos
Expand All @@ -122,4 +149,16 @@ def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512):
output_dir = os.path.join(args.save_dir, args.run_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
evaluate(pipeline, args.dataset, os.path.join(output_dir, f"{args.dataset}.csv"))
if args.test_setting == "lsf":
save_dir = os.path.join(output_dir, f"{args.dataset}_{args.pred_length}.csv")
else:
save_dir = os.path.join(output_dir, f"{args.dataset}.csv")
evaluate(
pipeline,
args.dataset,
save_dir,
args.num_samples,
args.batch_size,
test_setting=args.test_setting,
pred_length=args.pred_length,
)
Loading