Skip to content

Commit

Permalink
add wandb.log scatter-parity plot in test_{cgcnn,wrenformer}.py
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 8219c43 commit 255429a
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 83 deletions.
8 changes: 3 additions & 5 deletions matbench_discovery/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import subprocess
import sys
from collections.abc import Sequence
from datetime import datetime

SLURM_KEYS = (
"job_id array_task_id array_task_count mem_per_node nodelist submit_host"
Expand Down Expand Up @@ -74,11 +73,12 @@ def slurm_submit(
# before actual job command
pre_cmd += ". /etc/profile.d/modules.sh; module load rhel8/default-amp;"

today = f"{datetime.now():%Y-%m-%d}"
os.makedirs(log_dir, exist_ok=True) # slurm fails if log_dir is missing

cmd = [
*f"sbatch --{partition=} --{account=} --{time=}".replace("'", "").split(),
*("--job-name", job_name),
*("--output", f"{log_dir}/{today}-slurm-%A{'-%a' if array else ''}.log"),
*("--output", f"{log_dir}/slurm-%A{'-%a' if array else ''}.log"),
*slurm_flags,
*("--wrap", f"{pre_cmd} python {py_file_path}".strip()),
]
Expand All @@ -104,8 +104,6 @@ def slurm_submit(
if "slurm-submit" not in sys.argv:
return slurm_vars # if not submitting slurm job, resume outside code as normal

os.makedirs(log_dir, exist_ok=True) # slurm fails if log_dir is missing

result = subprocess.run(cmd, check=True)

# after sbatch submission, exit with slurm exit code
Expand Down
10 changes: 7 additions & 3 deletions models/bowsr/test_bowsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
slurm_mem_per_node = 12000
# set large job array size for fast testing/debugging
slurm_array_task_count = 1000
slurm_max_parallel = 100
# see https://stackoverflow.com/a/55431306 for how to change array throttling
# post submission
slurm_max_parallel = 50
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
today = timestamp.split("@")[0]
energy_model = "megnet"
Expand Down Expand Up @@ -89,6 +91,9 @@
seed=42,
)
optimize_kwargs = dict(n_init=100, n_iter=100, alpha=0.026**2)
slurm_dict = dict(
slurm_max_parallel=slurm_max_parallel, slurm_max_job_time=slurm_max_job_time
)

run_params = dict(
bayes_optim_kwargs=bayes_optim_kwargs,
Expand All @@ -99,8 +104,7 @@
energy_model_version=version(energy_model),
optimize_kwargs=optimize_kwargs,
task_type=task_type,
slurm_max_job_time=slurm_max_job_time,
slurm_vars=slurm_vars | dict(slurm_max_parallel=slurm_max_parallel),
slurm_vars=slurm_vars | slurm_dict,
)
if wandb.run is None:
wandb.login()
Expand Down
77 changes: 54 additions & 23 deletions models/cgcnn/test_cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import os
from datetime import datetime
from importlib.metadata import version

import pandas as pd
import wandb
from aviary.cgcnn.data import CrystalGraphData, collate_batch
from aviary.cgcnn.model import CrystalGraphConvNet
from aviary.deploy import predict_from_wandb_checkpoints
from pymatgen.core import Structure
from pymatviz import density_scatter
from torch.utils.data import DataLoader
from tqdm import tqdm

Expand All @@ -29,28 +29,25 @@

today = f"{datetime.now():%Y-%m-%d}"
log_dir = f"{os.path.dirname(__file__)}/{today}-test"
ensemble_id = "cgcnn-e_form-ensemble-1"
run_name = f"{ensemble_id}-IS2RE"
job_name = "test-cgcnn-ensemble"

slurm_submit(
job_name=run_name,
slurm_vars = slurm_submit(
job_name=job_name,
partition="ampere",
account="LEE-SL3-GPU",
time="1:0:0",
time=(slurm_max_job_time := "2:0:0"),
log_dir=log_dir,
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
)


# %%
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
task_type = "IS2RE"
if task_type == "IS2RE":
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
elif task_type == "RS2RE":
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-cses.json.bz2"
df = pd.read_json(data_path).set_index("material_id", drop=False)
old_len = len(df)
no_init_structs = df.query("initial_structure.isnull()").index
df = df.dropna() # two missing initial structures
assert len(df) == old_len - 2

assert all(df.index == df_wbm.drop(index=no_init_structs).index)

target_col = "e_form_per_atom_mp2020_corrected"
df[target_col] = df_wbm[target_col]
Expand All @@ -60,12 +57,38 @@

df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]

filters = {
"$and": [{"created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"}}],
"display_name": {"$regex": "^cgcnn-robust"},
}
wandb.login()
runs = wandb.Api().runs(
"janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
runs = wandb.Api().runs("janosh/matbench-discovery", filters=filters)

assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {filters=}"
for idx, run in enumerate(runs):
for key, val in run.config.items():
if val == runs[0][key] or key.startswith(("slurm_", "timestamp")):
continue
raise ValueError(
f"Configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}"
)

run_params = dict(
data_path=data_path,
df=dict(shape=str(df.shape), columns=", ".join(df)),
aviary_version=version("aviary"),
ensemble_size=len(runs),
task_type=task_type,
target_col=target_col,
input_col=input_col,
filters=filters,
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
)

assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
wandb.init(
project="matbench-discovery", name=f"{job_name}-{slurm_job_id}", config=run_params
)

cg_data = CrystalGraphData(
df, task_dict={target_col: "regression"}, structure_col=input_col
Expand All @@ -82,14 +105,22 @@
data_loader=data_loader,
)

df.round(6).to_csv(f"{log_dir}/{today}-{run_name}-preds.csv", index=False)
df.to_csv(f"{log_dir}/{today}-{job_name}-preds.csv", index=False)
table = wandb.Table(dataframe=df)


# %%
print(f"{runs[0].url=}")
ax = density_scatter(
df=df.query("e_form_per_atom_mp2020_corrected < 10"),
x="e_form_per_atom_mp2020_corrected",
y="e_form_per_atom_mp2020_corrected_pred_1",
pred_col = f"{target_col}_pred_ens"
MAE = ensemble_metrics["MAE"]
R2 = ensemble_metrics["R2"]

title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
print(title)

scatter_plot = wandb.plot_table(
vega_spec_name="janosh/scatter-parity",
data_table=table,
fields=dict(x=target_col, y=pred_col, title=title),
)
# ax.figure.savefig(f"{ROOT}/tmp/{today}-{run_name}-scatter-preds.png", dpi=300)

wandb.log({"true_pred_scatter": scatter_plot})
8 changes: 6 additions & 2 deletions models/cgcnn/train_cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
# %%
epochs = 300
target_col = "formation_energy_per_atom"
run_name = f"cgcnn-robust-{target_col}"
run_name = f"train-cgcnn-robust-{target_col}"
print(f"{run_name=}")
robust = "robust" in run_name.lower()
n_ens = 10
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
today = timestamp.split("@")[0]
log_dir = f"{os.path.dirname(__file__)}/{today}-{run_name}"

slurm_submit(
slurm_vars = slurm_submit(
job_name=run_name,
partition="ampere",
account="LEE-SL3-GPU",
Expand Down Expand Up @@ -63,11 +63,13 @@

train_df, test_df = df_train_test_split(df, test_size=0.05)

print(f"{train_df.shape=}")
train_data = CrystalGraphData(train_df, task_dict={target_col: task_type})
train_loader = DataLoader(
train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch
)

print(f"{test_df.shape=}")
test_data = CrystalGraphData(test_df, task_dict={target_col: task_type})
test_loader = DataLoader(
test_data, batch_size=batch_size, shuffle=False, collate_fn=collate_batch
Expand All @@ -90,6 +92,7 @@
batch_size=batch_size,
train_df=dict(shape=str(train_data.df.shape), columns=", ".join(train_df)),
test_df=dict(shape=str(test_data.df.shape), columns=", ".join(test_df)),
slurm_vars=slurm_vars,
)


Expand All @@ -111,4 +114,5 @@
timestamp=timestamp,
train_loader=train_loader,
wandb_path="janosh/matbench-discovery",
run_params=run_params,
)
3 changes: 1 addition & 2 deletions models/m3gnet/test_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@
data_path=data_path,
m3gnet_version=version("m3gnet"),
task_type=task_type,
slurm_max_job_time=slurm_max_job_time,
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
slurm_vars=slurm_vars,
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
)
if wandb.run is None:
wandb.login()
Expand Down
29 changes: 15 additions & 14 deletions models/megnet/test_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import wandb
from megnet.utils.models import load_model
from sklearn.metrics import r2_score
from tqdm import tqdm

from matbench_discovery import ROOT
Expand Down Expand Up @@ -54,8 +55,10 @@
# %%
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
print(f"{data_path=}")
df_wbm_structs = pd.read_json(data_path).set_index("material_id")
target_col = "e_form_per_atom_mp2020_corrected"
assert target_col in df_wbm, f"{target_col=} not in {list(df_wbm)=}"

df_wbm_structs = pd.read_json(data_path).set_index("material_id")
megnet_mp_e_form = load_model(model_name := "Eform_MP_2019")


Expand All @@ -65,9 +68,9 @@
megnet_version=version("megnet"),
model_name=model_name,
task_type=task_type,
slurm_max_job_time=slurm_max_job_time,
target_col=target_col,
df=dict(shape=str(df_wbm_structs.shape), columns=", ".join(df_wbm_structs)),
slurm_vars=slurm_vars,
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
)
if wandb.run is None:
wandb.login()
Expand Down Expand Up @@ -105,26 +108,24 @@
print(f"{len(megnet_e_form_preds)=:,}")
print(f"{len(structures)=:,}")
print(f"missing: {len(structures) - len(megnet_e_form_preds):,}")
out_col = "e_form_per_atom_megnet"
df_wbm[out_col] = pd.Series(megnet_e_form_preds)
pred_col = "e_form_per_atom_megnet"
df_wbm[pred_col] = pd.Series(megnet_e_form_preds)

df_wbm[out_col].reset_index().to_csv(out_path, index=False)
df_wbm[pred_col].reset_index().to_csv(out_path, index=False)


# %%
fields = {"x": "e_form_per_atom_mp2020_corrected", "y": out_col}
cols = list(fields.values())
assert all(col in df_wbm for col in cols)

table = wandb.Table(dataframe=df_wbm[cols].reset_index())
table = wandb.Table(dataframe=df_wbm[[target_col, pred_col]].reset_index())

MAE = (df_wbm[fields["x"]] - df_wbm[fields["y"]]).abs().mean()
MAE = (df_wbm[target_col] - df_wbm[pred_col]).abs().mean()
R2 = r2_score(df_wbm[target_col], df_wbm[pred_col])
title = f"{model_name} {task_type} {MAE=:.4} {R2=:.4}"
print(title)

scatter_plot = wandb.plot_table(
vega_spec_name="janosh/scatter-parity",
data_table=table,
fields=fields,
string_fields={"title": f"{model_name} {task_type} {MAE=:.4}"},
fields=dict(x=target_col, y=pred_col, title=title),
)

wandb.log({"true_pred_scatter": scatter_plot})
5 changes: 5 additions & 0 deletions models/voronoi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
]
featurizer = MultipleFeaturizer(featurizers)


# multiprocessing seems to be the cause of OOM errors on large structures even when
# taking only small slice of the data and launching slurm jobs with --mem 100G
# Alex Dunn has been aware of this problem for a while. presumed cause: chunk of data
# (eg 50 structures) is sent to a single process, but sometimes one of those structures
# might be huge causing that process to stall. Other processes in pool can't synchronize
# at the end, effectively freezing the job
featurizer.set_n_jobs(1)
4 changes: 2 additions & 2 deletions models/voronoi/readme.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Voronoi Tessellation with matminer featurezation piped into `scikit-learn` Random Forest
# Voronoi Tessellation with `matminer` featurization piped into `scikit-learn` `RandomForestRegressor`

## OOM errors during featurization

Expand All @@ -14,4 +14,4 @@ Saving tip came from [Alex Dunn via Slack](https://berkeleytheory.slack.com/arch

## Archive

Files in `2022-10-04-rhys-voronoi.zip` received from Rhys via [Slack](https://ml-physics.slack.com/archives/DD8GBBRLN/p1664929946687049). All originals before making any changes for this project.
Files in `2022-10-04-rhys-voronoi.zip` received from Rhys via [Slack](https://ml-physics.slack.com/archives/DD8GBBRLN/p1664929946687049). They are unchanged originals.
16 changes: 8 additions & 8 deletions models/voronoi/voronoi_featurize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

today = f"{datetime.now():%Y-%m-%d}"
module_dir = os.path.dirname(__file__)
assert featurizer._n_jobs == 1, "set n_jobs=1 to avoid OOM errors"

data_name = "mp" # "mp"
if data_name == "wbm":
Expand All @@ -25,17 +24,17 @@
data_path = f"{ROOT}/data/mp/2022-09-16-mp-computed-structure-entries.json.gz"
input_col = "structure"

slurm_array_task_count = 10
slurm_array_task_count = 30
job_name = f"voronoi-features-{data_name}"
log_dir = f"{module_dir}/{today}-{job_name}"

slurm_vars = slurm_submit(
job_name=job_name,
partition="icelake-himem",
account="LEE-SL3-CPU",
time=(slurm_max_job_time := "8:0:0"),
time=(slurm_max_job_time := "12:0:0"),
array=f"1-{slurm_array_task_count}",
slurm_flags=("--mem", "30G") if data_name == "mp" else (),
slurm_flags=("--mem", "20G") if data_name == "mp" else (),
log_dir=log_dir,
)

Expand Down Expand Up @@ -66,10 +65,9 @@
# %%
run_params = dict(
data_path=data_path,
slurm_max_job_time=slurm_max_job_time,
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
input_col=input_col,
slurm_vars=slurm_vars,
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
)
if wandb.run is None:
wandb.login()
Expand All @@ -88,10 +86,12 @@

df_features = featurizer.featurize_dataframe(
df_this_job, input_col, ignore_errors=True, pbar=dict(position=0, leave=True)
).drop(columns=input_col)
)


# %%
df_features.to_csv(out_path, default_handler=as_dict_handler)
df_features[featurizer.feature_labels()].to_csv(
out_path, default_handler=as_dict_handler
)

wandb.log({"voronoi_features": wandb.Table(dataframe=df_features)})
Loading

0 comments on commit 255429a

Please sign in to comment.