Skip to content

Commit

Permalink
add aflow wyckoff labels to 2022-10-19-wbm-summary.csv
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 20, 2023
1 parent c1e55e1 commit 6450ebb
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 29 deletions.
2 changes: 1 addition & 1 deletion data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

df["spacegroup_number"] = df.pop("symmetry").map(lambda x: x.number)

df["wyckoff"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]
df["wyckoff_spglib"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]

df.to_json(f"{module_dir}/{today}-mp-energies.json.gz", default_handler=as_dict_handler)

Expand Down
50 changes: 33 additions & 17 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@
# %%
json_paths = sorted(glob(f"{module_dir}/raw/wbm-structures-step-*.json.bz2"))
step_lens = (61848, 52800, 79205, 40328, 23308)
# step 3 has 79,211 structures but only 79,205 ComputedStructureEntries
# step 3 has 79,211 initial structures but only 79,205 ComputedStructureEntries
# i.e. 6 extra structures which have missing energy, volume, etc. in the summary file
bad_struct_ids = (70802, 70803, 70825, 70826, 70828, 70829)
# step 5 has 2 missing initial structures: 23166, 23294


assert len(json_paths) == len(step_lens), "Mismatch in WBM steps and JSON files"
Expand Down Expand Up @@ -229,6 +230,14 @@ def increment_wbm_material_id(wbm_id: str) -> str:
).value_counts().to_dict() == {"GGA": 248481, "GGA+U": 9008}


# drop two materials with missing initial structures
assert list(df_wbm.query("initial_structure.isna()").index) == [
"wbm-step-5-23166",
"wbm-step-5-23294",
]
df_wbm = df_wbm.dropna(subset=["initial_structure"])


# %% get composition from CSEs
df_wbm["composition_from_cse"] = [
ComputedStructureEntry.from_dict(cse).composition
Expand Down Expand Up @@ -273,12 +282,13 @@ def increment_wbm_material_id(wbm_id: str) -> str:
x.alphabetical_formula for x in df_wbm.pop("composition_from_cse")
]

for key, col_name in (
("cses", "computed_structure_entry"),
("init-structs", "initial_structure"),
for fname, cols in (
("cses", ["computed_structure_entry"]),
("init-structs", ["initial_structure"]),
("cses+init-structs", ["initial_structure", "computed_structure_entry"]),
):
cols = ["initial_structure", "formula_from_cse", col_name]
df_wbm[cols].reset_index().to_json(f"{module_dir}/{today}-wbm-{key}.json.bz2")
cols = ["formula_from_cse", *cols]
df_wbm[cols].reset_index().to_json(f"{module_dir}/{today}-wbm-{fname}.json.bz2")


# %%
Expand Down Expand Up @@ -486,26 +496,32 @@ def increment_wbm_material_id(wbm_id: str) -> str:
f"{module_dir}/2022-10-19-wbm-cses+init-structs.json.bz2"
).set_index("material_id")

df_init_struct = pd.read_json(
f"{module_dir}/2022-10-19-wbm-init-structs.json.bz2"
).set_index("material_id")

df_wbm["cse"] = [
ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry)
]


# %%
df_wbm["init_struct"] = [
Structure.from_dict(x) if x else None for x in tqdm(df_wbm.initial_structure)
]
df_init_struct = pd.read_json(
f"{module_dir}/2022-10-19-wbm-init-structs.json.bz2"
).set_index("material_id")

wyckoff_col = "wyckoff_spglib"
for idx, struct in tqdm(df_wbm.init_struct.items(), total=len(df_wbm)):
if struct is None:
if wyckoff_col not in df_init_struct:
df_init_struct[wyckoff_col] = None

for idx, struct in tqdm(
df_init_struct.initial_structure.items(), total=len(df_init_struct)
):
if not pd.isna(df_summary.at[idx, wyckoff_col]):
continue
if not df_wbm.at[idx, wyckoff_col]:
df_wbm.at[idx, wyckoff_col] = get_aflow_label_from_spglib(struct)
try:
struct = Structure.from_dict(struct)
df_summary.at[idx, wyckoff_col] = get_aflow_label_from_spglib(struct)
except Exception as exc:
print(f"{idx=} {exc=}")

assert df_summary[wyckoff_col].isna().sum() == 0


# %% make sure material IDs within each step are consecutive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,4 @@
ax.legend(loc="center left", frameon=False)

fig_name = f"wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}"
img_path = f"{ROOT}/figures/{today}-{fig_name}.pdf"
# fig.savefig(img_path)
# fig.savefig(f"{ROOT}/figures/{today}-{fig_name}.pdf")
5 changes: 1 addition & 4 deletions matbench_discovery/plot_scripts/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@


# %%
rare = "all"

dfs: dict[str, pd.DataFrame] = {}
for model_name in ("wren", "cgcnn", "voronoi"):
csv_path = (
Expand Down Expand Up @@ -118,7 +116,6 @@
ax.set(xlim=(0, None))


img_name = f"{today}-precision-recall-vs-calc-count-{rare=}"
# x-ticks every 10k materials
# ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000))

Expand All @@ -128,4 +125,4 @@


# %%
fig.savefig(f"{ROOT}/figures/{img_name}.pdf")
# fig.savefig(f"{ROOT}/figures/{today}-precision-recall-curves.pdf")
6 changes: 3 additions & 3 deletions models/wrenformer/mp/use_wrenformer_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@

# %%
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
df = pd.read_csv(data_path).set_index("material_id")
target_col = "e_form_per_atom_mp2020_corrected"
input_col = "wyckoff_spglib"
df = pd.read_csv(data_path).dropna(subset=input_col).set_index("material_id")

target_col = "e_form_per_atom"
input_col = "wyckoff"
assert target_col in df, f"{target_col=} not in {list(df)}"
assert input_col in df, f"{input_col=} not in {list(df)}"

Expand Down
7 changes: 5 additions & 2 deletions models/wrenformer/slurm_train_wrenformer_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from datetime import datetime

import pandas as pd
from aviary import ROOT
from aviary.train import df_train_test_split, train_wrenformer

from matbench_discovery import ROOT
from matbench_discovery.slurm import slurm_submit_python

"""
Expand Down Expand Up @@ -45,13 +45,15 @@
batch_size = 128
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
input_col = "wyckoff_spglib"

print(f"Job started running {timestamp}")
print(f"{run_name=}")
print(f"{data_path=}")

df = pd.read_json(data_path).set_index("material_id", drop=False)
assert target_col in df
assert target_col in df, f"{target_col=} not in {list(df)}"
assert input_col in df, f"{input_col=} not in {list(df)}"
train_df, test_df = df_train_test_split(df, test_size=0.3)

run_params = dict(
Expand All @@ -70,6 +72,7 @@
# folds=(n_folds, slurm_array_task_id),
epochs=epochs,
checkpoint="wandb", # None | 'local' | 'wandb',
input_col=input_col,
learning_rate=learning_rate,
batch_size=batch_size,
wandb_path="janosh/matbench-discovery",
Expand Down

0 comments on commit 6450ebb

Please sign in to comment.