Skip to content

Commit

Permalink
record numpy + pytorch versions in train/test scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 3130c89 commit 7988f52
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 18 deletions.
14 changes: 7 additions & 7 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ def load_train_test(


PRED_FILENAMES = {
"CGCNN": "models/cgcnn/2022-11-23-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv",
"Voronoi RF": "models/voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv",
"Wrenformer": "models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv",
"MEGNet": "models/megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv",
"M3GNet": "models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
"BOWSR MEGNet": "models/bowsr/2022-11-22-bowsr-megnet-wbm-IS2RE.csv",
"CGCNN": "cgcnn/2022-11-23-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv",
"Voronoi RF": "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv",
"Wrenformer": "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv",
"MEGNet": "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv",
"M3GNet": "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
"BOWSR MEGNet": "bowsr/2022-11-22-bowsr-megnet-wbm-IS2RE.csv",
}


Expand Down Expand Up @@ -212,7 +212,7 @@ def load_df_wbm_with_preds(

for model_name in (bar := tqdm(models, disable=not pbar)):
bar.set_description(model_name)
pattern = PRED_FILENAMES[model_name]
pattern = f"models/{PRED_FILENAMES[model_name]}"
df = glob_to_df(pattern, pbar=False, **kwargs).set_index(id_col)
dfs[model_name] = df

Expand Down
5 changes: 2 additions & 3 deletions models/bowsr/test_bowsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@
print(f"\nJob started running {timestamp}")
print(f"{data_path = }")
print(f"{out_path = }")
print(f"{version('maml') = }")
print(f"{version(energy_model) = }")


# %%
Expand All @@ -92,9 +90,10 @@
bayes_optim_kwargs=bayes_optim_kwargs,
data_path=data_path,
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
maml_version=version("maml"),
energy_model=energy_model,
maml_version=version("maml"),
energy_model_version=version(energy_model),
numpy_version=version("numpy"),
optimize_kwargs=optimize_kwargs,
task_type=task_type,
slurm_vars=slurm_vars,
Expand Down
9 changes: 8 additions & 1 deletion models/cgcnn/test_cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
# %%
if task_type == "IS2RE":
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
# or for debug
# data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json-1k-samples.bz2"
# created with:
# df = df.sample(1000)
# df.reset_index().to_json(data_path.replace(".json", "-1k-samples.json"))
input_col = "initial_structure"
elif task_type == "RS2RE":
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-computed-structure-entries.json.bz2"
Expand Down Expand Up @@ -83,6 +88,8 @@
data_path=data_path,
df=dict(shape=str(df.shape), columns=", ".join(df)),
aviary_version=version("aviary"),
numpy_version=version("numpy"),
torch_version=version("torch"),
ensemble_size=len(runs),
task_type=task_type,
target_col=target_col,
Expand Down Expand Up @@ -113,7 +120,7 @@
)

slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
df.round(4).to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv", index=False)
df.round(4).to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv")
pred_col = f"{target_col}_pred_ens"
assert pred_col in df, f"{pred_col=} not in {list(df)}"
table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index())
Expand Down
7 changes: 6 additions & 1 deletion models/cgcnn/train_cgcnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# %%
import os
from importlib.metadata import version

import pandas as pd
from aviary.cgcnn.data import CrystalGraphData, collate_batch
Expand Down Expand Up @@ -27,7 +28,8 @@
target_col = "formation_energy_per_atom"
input_col = "structure"
id_col = "material_id"
augment = 3
augment = 1 # 0 for no augmentation, n>1 means train on n perturbations of each crystal
# in the training set all assigned the same original target energy
job_name = f"train-cgcnn-robust-{augment=}{'-debug' if DEBUG else ''}"
print(f"{job_name=}")
robust = "robust" in job_name.lower()
Expand Down Expand Up @@ -100,6 +102,9 @@
run_params = dict(
data_path=data_path,
batch_size=batch_size,
aviary_version=version("aviary"),
numpy_version=version("numpy"),
torch_version=version("torch"),
train_df=dict(shape=str(train_data.df.shape), columns=", ".join(train_df)),
test_df=dict(shape=str(test_data.df.shape), columns=", ".join(test_df)),
slurm_vars=slurm_vars,
Expand Down
2 changes: 1 addition & 1 deletion models/m3gnet/test_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
f"{ROOT}/data/wbm/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
)
print(f"\nJob started running {timestamp}")
print(f"{version('m3gnet') = }")
print(f"{data_path=}")
df_wbm = pd.read_json(data_path).set_index("material_id")

