From 7b9adb7fbf8a15f810036eeaceb0a755d8f892b7 Mon Sep 17 00:00:00 2001 From: chenmouxiang Date: Tue, 17 Sep 2024 13:57:21 +0000 Subject: [PATCH 1/2] fix VisionTS context len issue --- project/benchmarks/run_visionts.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/project/benchmarks/run_visionts.py b/project/benchmarks/run_visionts.py index 5fb229e..c9579f7 100644 --- a/project/benchmarks/run_visionts.py +++ b/project/benchmarks/run_visionts.py @@ -156,12 +156,11 @@ def evaluate( best_valid_mae = float("inf") best_valid_p = 1 for periodicity in seasonality_list: - # Round context length to the integer multiples of the period - context_len = convert_context_len( + cur_context_len = convert_context_len( context_len, no_periodicity_context_len, periodicity ) model.update_config( - context_len, + cur_context_len, prediction_length, periodicity, norm_const, @@ -177,7 +176,7 @@ def evaluate( [x["target"][-prediction_length:] for x in batch] ) cur_forecast_samples = forcast_batch( - input_batch, device, context_len, align_const, model + input_batch, device, cur_context_len, align_const, model ) assert cur_forecast_samples.shape == label_batch.shape cur_mae_list.append(np.abs(label_batch - cur_forecast_samples)) @@ -198,15 +197,14 @@ def evaluate( else: periodicity = int(periodicity) - print(f"Use periodicity = {periodicity}, context len = {context_len}") - # Generate forecast samples forecast_samples = [] - context_len = convert_context_len( + cur_context_len = convert_context_len( context_len, no_periodicity_context_len, periodicity ) + print(f"Use periodicity = {periodicity}, context len = {cur_context_len}") model.update_config( - context_len, prediction_length, periodicity, norm_const, align_const + cur_context_len, prediction_length, periodicity, norm_const, align_const ) for batch in tqdm( list(batcher(test_data.input, batch_size=batch_size)), @@ -214,7 +212,7 @@ def evaluate( ): batch = [x["target"] for x in batch] cur_forecast_samples = forcast_batch( - batch, device, context_len, align_const, model + batch, device, cur_context_len, align_const, model ) forecast_samples.append(cur_forecast_samples) forecast_samples = np.concatenate(forecast_samples, axis=0) From 9b38eb36111670664177a85ebc8dc2bfbb5f0780 Mon Sep 17 00:00:00 2001 From: chenmouxiang Date: Wed, 18 Sep 2024 02:32:33 +0000 Subject: [PATCH 2/2] fix VisionTS default seasonality issue --- project/benchmarks/run_visionts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/benchmarks/run_visionts.py b/project/benchmarks/run_visionts.py index c9579f7..82bf046 100644 --- a/project/benchmarks/run_visionts.py +++ b/project/benchmarks/run_visionts.py @@ -59,7 +59,7 @@ def norm_freq_str(freq_str: str) -> str: def get_seasonality_list(freq: str) -> int: offset = pd.tseries.frequencies.to_offset(freq) - base_seasonality_list = POSSIBLE_SEASONALITIES.get(norm_freq_str(offset.name), 1) + base_seasonality_list = POSSIBLE_SEASONALITIES.get(norm_freq_str(offset.name), []) seasonality_list = [] for base_seasonality in base_seasonality_list: seasonality, remainder = divmod(base_seasonality, offset.n)