Skip to content

Commit 87316d4

Browse files
Tweak potcar validation (#1220)
* add option to exclude certain keys from potcar check - needed for potcar 54 validation w/ and w/o hashes * bump robocrys to remove circular software dependence in mp stack
1 parent 2c3cf81 commit 87316d4

File tree

3 files changed

+46
-30
lines changed

3 files changed

+46
-30
lines changed

emmet-core/emmet/core/openmm/tasks.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import io
66
from pathlib import Path
7-
from typing import Optional, Union
7+
from typing import Optional, Union, TYPE_CHECKING
88

99
import openmm
1010
import pandas as pd # type: ignore[import-untyped]
@@ -17,6 +17,9 @@
1717
from emmet.core.openff.tasks import CompressedStr # type: ignore[import-untyped]
1818
from emmet.core.vasp.task_valid import TaskState # type: ignore[import-untyped]
1919

20+
if TYPE_CHECKING:
21+
from typing import Any
22+
2023

2124
class CalculationInput(BaseModel): # type: ignore[call-arg]
2225
"""OpenMM input settings for a job, these are the attributes of the OpenMMMaker."""
@@ -161,7 +164,7 @@ def from_directory(
161164
) -> CalculationOutput:
162165
"""Extract data from the output files in the directory."""
163166
state_file = Path(dir_name) / state_file_name
164-
column_name_map = {
167+
column_name_map: dict[str, str] = {
165168
'#"Step"': "steps_reported",
166169
"Potential Energy (kJ/mole)": "potential_energy",
167170
"Kinetic Energy (kJ/mole)": "kinetic_energy",
@@ -174,8 +177,8 @@ def from_directory(
174177
if state_is_not_empty:
175178
data = pd.read_csv(state_file, header=0)
176179
data = data.rename(columns=column_name_map)
177-
data = data.filter(items=column_name_map.values())
178-
attributes = data.to_dict(orient="list")
180+
data = data.filter(items=list(column_name_map.values()))
181+
attributes: dict[str, Any] = data.to_dict(orient="list") # type: ignore[assignment]
179182
else:
180183
attributes = {name: None for name in column_name_map.values()}
181184
state_file_name = None # type: ignore[assignment]
@@ -184,14 +187,10 @@ def from_directory(
184187
traj_is_not_empty = traj_file.exists() and traj_file.stat().st_size > 0
185188
traj_file_name = traj_file_name if traj_is_not_empty else None # type: ignore
186189

187-
if traj_is_not_empty:
188-
if embed_traj:
189-
with open(traj_file, "rb") as f:
190-
traj_blob = f.read().hex()
191-
else:
192-
traj_blob = None
193-
else:
194-
traj_blob = None
190+
traj_blob: str | None = None
191+
if traj_is_not_empty and embed_traj:
192+
with open(traj_file, "rb") as f:
193+
traj_blob = f.read().hex()
195194

196195
return CalculationOutput(
197196
dir_name=str(dir_name),

emmet-core/emmet/core/vasp/validation.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
"""Current MP tools to validate VASP calculations."""
2+
from __future__ import annotations
3+
14
from datetime import datetime
2-
from typing import Dict, List, Union, Optional
5+
from typing import TYPE_CHECKING
36

47
import numpy as np
58
from pydantic import ConfigDict, Field, ImportString
@@ -15,6 +18,9 @@
1518
from emmet.core.vasp.calc_types.enums import CalcType, TaskType
1619
from emmet.core.vasp.task_valid import TaskDocument
1720

21+
if TYPE_CHECKING:
22+
from collections.abc import Sequence
23+
1824
SETTINGS = EmmetSettings()
1925

2026

@@ -51,19 +57,19 @@ class ValidationDoc(EmmetBaseModel):
5157
description="Last updated date for this document",
5258
default_factory=datetime.utcnow,
5359
)
54-
reasons: Optional[List[Union[DeprecationMessage, str]]] = Field(
60+
reasons: list[DeprecationMessage | str] | None = Field(
5561
None, description="List of deprecation tags detailing why this task isn't valid"
5662
)
57-
warnings: List[str] = Field(
63+
warnings: list[str] = Field(
5864
[], description="List of potential warnings about this calculation"
5965
)
60-
data: Dict = Field(
66+
data: dict = Field(
6167
description="Dictioary of data used to perform validation."
6268
" Useful for post-mortem analysis"
6369
)
6470
model_config = ConfigDict(extra="allow")
65-
nelements: Optional[int] = Field(None, description="Number of elements.")
66-
symmetry_number: Optional[int] = Field(
71+
nelements: int | None = Field(None, description="Number of elements.")
72+
symmetry_number: int | None = Field(
6773
None,
6874
title="Space Group Number",
6975
description="The spacegroup number for the lattice.",
@@ -72,14 +78,14 @@ class ValidationDoc(EmmetBaseModel):
7278
@classmethod
7379
def from_task_doc(
7480
cls,
75-
task_doc: Union[TaskDoc, TaskDocument],
81+
task_doc: TaskDoc | TaskDocument,
7682
kpts_tolerance: float = SETTINGS.VASP_KPTS_TOLERANCE,
7783
kspacing_tolerance: float = SETTINGS.VASP_KSPACING_TOLERANCE,
78-
input_sets: Dict[str, ImportString] = SETTINGS.VASP_DEFAULT_INPUT_SETS,
79-
LDAU_fields: List[str] = SETTINGS.VASP_CHECKED_LDAU_FIELDS,
84+
input_sets: dict[str, ImportString] = SETTINGS.VASP_DEFAULT_INPUT_SETS,
85+
LDAU_fields: list[str] = SETTINGS.VASP_CHECKED_LDAU_FIELDS,
8086
max_allowed_scf_gradient: float = SETTINGS.VASP_MAX_SCF_GRADIENT,
81-
max_magmoms: Dict[str, float] = SETTINGS.VASP_MAX_MAGMOM,
82-
potcar_stats: Optional[Dict[CalcType, Dict[str, str]]] = None,
87+
max_magmoms: dict[str, float] = SETTINGS.VASP_MAX_MAGMOM,
88+
potcar_stats: dict[CalcType, dict[str, str]] | None = None,
8389
) -> "ValidationDoc":
8490
"""
8591
Determines if a calculation is valid based on expected input parameters from a pymatgen inputset
@@ -120,7 +126,7 @@ def from_task_doc(
120126

121127
reasons = []
122128
data = {} # type: ignore
123-
warnings: List[str] = []
129+
warnings: list[str] = []
124130

125131
if str(calc_type) in input_sets:
126132
try:
@@ -349,12 +355,17 @@ def _kspacing_warnings(input_set, inputs, data, warnings, kspacing_tolerance):
349355
)
350356

351357

352-
def _potcar_stats_check(task_doc, potcar_stats: dict):
358+
def _potcar_stats_check(
359+
task_doc,
360+
potcar_stats: dict,
361+
exclude_keys: Sequence[str] | None = ["sha256", "copyr"],
362+
):
353363
"""
354364
Checks to make sure the POTCAR summary stats is equal to the correct
355365
value from the pymatgen input set.
356366
"""
357367
data_tol = 1.0e-6
368+
excl: set[str] = set([k.lower() for k in (exclude_keys or [])])
358369

359370
try:
360371
potcar_details = task_doc.calcs_reversed[0].model_dump()["input"]["potcar_spec"]
@@ -396,13 +407,19 @@ def _potcar_stats_check(task_doc, potcar_stats: dict):
396407
)
397408

398409
else:
410+
entry_keys = {
411+
key: set([k.lower() for k in entry["summary_stats"]["keywords"][key]])
412+
- excl
413+
for key in ["header", "data"]
414+
}
399415
all_match = False
400416
for ref_stat in ref_summ_stats:
401-
key_match = all(
402-
set(ref_stat["keywords"][key])
403-
== set(entry["summary_stats"]["keywords"][key])
417+
ref_keys = {
418+
key: set([k.lower() for k in ref_stat["keywords"][key]]) - excl
404419
for key in ["header", "data"]
405-
)
420+
}
421+
422+
key_match = all(entry_keys[k] == v for k, v in ref_keys.items())
406423

407424
data_match = False
408425
if key_match:

emmet-core/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"all": [
4040
"matcalc>=0.3.1",
4141
"seekpath>=2.0.1",
42-
"robocrys>=0.2.8",
42+
"robocrys>=0.2.11",
4343
"pymatgen-analysis-defects>=2024.7.18",
4444
"pymatgen-analysis-diffusion>=2024.7.15",
4545
"pymatgen-analysis-alloys>=0.0.6",

0 commit comments

Comments
 (0)