1
+ """Current MP tools to validate VASP calculations."""
2
+ from __future__ import annotations
3
+
1
4
from datetime import datetime
2
- from typing import Dict , List , Union , Optional
5
+ from typing import TYPE_CHECKING
3
6
4
7
import numpy as np
5
8
from pydantic import ConfigDict , Field , ImportString
15
18
from emmet .core .vasp .calc_types .enums import CalcType , TaskType
16
19
from emmet .core .vasp .task_valid import TaskDocument
17
20
21
+ if TYPE_CHECKING :
22
+ from collections .abc import Sequence
23
+
18
24
SETTINGS = EmmetSettings ()
19
25
20
26
@@ -51,19 +57,19 @@ class ValidationDoc(EmmetBaseModel):
51
57
description = "Last updated date for this document" ,
52
58
default_factory = datetime .utcnow ,
53
59
)
54
- reasons : Optional [ List [ Union [ DeprecationMessage , str ]]] = Field (
60
+ reasons : list [ DeprecationMessage | str ] | None = Field (
55
61
None , description = "List of deprecation tags detailing why this task isn't valid"
56
62
)
57
- warnings : List [str ] = Field (
63
+ warnings : list [str ] = Field (
58
64
[], description = "List of potential warnings about this calculation"
59
65
)
60
- data : Dict = Field (
66
+ data : dict = Field (
61
67
description = "Dictioary of data used to perform validation."
62
68
" Useful for post-mortem analysis"
63
69
)
64
70
model_config = ConfigDict (extra = "allow" )
65
- nelements : Optional [ int ] = Field (None , description = "Number of elements." )
66
- symmetry_number : Optional [ int ] = Field (
71
+ nelements : int | None = Field (None , description = "Number of elements." )
72
+ symmetry_number : int | None = Field (
67
73
None ,
68
74
title = "Space Group Number" ,
69
75
description = "The spacegroup number for the lattice." ,
@@ -72,14 +78,14 @@ class ValidationDoc(EmmetBaseModel):
72
78
@classmethod
73
79
def from_task_doc (
74
80
cls ,
75
- task_doc : Union [ TaskDoc , TaskDocument ] ,
81
+ task_doc : TaskDoc | TaskDocument ,
76
82
kpts_tolerance : float = SETTINGS .VASP_KPTS_TOLERANCE ,
77
83
kspacing_tolerance : float = SETTINGS .VASP_KSPACING_TOLERANCE ,
78
- input_sets : Dict [str , ImportString ] = SETTINGS .VASP_DEFAULT_INPUT_SETS ,
79
- LDAU_fields : List [str ] = SETTINGS .VASP_CHECKED_LDAU_FIELDS ,
84
+ input_sets : dict [str , ImportString ] = SETTINGS .VASP_DEFAULT_INPUT_SETS ,
85
+ LDAU_fields : list [str ] = SETTINGS .VASP_CHECKED_LDAU_FIELDS ,
80
86
max_allowed_scf_gradient : float = SETTINGS .VASP_MAX_SCF_GRADIENT ,
81
- max_magmoms : Dict [str , float ] = SETTINGS .VASP_MAX_MAGMOM ,
82
- potcar_stats : Optional [ Dict [ CalcType , Dict [str , str ]]] = None ,
87
+ max_magmoms : dict [str , float ] = SETTINGS .VASP_MAX_MAGMOM ,
88
+ potcar_stats : dict [ CalcType , dict [str , str ]] | None = None ,
83
89
) -> "ValidationDoc" :
84
90
"""
85
91
Determines if a calculation is valid based on expected input parameters from a pymatgen inputset
@@ -120,7 +126,7 @@ def from_task_doc(
120
126
121
127
reasons = []
122
128
data = {} # type: ignore
123
- warnings : List [str ] = []
129
+ warnings : list [str ] = []
124
130
125
131
if str (calc_type ) in input_sets :
126
132
try :
@@ -349,12 +355,17 @@ def _kspacing_warnings(input_set, inputs, data, warnings, kspacing_tolerance):
349
355
)
350
356
351
357
352
- def _potcar_stats_check (task_doc , potcar_stats : dict ):
358
+ def _potcar_stats_check (
359
+ task_doc ,
360
+ potcar_stats : dict ,
361
+ exclude_keys : Sequence [str ] | None = ["sha256" , "copyr" ],
362
+ ):
353
363
"""
354
364
Checks to make sure the POTCAR summary stats is equal to the correct
355
365
value from the pymatgen input set.
356
366
"""
357
367
data_tol = 1.0e-6
368
+ excl : set [str ] = set ([k .lower () for k in (exclude_keys or [])])
358
369
359
370
try :
360
371
potcar_details = task_doc .calcs_reversed [0 ].model_dump ()["input" ]["potcar_spec" ]
@@ -396,13 +407,19 @@ def _potcar_stats_check(task_doc, potcar_stats: dict):
396
407
)
397
408
398
409
else :
410
+ entry_keys = {
411
+ key : set ([k .lower () for k in entry ["summary_stats" ]["keywords" ][key ]])
412
+ - excl
413
+ for key in ["header" , "data" ]
414
+ }
399
415
all_match = False
400
416
for ref_stat in ref_summ_stats :
401
- key_match = all (
402
- set (ref_stat ["keywords" ][key ])
403
- == set (entry ["summary_stats" ]["keywords" ][key ])
417
+ ref_keys = {
418
+ key : set ([k .lower () for k in ref_stat ["keywords" ][key ]]) - excl
404
419
for key in ["header" , "data" ]
405
- )
420
+ }
421
+
422
+ key_match = all (entry_keys [k ] == v for k , v in ref_keys .items ())
406
423
407
424
data_match = False
408
425
if key_match :
0 commit comments