From e6b1c5ecf833b322979c7d238859d509caf07c74 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 6 Feb 2024 11:17:24 +0100 Subject: [PATCH] fix load() util for model checkpoint loading discovered broken for alignn_checkpoint --- matbench_discovery/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/matbench_discovery/data.py b/matbench_discovery/data.py index ad125d44..05e381f7 100644 --- a/matbench_discovery/data.py +++ b/matbench_discovery/data.py @@ -115,6 +115,8 @@ def load( if ".pkl" in file_path: # handle key='mp_patched_phase_diagram' separately with gzip.open(cache_path, "rb") as zip_file: return pickle.load(zip_file) + if ".pth" in file_path: # handle model checkpoints (e.g. key='alignn_checkpoint') + return cache_path csv_ext = (".csv", ".csv.gz", ".csv.bz2") reader = pd.read_csv if file_path.endswith(csv_ext) else pd.read_json