Skip to content

Commit

Permalink
fix load() util for model checkpoint loading
Browse files Browse the repository at this point in the history
discovered broken for alignn_checkpoint
  • Loading branch information
janosh committed Feb 6, 2024
1 parent f33c5fd commit e6b1c5e
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def load(
if ".pkl" in file_path: # handle key='mp_patched_phase_diagram' separately
with gzip.open(cache_path, "rb") as zip_file:
return pickle.load(zip_file)
if ".pth" in file_path: # handle model checkpoints (e.g. key='alignn_checkpoint')
return cache_path

csv_ext = (".csv", ".csv.gz", ".csv.bz2")
reader = pd.read_csv if file_path.endswith(csv_ext) else pd.read_json
Expand Down

0 comments on commit e6b1c5e

Please sign in to comment.