Skip to content

Commit a7f693d

Browse files
authored
Add metrics analysis excluding WBM materials with duplicate/MP-matching structure prototype (#75)
* bump requires-python = ">=3.(9->11)" * WBM df_summary add column wyckoff_spglib from DFT-relaxed structs after renaming prev wyk col to wyckoff_spglib_initial_structure * add wbm_summary col Key.uniq_proto indicating WBM materials with matching prototype in MP or duplicate prototypes in WBM (keeping only the lowest energy one) * drop mace_checkpoint(,1,2) from DataFiles, update wbm_summary = "wbm/2023-12-13-wbm-summary.csv.gz" * add site/src/figs/metrics-table-uniq-protos.svelte add table captions, change clf/regr metrics separator line from dotted to solid * fix scripts/upload_to_figshare.py FileNotFoundError in file upload loop fix IPython.display unimportable in CI breaking: rename mbd.data.load() 1st arg data_key->key * update wbm_summary figshare URLs, remove outdated MACE checkpoints * run CI tests with py 3.11 assert duplicate prototype counts in compile_wbm_test_set.py print prevalence of stable structures in eda_wbm.py remove metrics tables captions fix TNR typos in preprint/+page.md and iclr-ml4mat/+page.md * fix 2023-12-13-wbm-summary.csv.gz figshare URL in 1.0.0.json fix test_data.py expected df_wbm.shape and load() bad key err msg
1 parent d221cee commit a7f693d

26 files changed

Lines changed: 765 additions & 183 deletions

.github/workflows/slow-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- name: Set up Python
1818
uses: actions/setup-python@v5
1919
with:
20-
python-version: 3.9
20+
python-version: 3.11
2121

2222
- name: Install dependencies
2323
run: pip install -e .[test]

.github/workflows/test-scripts.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Set up Python
2424
uses: actions/setup-python@v5
2525
with:
26-
python-version: 3.9
26+
python-version: 3.11
2727

2828
- name: Install package and dependencies
2929
run: pip install -e .[fetch-data]

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ jobs:
2323
uses: janosh/workflows/.github/workflows/pytest-release.yml@main
2424
with:
2525
os: ${{ matrix.os }}
26+
python-version: 3.11

data/figshare/1.0.0.json

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@
44
"https://figshare.com/ndownloader/files/41233560",
55
"2023-06-02-pbenner-best-alignn-model.pth.zip"
66
],
7-
"mace_checkpoint_1": [
8-
"https://figshare.com/ndownloader/files/42374049",
9-
"2023-08-14-mace-yuan-trained-mptrj-04.model"
10-
],
11-
"mace_checkpoint_2": [
12-
"https://figshare.com/ndownloader/files/43117273",
13-
"2023-10-29-mace-16M-pbenner-mptrj-no-conditional-loss.model"
14-
],
157
"mp_computed_structure_entries": [
168
"https://figshare.com/ndownloader/files/40344436",
179
"2023-02-07-mp-computed-structure-entries.json.gz"
@@ -41,8 +33,8 @@
4133
"2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
4234
],
4335
"wbm_summary": [
44-
"https://figshare.com/ndownloader/files/41296866",
45-
"2022-10-19-wbm-summary.csv.gz"
36+
"https://figshare.com/ndownloader/files/44225498",
37+
"2023-12-13-wbm-summary.csv.gz"
4638
],
4739
"mp_trj_extxyz_by_yuan": [
4840
"https://figshare.com/ndownloader/files/43302033",

data/wbm/compile_wbm_test_set.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pymatviz.io import save_fig
2020
from tqdm import tqdm
2121

22-
from matbench_discovery import PDF_FIGS, SITE_FIGS, Key, today
22+
from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, Key, today
2323
from matbench_discovery.data import DATA_FILES
2424
from matbench_discovery.energy import get_e_form_per_atom
2525

@@ -38,9 +38,6 @@
3838
"""
3939

4040

41-
module_dir = os.path.dirname(__file__)
42-
43-
4441
# %% links to google drive files received via email from 1st author Hai-Chen Wang
4542
# on 2021-06-15 containing initial and relaxed structures
4643
google_drive_ids = {
@@ -53,10 +50,10 @@
5350

5451

5552
# %%
56-
os.makedirs(f"{module_dir}/raw", exist_ok=True)
53+
os.makedirs(f"{WBM_DIR}/raw", exist_ok=True)
5754

5855
for step, file_id in google_drive_ids.items():
59-
file_path = f"{module_dir}/raw/wbm-structures-step-{step}.json.bz2"
56+
file_path = f"{WBM_DIR}/raw/wbm-structures-step-{step}.json.bz2"
6057

6158
if os.path.exists(file_path):
6259
print(f"{file_path} already exists, skipping")
@@ -67,7 +64,7 @@
6764

6865

6966
# %%
70-
summary_path = f"{module_dir}/raw/wbm-summary.txt"
67+
summary_path = f"{WBM_DIR}/raw/wbm-summary.txt"
7168

7269
if not os.path.exists(summary_path):
7370
summary_id_file = "1639IFUG7poaDE2uB6aISUOi65ooBwCIg"
@@ -76,7 +73,7 @@
7673

7774

7875
# %%
79-
json_paths = sorted(glob(f"{module_dir}/raw/wbm-structures-step-*.json.bz2"))
76+
json_paths = sorted(glob(f"{WBM_DIR}/raw/wbm-structures-step-*.json.bz2"))
8077
step_lens = (61848, 52800, 79205, 40328, 23308)
8178
# step 3 has 79,211 initial structures but only 79,205 ComputedStructureEntries
8279
# i.e. 6 extra structures which have missing energy, volume, etc. in the summary file
@@ -177,7 +174,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
177174
# "summary.txt.bz2",
178175
*(f"step_{step}.json.bz2" for step in range(1, 6)),
179176
):
180-
file_path = f"{module_dir}/raw/wbm-cse-{filename.lower().replace('_', '-')}"
177+
file_path = f"{WBM_DIR}/raw/wbm-cse-{filename.lower().replace('_', '-')}"
181178
if os.path.exists(file_path):
182179
print(f"{file_path} already exists, skipping")
183180
continue
@@ -191,7 +188,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
191188

192189

193190
# %%
194-
cse_step_paths = sorted(glob(f"{module_dir}/raw/wbm-cse-step-*.json.bz2"))
191+
cse_step_paths = sorted(glob(f"{WBM_DIR}/raw/wbm-cse-step-*.json.bz2"))
195192
assert len(cse_step_paths) == 5
196193

197194
"""
@@ -295,14 +292,14 @@ def increment_wbm_material_id(wbm_id: str) -> str:
295292
"vol": "volume",
296293
"e": Key.dft_energy,
297294
"e_form": Key.e_form_wbm,
298-
"e_hull": "e_above_hull_wbm",
295+
"e_hull": Key.each_wbm,
299296
"gap": Key.bandgap_pbe,
300297
"id": Key.mat_id,
301298
}
302299
# WBM summary was shared twice, once on google drive, once on materials cloud
303300
# download both and check for consistency
304301
df_summary = pd.read_csv(
305-
f"{module_dir}/raw/wbm-summary.txt", sep="\t", names=col_map.values()
302+
f"{WBM_DIR}/raw/wbm-summary.txt", sep="\t", names=col_map.values()
306303
).set_index(Key.mat_id)
307304

