Skip to content

Commit

Permalink
refactor: code refactoring;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Jul 22, 2024
1 parent 1532cc3 commit 00a45f5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 6 additions & 4 deletions pypots/data/saving/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ def pickle_dump(data: object, path: str) -> None:
create_dir_if_not_exist(extract_parent_dir(path))
with open(path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Successfully saved to {path}")
except Exception as e:
logger.error(
f"❌ Pickling failed. No cache data saved. Please investigate the error below.\n{e}"
f"❌ Pickling failed. No cache data saved. Investigate the error below:\n{e}"
)
return None
logger.info(f"Successfully saved to {path}")

return None


def pickle_load(path: str) -> object:
Expand All @@ -58,6 +59,7 @@ def pickle_load(path: str) -> object:
with open(path, "rb") as f:
data = pickle.load(f)
except Exception as e:
logger.error(f"❌ Loading data failed. Operation aborted. See info below:\n{e}")
logger.error(f"❌ Loading data failed. Operation aborted. Investigate the error below:\n{e}"
return None

return data
5 changes: 3 additions & 2 deletions pypots/nn/modules/etsformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ def forward(self, x):
f = fft.rfftfreq(t)[self.low_freq :]

x_freq, index_tuple = self.topk_freq(x_freq)
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device)
f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device)
device = x_freq.device
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(device)
f = rearrange(f[index_tuple], "b f d -> b f () d").to(device)

return self.extrapolate(x_freq, f, t)

Expand Down

0 comments on commit 00a45f5

Please sign in to comment.