Skip to content

Commit

Permalink
add MP (train set) elem count heatmap for comparison with WBM (test s…
Browse files Browse the repository at this point in the history
…et) heatmap
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 89a33b3 commit 3130c89
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 77 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ repos:
- prettier
- prettier-plugin-svelte
- svelte
exclude: ^(site/figures/|data/wbm/20).*$
exclude: ^(site/figures/|data/wbm/20)*.*$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.31.0
Expand Down
95 changes: 95 additions & 0 deletions data/wbm/eda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# %%
import os

import pandas as pd
from pymatviz import count_elements, ptable_heatmap_plotly

from matbench_discovery import ROOT, today

module_dir = os.path.dirname(__file__)

"""
Compare MP and WBM elemental prevalence. Starting with WBM, MP below.
"""


# %%
df_summary = pd.read_csv(f"{module_dir}/2022-10-19-wbm-summary.csv").set_index(
"material_id"
)
elem_counts = count_elements(df_summary.formula).astype(int)

elem_counts.to_json(
f"{ROOT}/site/src/routes/about-the-test-set/{today}-wbm-element-counts.json"
)


# %%
fig = ptable_heatmap_plotly(
elem_counts,
log=True,
colorscale="YlGnBu",
hover_props=dict(atomic_number="atomic number"),
hover_data=elem_counts,
font_size="1vw",
)

title = "WBM Elements"
fig.update_layout(
title=dict(text=title, x=0.35, y=0.9, font_size=20),
xaxis=dict(fixedrange=True),
yaxis=dict(fixedrange=True),
paper_bgcolor="rgba(0,0,0,0)",
)
fig.show()


# %%
fig.write_image(f"{module_dir}/{today}-wbm-elements.svg", width=1000, height=500)
# fig.write_html(
# f"{module_dir}/{today}-wbm-elements.svelte",
# include_plotlyjs=False,
# full_html=False,
# config=dict(showTips=False, displayModeBar=False, responsive=True),
# )


# %% load MP training set
df = pd.read_json(f"{module_dir}/../mp/2022-08-13-mp-energies.json.gz")
elem_counts = count_elements(df.formula_pretty).astype(int)

elem_counts.to_json(
f"{ROOT}/site/src/routes/about-the-test-set/{today}-mp-element-counts.json"
)
elem_counts.describe()


# %%
fig = ptable_heatmap_plotly(
elem_counts[elem_counts > 1],
log=True,
colorscale="YlGnBu",
hover_props=dict(atomic_number="atomic number"),
hover_data=elem_counts,
font_size="1vw",
)

title = "MP Elements"
fig.update_layout(
title=dict(text=title, x=0.35, y=0.9, font_size=20),
xaxis=dict(fixedrange=True),
yaxis=dict(fixedrange=True),
paper_bgcolor="rgba(0,0,0,0)",
)
fig.show()


# %%
fig.write_image(f"{module_dir}/{today}-mp-elements.svg", width=1000, height=500)

# fig.write_html(
# f"{module_dir}/{today}-mp-elements.svelte",
# include_plotlyjs=False,
# full_html=False,
# config=dict(showTips=False, displayModeBar=False, responsive=True),
# )
39 changes: 1 addition & 38 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MaterialsProjectCompatibility as MPLegacyCompat,
)
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatviz import count_elements, density_scatter, ptable_heatmap_plotly
from pymatviz import density_scatter
from tqdm import tqdm

from matbench_discovery import ROOT, today
Expand Down Expand Up @@ -628,40 +628,3 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
df_wbm["cse"] = [
ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry)
]


# %%
elem_counts = count_elements(df_summary.formula).astype(int)

elem_counts.to_json(
f"{ROOT}/site/src/routes/about-the-test-set/{today}-wbm-element-counts.json"
)


# %%
fig = ptable_heatmap_plotly(
elem_counts,
log=True,
colorscale="YlGnBu",
hover_props=dict(atomic_number="atomic number"),
hover_data=elem_counts,
font_size="1vw",
)

title = "WBM Elements (log color scale)"
fig.update_layout(
title=dict(text=title, x=0.35, y=0.9, font_size=20),
xaxis=dict(fixedrange=True),
yaxis=dict(fixedrange=True),
)
fig.show()


# %%
fig.write_image(f"{module_dir}/{today}-wbm-elements-log.svg", width=1000, height=500)
fig.write_html(
f"{module_dir}/{today}-wbm-elements-log.svelte",
include_plotlyjs=False,
full_html=False,
config=dict(showTips=False, displayModeBar=False, responsive=True),
)
13 changes: 10 additions & 3 deletions data/wbm/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,14 @@ materialscloud:2021.68 includes a readme file with a description of the dataset,

## 📊   Data Plots

