Skip to content

Commit

Permalink
add df metadata to slurm_array_{m3gnet,bowsr}_wbm.py run_params
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 444fb7f commit f127a9e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 19 deletions.
19 changes: 10 additions & 9 deletions models/bowsr/slurm_array_bowsr_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@
raise SystemExit(f"{out_path = } already exists, exciting early")


# %%
print(f"Loading from {data_path = }")
df_wbm = pd.read_json(data_path).set_index("material_id")

df_this_job: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
slurm_array_task_id - 1
]


# %%
bayes_optim_kwargs = dict(
relax_coords=True,
Expand All @@ -83,6 +92,7 @@
run_params = dict(
bayes_optim_kwargs=bayes_optim_kwargs,
data_path=data_path,
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
maml_version=version("maml"),
megnet_version=version("megnet"),
optimize_kwargs=optimize_kwargs,
Expand All @@ -104,15 +114,6 @@
)


# %%
print(f"Loading from {data_path = }")
df_wbm = pd.read_json(data_path).set_index("material_id")

df_this_job: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[
slurm_array_task_id - 1
]


# %%
model = MEGNet()
relax_results: dict[str, dict[str, Any]] = {}
Expand Down
4 changes: 2 additions & 2 deletions models/cgcnn/slurm_train_cgcnn_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@
run_params = dict(
data_path=data_path,
batch_size=batch_size,
train_df=dict(shape=train_data.df.shape, columns=", ".join(train_df)),
test_df=dict(shape=test_data.df.shape, columns=", ".join(test_df)),
train_df=dict(shape=str(train_data.df.shape), columns=", ".join(train_df)),
test_df=dict(shape=str(test_data.df.shape), columns=", ".join(test_df)),
)


Expand Down
1 change: 1 addition & 0 deletions models/m3gnet/slurm_array_m3gnet_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
slurm_array_task_count=slurm_array_task_count,
task_type=task_type,
slurm_max_job_time=slurm_max_job_time,
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
**slurm_vars,
)
if wandb.run is None:
Expand Down
24 changes: 16 additions & 8 deletions models/wrenformer/mp/use_wrenformer_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@
assert target_col in df, f"{target_col=} not in {list(df)}"
assert input_col in df, f"{input_col=} not in {list(df)}"

wandb.login()
wandb_api = wandb.Api()
runs = wandb_api.runs(
"janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
)

assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"

data_loader = df_to_in_mem_dataloader(
df=df,
target_col=target_col,
Expand All @@ -63,6 +55,22 @@
shuffle=False, # False is default but best be explicit
)


# %%
wandb.login()
wandb_api = wandb.Api()
runs = wandb_api.runs(
"janosh/matbench-discovery",
filters={
"$and": [{"created_at": {"$gt": "2022-11-10", "$lt": "2022-11-11"}}],
"display_name": "wrenformer-robust-mp-formation_energy_per_atom-epochs=300",
},
)

assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"


# %%
df, ensemble_metrics = predict_from_wandb_checkpoints(
runs, data_loader=data_loader, df=df, model_cls=Wrenformer
)
Expand Down

0 comments on commit f127a9e

Please sign in to comment.