From 00a45f52a1407594cdb2c9ea574f09dfc1be006a Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 22 Jul 2024 15:57:10 +0800 Subject: [PATCH] refactor: code refactoring; --- pypots/data/saving/pickle.py | 10 ++++++---- pypots/nn/modules/etsformer/layers.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pypots/data/saving/pickle.py b/pypots/data/saving/pickle.py index f9049b1b..4d4e9c93 100644 --- a/pypots/data/saving/pickle.py +++ b/pypots/data/saving/pickle.py @@ -32,12 +32,13 @@ def pickle_dump(data: object, path: str) -> None: create_dir_if_not_exist(extract_parent_dir(path)) with open(path, "wb") as f: pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"Successfully saved to {path}") except Exception as e: logger.error( - f"❌ Pickling failed. No cache data saved. Please investigate the error below.\n{e}" + f"❌ Pickling failed. No cache data saved. Investigate the error below:\n{e}" ) - return None - logger.info(f"Successfully saved to {path}") + + return None def pickle_load(path: str) -> object: @@ -58,6 +59,7 @@ def pickle_load(path: str) -> object: with open(path, "rb") as f: data = pickle.load(f) except Exception as e: - logger.error(f"❌ Loading data failed. Operation aborted. See info below:\n{e}") + logger.error(f"❌ Loading data failed. Operation aborted. Investigate the error below:\n{e}" return None + return data diff --git a/pypots/nn/modules/etsformer/layers.py b/pypots/nn/modules/etsformer/layers.py index 7fe446cf..1a36ed51 100644 --- a/pypots/nn/modules/etsformer/layers.py +++ b/pypots/nn/modules/etsformer/layers.py @@ -160,8 +160,9 @@ def forward(self, x): f = fft.rfftfreq(t)[self.low_freq :] x_freq, index_tuple = self.topk_freq(x_freq) - f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device) - f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device) + device = x_freq.device + f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(device) + f = rearrange(f[index_tuple], "b f d -> b f () d").to(device) return self.extrapolate(x_freq, f, t)