308305
df_summary_bz2 = pd.read_csv(
@@ -398,7 +395,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
398395
),
399396
):
400397
cols = ["formula_from_cse", *cols] # type: ignore[list-item]
401-
df_wbm[cols].reset_index().to_json(f"{module_dir}/{today}-wbm-{fname}.json.bz2")
398+
df_wbm[cols].reset_index().to_json(f"{WBM_DIR}/{today}-wbm-{fname}.json.bz2")
402399

403400

404401
# %%
@@ -589,18 +586,34 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
589586
try:
590587
from aviary.wren.utils import get_aflow_label_from_spglib
591588

592-
if Key.wyckoff not in df_wbm:
593-
df_summary[Key.wyckoff] = None
589+
# add Aflow-style Wyckoff labels for initial and relaxed structures
590+
for key in (Key.init_wyckoff, Key.wyckoff):
591+
if key not in df_wbm:
592+
df_summary[key] = None
594593

595-
for idx, struct in tqdm(df_wbm[Key.init_struct].items(), total=len(df_wbm)):
596-
if not pd.isna(df_summary.loc[idx, Key.wyckoff]):
594+
# from initial structures
595+
for idx in tqdm(df_wbm.index):
596+
if not pd.isna(df_summary.loc[idx, Key.init_wyckoff]):
597597
continue # Aflow label already computed
598598
try:
599-
struct = Structure.from_dict(struct)
599+
struct = Structure.from_dict(df_wbm.loc[idx, Key.init_struct])
600+
df_summary.loc[idx, Key.init_wyckoff] = get_aflow_label_from_spglib(struct)
601+
except Exception as exc:
602+
print(f"{idx=} {exc=}")
603+
604+
# from relaxed structures
605+
for idx in tqdm(df_wbm.index):
606+
if not pd.isna(df_summary.loc[idx, Key.wyckoff]):
607+
continue
608+
609+
try:
610+
cse = df_wbm.loc[idx, Key.cse]
611+
struct = Structure.from_dict(cse["structure"])
600612
df_summary.loc[idx, Key.wyckoff] = get_aflow_label_from_spglib(struct)
601613
except Exception as exc:
602614
print(f"{idx=} {exc=}")
603615