<caption>Heatmap of elemental prevalence in WBM dataset.</caption>
<slot name="wbm-elements-log">
<img src="./2022-12-30-wbm-elements-log.svg" alt="Periodic table log heatmap of WBM elements">
<caption>Heatmap of WBM training set element counts</caption>
<slot name="wbm-elements-heatmap">
<img src="./2023-01-08-wbm-elements.svg" alt="Periodic table log heatmap of WBM elements">
</slot>

which compares as follows to the training set (all 146323 MP ComputedStructureEntries)

<caption>Heatmap of MP test set element counts</caption>
<slot name="mp-elements-heatmap">
<img src="./2023-01-08-mp-elements.svg" alt="Periodic table log heatmap of MP elements">
</slot>
8 changes: 4 additions & 4 deletions site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
"devDependencies": {
"@iconify/svelte": "^3.0.1",
"@rollup/plugin-yaml": "^4.0.1",
"@sveltejs/adapter-static": "1.0.0",
"@sveltejs/kit": "1.0.1",
"@sveltejs/adapter-static": "1.0.1",
"@sveltejs/kit": "1.0.7",
"@sveltejs/vite-plugin-svelte": "^2.0.2",
"@typescript-eslint/eslint-plugin": "^5.48.0",
"@typescript-eslint/parser": "^5.48.0",
Expand All @@ -28,7 +28,7 @@
"hastscript": "^7.2.0",
"highlight.js": "^11.7.0",
"mdsvex": "^0.10.6",
"prettier": "^2.8.1",
"prettier": "^2.8.2",
"prettier-plugin-svelte": "^2.9.0",
"rehype-autolink-headings": "^6.1.1",
"rehype-slug": "^5.1.0",
Expand All @@ -38,7 +38,7 @@
"svelte-preprocess": "^5.0.0",
"svelte-toc": "^0.5.1",
"svelte2tsx": "^0.6.0",
"sveriodic-table": "^0.1.2",
"sveriodic-table": "^0.1.4",
"tslib": "^2.4.1",
"typescript": "^4.9.4",
"vite": "^4.0.4"
Expand Down
1 change: 1 addition & 0 deletions site/src/app.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
/// <reference types="mdsvex/globals" />

declare module '*.md'
declare module '*package.json'
43 changes: 31 additions & 12 deletions site/src/routes/about-the-test-set/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
import type { ChemicalElement } from 'sveriodic-table'
import { PeriodicTable, TableInset, Toggle } from 'sveriodic-table'
import { pretty_num } from 'sveriodic-table/labels'
import elem_counts from './2022-12-30-wbm-element-counts.json'
import mp_elem_counts from './2023-01-08-mp-element-counts.json'
import wbm_elem_counts from './2023-01-08-wbm-element-counts.json'
let log_color_scale = false
const heatmap_values: number[] = Object.values(elem_counts)
let log = false // log color scale
const wbm_heat_vals: number[] = Object.values(wbm_elem_counts)
const mp_heat_vals: number[] = Object.values(mp_elem_counts)
const color_map = {
200: `blue`,
35_000: `green`,
80_000: `yellow`,
150_000: `red`,
}
let active_element: ChemicalElement
let active_mp_elem: ChemicalElement
let active_wbm_elem: ChemicalElement
</script>

<DataReadme>
Expand All @@ -23,17 +26,33 @@
<FormEnergyHist />
{/if}
</svelte:fragment>
<svelte:fragment slot="wbm-elements-log">
<span>Log color scale <Toggle bind:checked={log_color_scale} /></span>
<PeriodicTable {heatmap_values} {color_map} log={log_color_scale} bind:active_element>
<svelte:fragment slot="wbm-elements-heatmap">
<span>Log color scale <Toggle bind:checked={log} /></span>
<PeriodicTable heatmap_values={wbm_heat_vals} {color_map} {log} bind:active_element={active_wbm_elem}>
<TableInset slot="inset" grid_row="3">
{#if active_element?.name}
{#if active_wbm_elem?.name}
<strong>
{active_element?.name}: {pretty_num(elem_counts[active_element?.symbol])}
{active_wbm_elem?.name}: {pretty_num(wbm_elem_counts[active_wbm_elem?.symbol])}
<!-- compute percent of total -->
{#if elem_counts[active_element?.symbol] > 0}
{@const total = heatmap_values.reduce((a, b) => a + b, 0)}
({pretty_num((elem_counts[active_element?.symbol] / total) * 100)}%)
{#if wbm_elem_counts[active_wbm_elem?.symbol] > 0}
{@const total = wbm_heat_vals.reduce((a, b) => a + b, 0)}
({pretty_num((wbm_elem_counts[active_wbm_elem?.symbol] / total) * 100)}%)
{/if}
</strong>
{/if}
</TableInset>
</PeriodicTable>
</svelte:fragment>
<svelte:fragment slot="mp-elements-heatmap">
<PeriodicTable heatmap_values={mp_heat_vals} {color_map} {log} bind:active_element={active_mp_elem}>
<TableInset slot="inset" grid_row="3">
{#if active_mp_elem?.name}
<strong>
{active_mp_elem?.name}: {pretty_num(wbm_elem_counts[active_mp_elem?.symbol])}
<!-- compute percent of total -->
{#if wbm_elem_counts[active_mp_elem?.symbol] > 0}
{@const total = wbm_heat_vals.reduce((a, b) => a + b, 0)}
({pretty_num((wbm_elem_counts[active_mp_elem?.symbol] / total) * 100)}%)
{/if}
</strong>
{/if}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"H":23584,"He":0,"Li":14313,"Be":1989,"B":21858,"C":13678,"N":27918,"O":150013,"F":57810,"Ne":0,"Na":13573,"Mg":25502,"Al":55485,"Si":47749,"P":25924,"S":43214,"Cl":26071,"Ar":0,"K":17952,"Ca":21697,"Sc":18533,"Ti":15680,"V":12107,"Cr":15394,"Mn":33133,"Fe":44576,"Co":38020,"Ni":47391,"Cu":43476,"Zn":32173,"Ga":45347,"Ge":49788,"As":27578,"Se":42190,"Br":22704,"Kr":0,"Rb":16884,"Sr":22328,"Y":19544,"Zr":18331,"Nb":13125,"Mo":3942,"Tc":1636,"Ru":26488,"Rh":38467,"Pd":37770,"Ag":15880,"Cd":18170,"In":44110,"Sn":43997,"Sb":20297,"Te":27842,"I":16387,"Xe":2,"Cs":14702,"Ba":21181,"La":18720,"Ce":16002,"Pr":17375,"Nd":17224,"Pm":1462,"Sm":17223,"Eu":10483,"Gd":9462,"Tb":19230,"Dy":18638,"Ho":17874,"Er":17707,"Tm":18983,"Yb":19848,"Lu":12650,"Hf":15081,"Ta":11155,"W":3761,"Re":4180,"Os":13551,"Ir":27121,"Pt":38649,"Au":28489,"Hg":10508,"Tl":16972,"Pb":20005,"Bi":12823,"Po":0,"At":0,"Rn":0,"Fr":0,"Ra":0,"Ac":1863,"Th":17945,"Pa":4051,"U":14301,"Np":10177,"Pu":13117,"Am":0,"Cm":0,"Bk":0,"Cf":0,"Es":0,"Fm":0,"Md":0,"No":0,"Lr":0,"Rf":0,"Db":0,"Sg":0,"Bh":0,"Hs":0,"Mt":0,"Ds":0,"Rg":0,"Cn":0,"Nh":0,"Fl":0,"Mc":0,"Lv":0,"Ts":0,"Og":0}
38 changes: 19 additions & 19 deletions site/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@ import { exec } from 'child_process'
import { resolve } from 'path'
import type { UserConfig } from 'vite'

if (process.env.PROD) {
// update generated API docs on production builds
const src_url = `https://github.com/janosh/matbench-discovery/blob/main`
const route = `src/routes/api`
await exec(`rm -f ${route}/*.md`)
await exec(
`cd .. && lazydocs matbench_discovery --output-path site/${route} --no-watermark --src-base-url ${src_url}`
)

// remove <b> tags from generated markdown
await exec(`sed -i 's/<b>//g' ${route}/*.md`)
await exec(`sed -i 's/<\\/b>//g' ${route}/*.md`)
// tweak look of badges linking to source code
const old_badge = `src="https://img.shields.io/badge/-source-cccccc?style=flat-square"`
const new_badge = `src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"`
await exec(`sed -i 's/${old_badge}/${new_badge}/g' ${route}/*.md`)
}

const vite_config: UserConfig = {
plugins: [sveltekit(), yaml()],

Expand All @@ -16,7 +34,7 @@ const vite_config: UserConfig = {
},

server: {
fs: { allow: [`../..`] }, // needed to import readme.md
fs: { allow: [`../..`] }, // needed to import from $root
port: 3000,
},

Expand All @@ -26,21 +44,3 @@ const vite_config: UserConfig = {
}

export default vite_config

if (process.env.PROD) {
// update generated API docs on production builds
const src_url = `https://github.com/janosh/matbench-discovery/blob/main`
const route = `src/routes/api`
await exec(`rm -f ${route}/*.md`)
await exec(
`cd .. && lazydocs matbench_discovery --output-path site/${route} --no-watermark --src-base-url ${src_url}`
)

// remove <b> tags from generated markdown
await exec(`sed -i 's/<b>//g' ${route}/*.md`)
await exec(`sed -i 's/<\\/b>//g' ${route}/*.md`)
// tweak look of badges linking to source code
const old_src = `src="https://img.shields.io/badge/-source-cccccc?style=flat-square"`
const new_src = `src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"`
await exec(`sed -i 's/${old_src}/${new_src}/g' ${route}/*.md`)
}

0 comments on commit 3130c89

Please sign in to comment.