1
+ """Define schemas for DFPT, phonopy, and pheasy-derived phonon data."""
1
2
from __future__ import annotations
2
3
3
4
from datetime import datetime
5
+ import json
4
6
from functools import cached_property
5
7
8
+ from monty .dev import requires
6
9
import numpy as np
7
10
from pydantic import model_validator , BaseModel , Field , computed_field , PrivateAttr
8
11
from typing import Optional , TYPE_CHECKING
22
25
23
26
from typing_extensions import Literal
24
27
28
+ try :
29
+ import pyarrow as pa
30
+
31
+ from pyarrow import Table as ArrowTable
32
+ except ImportError :
33
+ ArrowTable = None
34
+
25
35
if TYPE_CHECKING :
26
36
from collections .abc import Sequence
27
37
from typing import Any
38
+ from typing_extensions import Self
28
39
29
40
30
41
class PhononDOS (BaseModel ):
@@ -35,8 +46,22 @@ class PhononDOS(BaseModel):
35
46
36
47
@cached_property
37
48
def to_pmg (self ) -> PhononDosObject :
49
+ """Get / cache corresponding pymatgen object."""
38
50
return PhononDosObject (frequencies = self .frequencies , densities = self .densities )
39
51
52
+ @requires (pa is not None , "`pip install pyarrow` to use this functionality." )
53
+ def to_arrow (self , col_prefix : str | None = None ) -> ArrowTable :
54
+ """Convert PhononDOS to a pyarrow Table."""
55
+ col_prefix = col_prefix or ""
56
+ return pa .Table .from_pydict ({f"{ col_prefix } { k } " : [getattr (self ,k )] for k in ("frequencies" ,"densities" )})
57
+
58
+ @classmethod
59
+ @requires (pa is not None , "`pip install pyarrow` to use this functionality." )
60
+ def from_arrow (cls , table : ArrowTable , col_prefix : str | None = None ) -> Self :
61
+ """Create a PhononDOS from a pyarrow Table."""
62
+ col_prefix = col_prefix or ""
63
+ return cls (** {k : table [f"{ col_prefix } { k } " ].to_pylist ()[0 ] for k in cls .model_fields })
64
+
40
65
41
66
class PhononBS (BaseModel ):
42
67
"""Define schema of pymatgen phonon band structure."""
@@ -47,9 +72,7 @@ class PhononBS(BaseModel):
47
72
frequencies : list [list [float ]] = Field (
48
73
description = "The phonon frequencies in THz, with the first index representing the band, and the second the q-point." ,
49
74
)
50
- reciprocal_lattice : tuple [Vector3D , Vector3D , Vector3D ] = Field (
51
- description = "The reciprocal lattice."
52
- )
75
+ reciprocal_lattice : Matrix3D = Field (description = "The reciprocal lattice." )
53
76
has_nac : bool = Field (
54
77
False ,
55
78
description = "Whether the calculation includes non-analytical corrections at Gamma." ,
@@ -102,6 +125,7 @@ def primitive_structure(self) -> Structure | None:
102
125
103
126
@cached_property
104
127
def to_pmg (self ) -> PhononBandStructureSymmLine :
128
+ """Get / cache corresponding pymatgen object."""
105
129
rlatt = Lattice (self .reciprocal_lattice )
106
130
return PhononBandStructureSymmLine (
107
131
[Kpoint (q , lattice = rlatt ).frac_coords for q in self .qpoints ], # type: ignore[misc]
@@ -117,6 +141,71 @@ def to_pmg(self) -> PhononBandStructureSymmLine:
117
141
coords_are_cartesian = False ,
118
142
)
119
143
144
+ @requires (pa is not None , "`pip install pyarrow` to use this functionality." )
145
+ def to_arrow (self , col_prefix : str | None = None ) -> ArrowTable :
146
+ """Convert a PhononBS to an arrow table."""
147
+ config = self .model_dump ()
148
+ if structure := config .pop ("structure" , None ):
149
+ config ["structure" ] = json .dumps (structure .as_dict ())
150
+
151
+ for k in ("qpoints" , "frequencies" , "reciprocal_lattice" , "eigendisplacements" ):
152
+ if (vals := config .pop (k , None )) and k == "eigendisplacements" :
153
+ cvals = np .array (vals )
154
+ config ["eigendisplacements_real" ] = cvals .real .flatten ()
155
+ config ["eigendisplacements_imag" ] = cvals .imag .flatten ()
156
+ config ["eigendisplacements_shape" ] = list (cvals .shape )
157
+ elif vals :
158
+ rvals = np .array (vals )
159
+ config [k ] = rvals .flatten ()
160
+ config [f"{ k } _shape" ] = list (rvals .shape )
161
+
162
+ if qpt_labels := config .pop ("labels_dict" ):
163
+ config ["qpoint_labels" ] = list (qpt_labels )
164
+ config ["qpoint_labelled_points" ] = [
165
+ qpt_labels [k ] for k in config ["qpoint_labels" ]
166
+ ]
167
+
168
+ col_prefix = col_prefix or ""
169
+ return pa .Table .from_pydict ({f"{ col_prefix } { k } " : [v ] for k , v in config .items ()})
170
+
171
+ @classmethod
172
+ @requires (pa is not None , "`pip install pyarrow` to use this functionality." )
173
+ def from_arrow (cls , table : ArrowTable , col_prefix : str | None = None ) -> Self :
174
+ """Create a PhononBS from an arrow table."""
175
+ col_prefix = col_prefix or ""
176
+ config : dict [str , Any ] = {}
177
+ for k in (
178
+ "structure" ,
179
+ "has_nac" ,
180
+ "qpoints" ,
181
+ "frequencies" ,
182
+ "reciprocal_lattice" ,
183
+ "eigendisplacements_real" ,
184
+ "qpoint_labels" ,
185
+ ):
186
+ _k = f"{ col_prefix } { k } "
187
+ if _k not in table .column_names :
188
+ continue
189
+ v = table [_k ].to_pylist ()[0 ]
190
+ if k == "structure" :
191
+ config [k ] = Structure .from_dict (json .loads (v ))
192
+ elif k in ("qpoints" , "frequencies" , "reciprocal_lattice" ):
193
+ config [k ] = np .array (v ).reshape (
194
+ tuple (table [f"{ _k } _shape" ].to_pylist ()[0 ])
195
+ )
196
+ elif k == "eigendisplacements_real" :
197
+ config ["eigendisplacements" ] = (
198
+ table [f"{ col_prefix } eigendisplacements_real" ].to_numpy ()[0 ]
199
+ + 1.0j * table [f"{ col_prefix } eigendisplacements_imag" ].to_numpy ()[0 ]
200
+ ).reshape (tuple (table [f"{ col_prefix } eigendisplacements_shape" ].to_pylist ()[0 ]))
201
+ elif k == "qpoint_labels" :
202
+ config ["labels_dict" ] = dict (
203
+ zip (v , table [f"{ col_prefix } qpoint_labelled_points" ].to_pylist ()[0 ])
204
+ )
205
+ else :
206
+ config [k ] = v
207
+ return cls (** config )
208
+
120
209
121
210
class SumRuleChecks (BaseModel ):
122
211
"""Container class for defining sum rule checks."""
@@ -281,7 +370,22 @@ def compute_thermo_quantites(
281
370
thermo_props ["temperature" ] = temperatures
282
371
return thermo_props
283
372
284
-
373
+ @requires (pa is not None , "`pip install pyarrow` to use this functionality." )
374
+ def objects_to_arrow (self ) -> ArrowTable :
375
+ """Convert band structure and DOS to pyarrow table row."""
376
+ table = pa .Table .from_pydict ({"material_id" : [self .material_id ]})
377
+ if self .phonon_bandstructure :
378
+ bst = self .phonon_bandstructure .to_arrow (col_prefix = "bs_" )
379
+ for k in bst .column_names :
380
+ table = table .append_column (k ,bst [k ])
381
+
382
+ if self .phonon_dos :
383
+ dost = self .phonon_dos .to_arrow (col_prefix = "dos_" )
384
+
385
+ for k in dost .column_names :
386
+ table = table .append_column (k ,dost [k ])
387
+ return table
388
+
285
389
class PhononComputationalSettings (BaseModel ):
286
390
"""Collection to store computational settings for the phonon computation."""
287
391
0 commit comments