616+
assert df_summary[Key.init_wyckoff].isna().sum() == 0
604617
assert df_summary[Key.wyckoff].isna().sum() == 0
605618
except ImportError:
606619
print("aviary not installed, skipping Wyckoff label generation")
@@ -609,7 +622,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
609622

610623

611624
# %%
612-
fingerprints_path = f"{module_dir}/site-stats.json.gz"
625+
fingerprints_path = f"{WBM_DIR}/site-stats.json.gz"
613626
suggest = "not found, run scripts/compute_struct_fingerprints.py to generate"
614627
fp_diff_col = "site_stats_fingerprint_init_final_norm_diff"
615628
try:
@@ -621,16 +634,40 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
621634
print(f"{fingerprints_path=} does not contain {fp_diff_col=}")
622635

623636

637+
# %% mark WBM materials with matching prototype in MP or duplicate prototypes
638+
# in WBM (keeping only the lowest energy one)
639+
df_mp = pd.read_csv(DATA_FILES.mp_energies, index_col=0)
640+
641+
# mask WBM materials with matching prototype in MP
642+
mask_proto_in_mp = df_summary[Key.wyckoff].isin(df_mp[Key.wyckoff])
643+
# mask duplicate prototypes in WBM (keeping the lowest energy one)
644+
mask_dupe_protos = df_summary.sort_values(by=[Key.wyckoff, Key.each_wbm]).duplicated(
645+
subset=Key.wyckoff, keep="first"
646+
)
647+
assert sum(mask_proto_in_mp) == 11_175, f"{sum(mask_proto_in_mp)=:_}"
648+
assert sum(mask_dupe_protos) == 32_784, f"{sum(mask_dupe_protos)=:_}"
649+
650+
df_summary[Key.uniq_proto] = ~(mask_proto_in_mp | mask_dupe_protos)
651+
assert dict(df_summary[Key.uniq_proto].value_counts()) == {True: 215_488, False: 41_475}
652+
assert list(df_summary.query(f"~{Key.uniq_proto}").head(5).index) == [
653+
"wbm-1-7",
654+
"wbm-1-8",
655+
"wbm-1-15",
656+
"wbm-1-20",
657+
"wbm-1-33",
658+
]
659+
660+
624661
# %% write final summary data to disk (yeah!)
625-
df_summary.round(6).to_csv(f"{module_dir}/{today}-wbm-summary.csv")
662+
df_summary.round(6).to_csv(f"{WBM_DIR}/{today}-wbm-summary.csv.gz")
626663

627664

