From f959ee8a903148d4d2aa1bf35aa0aaae797ef5e8 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 6 Dec 2023 07:38:38 -0800 Subject: [PATCH] use SymLogNorm for MP+WBM+MPtrj ptable element count heatmaps --- .pre-commit-config.yaml | 8 ++++---- data/mp/eda_mp_trj.py | 22 +++++++++++----------- data/wbm/eda_wbm.py | 13 ++++++------- matbench_discovery/preds.py | 7 ++++--- models/chgnet/join_chgnet_results.py | 5 +++-- scripts/model_figs/make_metrics_tables.py | 11 ++++++----- 6 files changed, 34 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 133ee7c0..8483c6d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.5 + rev: v0.1.7 hooks: - id: ruff args: [--fix] @@ -30,7 +30,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.7.0 + rev: v1.7.1 hooks: - id: mypy additional_dependencies: [types-pyyaml, types-requests] @@ -45,7 +45,7 @@ repos: args: [--ignore-words-list, "nd,te,fpr", --check-filenames] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.3 + rev: v4.0.0-alpha.3 hooks: - id: prettier args: [--write] # edit files in-place @@ -56,7 +56,7 @@ repos: exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yaml|json)|changelog.md)$ - repo: https://github.com/pre-commit/mirrors-eslint - rev: v8.53.0 + rev: v8.55.0 hooks: - id: eslint types: [file] diff --git a/data/mp/eda_mp_trj.py b/data/mp/eda_mp_trj.py index d3e7abc6..6ff76efb 100644 --- a/data/mp/eda_mp_trj.py +++ b/data/mp/eda_mp_trj.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd import plotly.express as px +from matplotlib.colors import SymLogNorm from pymatgen.core import Composition from pymatviz import count_elements, ptable_heatmap, ptable_heatmap_ratio, ptable_hists from pymatviz.io import save_fig @@ -213,28 +214,27 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: # %% -count_mode = "composition" -if "trj_elem_counts" not in locals(): - trj_elem_counts = pd.read_json( - f"{data_page}/mp-trj-element-counts-by-{count_mode}.json", - typ="series", - ) +count_mode = "occurrence" +trj_elem_counts = pd.read_json( + f"{data_page}/mp-trj-element-counts-by-{count_mode}.json", typ="series" +) excl_elems = "He Ne Ar Kr Xe".split() if (excl_noble := False) else () ax_ptable = ptable_heatmap( # matplotlib version looks better for SI trj_elem_counts, - fmt=lambda x, _: si_fmt(x, ".0f"), - cbar_fmt=lambda x, _: si_fmt(x, ".0f"), zero_color="#efefef", - log=(log := True), + log=(log := SymLogNorm(linthresh=10_000)), exclude_elements=excl_elems, # drop noble gases - cbar_range=None if excl_noble else (10_000, None), + # cbar_range=None if excl_noble else (10_000, None), label_font_size=17, value_font_size=14, + cbar_title="MPtrj Element Counts", ) -img_name = f"mp-trj-element-counts-by-{count_mode}{'-log' if log else ''}" +img_name = f"mp-trj-element-counts-by-{count_mode}" +if log: + img_name += "-symlog" if isinstance(log, SymLogNorm) else "-log" if excl_noble: img_name += "-excl-noble" save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf") diff --git a/data/wbm/eda_wbm.py b/data/wbm/eda_wbm.py index 83bad373..46996a62 100644 --- a/data/wbm/eda_wbm.py +++ b/data/wbm/eda_wbm.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import plotly.express as px +from matplotlib.colors import SymLogNorm from pymatgen.core import Composition from pymatviz import ( count_elements, @@ -64,13 +65,9 @@ # %% -log = True for dataset, count_mode, elem_counts in all_counts: filename = f"{dataset}-element-counts-by-{count_mode}" - if log: - filename += "-log" - else: - elem_counts.to_json(f"{data_page}/{filename}.json") + elem_counts.to_json(f"{data_page}/{filename}.json") title = f"Number of {dataset.upper()} structures containing each element" fig = ptable_heatmap_plotly(elem_counts, font_size=10) @@ -85,9 +82,11 @@ label_font_size=17, value_font_size=14, cbar_title=f"{dataset.upper()} Element Count", - log=log, - cbar_range=(100, None), + log=(log := SymLogNorm(linthresh=100)), + # cbar_range=(100, None), ) + if log: + filename += "-symlog" if isinstance(log, SymLogNorm) else "-log" save_fig(ax_mp_cnt, f"{PDF_FIGS}/{filename}.pdf") diff --git a/matbench_discovery/preds.py b/matbench_discovery/preds.py index 4d21aa8f..c9ddcfc4 100644 --- a/matbench_discovery/preds.py +++ b/matbench_discovery/preds.py @@ -145,9 +145,10 @@ def load_df_wbm_with_preds( else: cols = list(df) - raise ValueError( - f"No pred col for {model_name=} ({model_key=}), available {cols=}" - ) + msg = f"No pred col for {model_name=}, available {cols=}" + if model_name != model_key: + msg = msg.replace(", ", f" ({model_key=}), ") + raise ValueError(msg) return df_out diff --git a/models/chgnet/join_chgnet_results.py b/models/chgnet/join_chgnet_results.py index ce5d39af..5d3787cd 100644 --- a/models/chgnet/join_chgnet_results.py +++ b/models/chgnet/join_chgnet_results.py @@ -53,12 +53,13 @@ # %% compute corrected formation energies -e_form_chgnet_col = "e_form_per_atom_chgnet" +e_pred_col = "chgnet_energy" +e_form_chgnet_col = f"e_form_per_atom_{e_pred_col.split('_energy')[0]}" df_chgnet[formula_col] = df_preds[formula_col] df_chgnet[e_form_chgnet_col] = [ get_e_form_per_atom(dict(energy=ene, composition=formula)) for formula, ene in tqdm( - df_chgnet.set_index(formula_col).chgnet_energy.items(), total=len(df_chgnet) + df_chgnet.set_index(formula_col)[e_pred_col].items(), total=len(df_chgnet) ) ] df_preds[e_form_chgnet_col] = df_chgnet[e_form_chgnet_col] diff --git a/scripts/model_figs/make_metrics_tables.py b/scripts/model_figs/make_metrics_tables.py index 64b4f91c..328bc225 100644 --- a/scripts/model_figs/make_metrics_tables.py +++ b/scripts/model_figs/make_metrics_tables.py @@ -37,12 +37,12 @@ n_structs = MODEL_METADATA[model_name]["training_set"]["n_structures"] n_materials = MODEL_METADATA[model_name]["training_set"].get("n_materials") - formatted = si_fmt(n_structs) + n_structs_fmt = si_fmt(n_structs) if n_materials: - formatted += f" ({si_fmt(n_materials)})" + n_structs_fmt += f" ({si_fmt(n_materials)})" - df_metrics.loc[train_size_col, model] = formatted - df_metrics_10k.loc[train_size_col, model] = formatted + df_metrics.loc[train_size_col, model] = n_structs_fmt + df_metrics_10k.loc[train_size_col, model] = n_structs_fmt # %% add dummy classifier results to df_metrics @@ -157,6 +157,7 @@ cmap="viridis_r", subset=list(lower_is_better & {*df_filtered}) ) ) + # add up/down arrows to indicate which metrics are better when higher/lower arrow_suffix = dict.fromkeys(higher_is_better, " ↑") | dict.fromkeys( lower_is_better, " ↓" ) @@ -182,7 +183,7 @@ f"{SITE_FIGS}/metrics-table{label}.svelte", inline_props="class='roomy'", # draw dotted line between classification and regression metrics - styles=f"{col_selector} {{ border-left: 1px dotted white; }}{hide_scroll_bar}", + styles=f"{col_selector} {{ border-left: 2px dotted white; }}{hide_scroll_bar}", ) try: df_to_pdf(styler, f"{PDF_FIGS}/metrics-table{label}.pdf")