Expand All @@ -74,6 +73,7 @@
run_params = dict(
data_path=data_path,
m3gnet_version=version("m3gnet"),
numpy_version=version("numpy"),
task_type=task_type,
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
slurm_vars=slurm_vars,
Expand Down
3 changes: 2 additions & 1 deletion models/megnet/test_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
run_params = dict(
data_path=data_path,
megnet_version=version("megnet"),
numpy_version=version("numpy"),
model_name=model_name,
task_type=task_type,
target_col=target_col,
Expand Down Expand Up @@ -104,7 +105,7 @@
pred_col = "e_form_per_atom_megnet"
df_wbm[pred_col] = pd.Series(megnet_e_form_preds)

df_wbm[pred_col].reset_index().round(4).to_csv(out_path, index=False)
df_wbm[pred_col].round(4).to_csv(out_path)


# %%
Expand Down
1 change: 1 addition & 0 deletions models/voronoi/train_test_voronoi_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
mp_energies_path=mp_energies_path,
scikit_learn_version=version("scikit-learn"),
matminer_version=version("matminer"),
numpy_version=version("numpy"),
model_name=model_name,
train_target_col=train_target_col,
test_target_col=test_target_col,
Expand Down
1 change: 1 addition & 0 deletions models/voronoi/voronoi_featurize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
slurm_vars=slurm_vars,
out_path=out_path,
matminer_version=version("matminer"),
numpy_version=version("numpy"),
)

wandb.init(project="matbench-discovery", name=run_name, config=run_params)
Expand Down
2 changes: 2 additions & 0 deletions models/wrenformer/test_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
data_path=data_path,
df=dict(shape=str(df.shape), columns=", ".join(df)),
aviary_version=version("aviary"),
numpy_version=version("numpy"),
torch_version=version("torch"),
ensemble_size=len(runs),
task_type=task_type,
target_col=target_col,
Expand Down
2 changes: 2 additions & 0 deletions models/wrenformer/train_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
run_params = dict(
data_path=data_path,
aviary_version=version("aviary"),
numpy_version=version("numpy"),
torch_version=version("torch"),
batch_size=batch_size,
train_df=dict(shape=train_df.shape, columns=", ".join(train_df)),
test_df=dict(shape=test_df.shape, columns=", ".join(test_df)),
Expand Down
4 changes: 3 additions & 1 deletion scripts/cumulative_clf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
fig.suptitle(title)
# fig.text(0.5, -0.08, xlabel_cumulative, ha="center", fontdict={"size": 16})
elif backend == "plotly":
fig.update_layout(title=title, matches=None)
fig.update_layout(title=title)
fig.update_xaxes(matches=None, showticklabels=True)
fig.update_yaxes(matches=None, showticklabels=True)

fig.show()

Expand Down
4 changes: 1 addition & 3 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def test_load_df_wbm_with_preds_raises() -> None:

def test_pred_filenames() -> None:
assert len(PRED_FILENAMES) >= 6
assert all(
path.startswith(("models/", "data/")) for path in PRED_FILENAMES.values()
)
assert all(path.endswith((".csv", ".json")) for path in PRED_FILENAMES.values())


@pytest.mark.parametrize("pattern", ["tmp/*df.csv", "tmp/*df.json"])
Expand Down

0 comments on commit 7988f52

Please sign in to comment.