628665
# %% only here to load data for later inspection
629666
if False:
630-
wbm_summary_path = f"{module_dir}/2022-10-19-wbm-summary.csv.gz"
667+
wbm_summary_path = f"{WBM_DIR}/2022-10-19-wbm-summary.csv.gz"
631668
df_summary = pd.read_csv(wbm_summary_path).set_index(Key.mat_id)
632669
df_wbm = pd.read_json(
633-
f"{module_dir}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
670+
f"{WBM_DIR}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
634671
).set_index(Key.mat_id)
635672

636673
df_wbm["cse"] = [

data/wbm/eda_wbm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@
5757
)
5858

5959

60+
# %% print prevalence of stable structures in full WBM and uniq-prototypes only
61+
for df, label in (
62+
(df_wbm, "full WBM"),
63+
(df_wbm.query(Key.uniq_proto), "WBM unique prototypes"),
64+
):
65+
n_stable = sum(df[Key.each_true] <= STABILITY_THRESHOLD)
66+
stable_rate = n_stable / len(df)
67+
print(f"{label}: {stable_rate=:.1%} ({n_stable:,} out of {len(df):,})")
68+
69+
6070
# %%
6171
for dataset, count_mode, elem_counts in all_counts:
6272
filename = f"{dataset}-element-counts-by-{count_mode}"
@@ -303,7 +313,7 @@
303313

304314

305315
# %%
306-
df_wbm[Key.spacegroup] = df_wbm[Key.wyckoff].str.split("_").str[2].astype(int)
316+
df_wbm[Key.spacegroup] = df_wbm[Key.init_wyckoff].str.split("_").str[2].astype(int)
307317
df_mp[Key.spacegroup] = df_mp[Key.wyckoff].str.split("_").str[2].astype(int)
308318

309319

data/wbm/readme.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ The full set of processing steps used to curate the WBM test set from the raw da
2323
- remove 6 pathological structures (with 0 volume)
2424
- remove formation energy outliers below -5 and above 5 eV/atom (502 and 22 crystals respectively out of 257,487 total, including an anomaly of 500 structures at exactly -10 eV/atom)
2525

26-
<caption>WBM Formation energy distribution. 524 materials outside dashed lines were discarded.<br />(zoom out on this plot to see discarded samples)</caption>
26+
<caption style="margin: 1em;">WBM Formation energy distribution. 524 materials outside dashed lines were discarded.</caption>
2727
<slot name="hist-e-form-per-atom">
2828
<img src="./figs/hist-wbm-e-form-per-atom.svg" alt="WBM formation energy histogram indicating outlier cutoffs">
2929
</slot>

matbench_discovery/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import warnings
66
from datetime import datetime
7-
from enum import Enum, unique
7+
from enum import StrEnum, unique
88
from importlib.metadata import Distribution
99

1010
import matplotlib.pyplot as plt
@@ -54,10 +54,6 @@
5454
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
5555

5656

57-
class StrEnum(str, Enum):
58-
"""Enum whose members are also (and must be) strings."""
59-
60-
6157
@unique
6258
class Key(StrEnum):
6359
"""Keys used to access dataframes columns."""
@@ -72,8 +68,10 @@ class Key(StrEnum):
7268
e_form_pred = "e_form_per_atom_pred"
7369
e_form_raw = "e_form_per_atom_uncorrected"
7470
e_form_wbm = "e_form_per_atom_wbm"
71+
each = "energy_above_hull" # as returned by MP API
7572
each_pred = "e_above_hull_pred"
7673
each_true = "e_above_hull_mp2020_corrected_ppd_mp"
74+
each_wbm = "e_above_hull_wbm"
7775
final_struct = "relaxed_structure"
7876
forces = "forces"
7977
form_energy = "formation_energy_per_atom"
@@ -91,8 +89,11 @@ class Key(StrEnum):
9189
stress_trace = "stress_trace"
9290
struct = "structure"
9391
task_id = "task_id"
92+
# lowest WBM structures for a given prototype that isn't already in MP
93+
uniq_proto = "unique_prototype"
9494
volume = "volume"
95-
wyckoff = "wyckoff_spglib"
95+
wyckoff = "wyckoff_spglib" # relaxed structure Aflow label
96+
init_wyckoff = "wyckoff_spglib_initial_structure" # initial structure Aflow label
9697

9798

9899
@unique

0 commit comments

Comments
 (0)