Skip to content

Commit 864a7ee

Browse files
add parquet
1 parent 941412c commit 864a7ee

File tree

1 file changed

+108
-4
lines changed

1 file changed

+108
-4
lines changed

emmet-core/emmet/core/phonon.py

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
"""Define schemas for DFPT, phonopy, and pheasy-derived phonon data."""
12
from __future__ import annotations
23

34
from datetime import datetime
5+
import json
46
from functools import cached_property
57

8+
from monty.dev import requires
69
import numpy as np
710
from pydantic import model_validator, BaseModel, Field, computed_field, PrivateAttr
811
from typing import Optional, TYPE_CHECKING
@@ -22,9 +25,17 @@
2225

2326
from typing_extensions import Literal
2427

28+
try:
29+
import pyarrow as pa
30+
31+
from pyarrow import Table as ArrowTable
32+
except ImportError:
33+
ArrowTable = None
34+
2535
if TYPE_CHECKING:
2636
from collections.abc import Sequence
2737
from typing import Any
38+
from typing_extensions import Self
2839

2940

3041
class PhononDOS(BaseModel):
@@ -35,8 +46,22 @@ class PhononDOS(BaseModel):
3546

3647
@cached_property
3748
def to_pmg(self) -> PhononDosObject:
49+
"""Get / cache corresponding pymatgen object."""
3850
return PhononDosObject(frequencies=self.frequencies, densities=self.densities)
3951

52+
@requires(pa is not None, "`pip install pyarrow` to use this functionality.")
53+
def to_arrow(self, col_prefix : str | None = None) -> ArrowTable:
54+
"""Convert PhononDOS to a pyarrow Table."""
55+
col_prefix = col_prefix or ""
56+
return pa.Table.from_pydict({f"{col_prefix}{k}": [getattr(self,k)] for k in ("frequencies","densities")})
57+
58+
@classmethod
59+
@requires(pa is not None, "`pip install pyarrow` to use this functionality.")
60+
def from_arrow(cls, table: ArrowTable, col_prefix : str | None = None) -> Self:
61+
"""Create a PhononDOS from a pyarrow Table."""
62+
col_prefix = col_prefix or ""
63+
return cls(**{k: table[f"{col_prefix}{k}"].to_pylist()[0] for k in cls.model_fields})
64+
4065

4166
class PhononBS(BaseModel):
4267
"""Define schema of pymatgen phonon band structure."""
@@ -47,9 +72,7 @@ class PhononBS(BaseModel):
4772
frequencies: list[list[float]] = Field(
4873
description="The phonon frequencies in THz, with the first index representing the band, and the second the q-point.",
4974
)
50-
reciprocal_lattice: tuple[Vector3D, Vector3D, Vector3D] = Field(
51-
description="The reciprocal lattice."
52-
)
75+
reciprocal_lattice: Matrix3D = Field(description="The reciprocal lattice.")
5376
has_nac: bool = Field(
5477
False,
5578
description="Whether the calculation includes non-analytical corrections at Gamma.",
@@ -102,6 +125,7 @@ def primitive_structure(self) -> Structure | None:
102125

103126
@cached_property
104127
def to_pmg(self) -> PhononBandStructureSymmLine:
128+
"""Get / cache corresponding pymatgen object."""
105129
rlatt = Lattice(self.reciprocal_lattice)
106130
return PhononBandStructureSymmLine(
107131
[Kpoint(q, lattice=rlatt).frac_coords for q in self.qpoints], # type: ignore[misc]
@@ -117,6 +141,71 @@ def to_pmg(self) -> PhononBandStructureSymmLine:
117141
coords_are_cartesian=False,
118142
)
119143

144+
@requires(pa is not None, "`pip install pyarrow` to use this functionality.")
145+
def to_arrow(self, col_prefix : str | None = None) -> ArrowTable:
146+
"""Convert a PhononBS to an arrow table."""
147+
config = self.model_dump()
148+
if structure := config.pop("structure", None):
149+
config["structure"] = json.dumps(structure.as_dict())
150+
151+
for k in ("qpoints", "frequencies", "reciprocal_lattice", "eigendisplacements"):
152+
if (vals := config.pop(k, None)) and k == "eigendisplacements":
153+
cvals = np.array(vals)
154+
config["eigendisplacements_real"] = cvals.real.flatten()
155+
config["eigendisplacements_imag"] = cvals.imag.flatten()
156+
config["eigendisplacements_shape"] = list(cvals.shape)
157+
elif vals:
158+
rvals = np.array(vals)
159+
config[k] = rvals.flatten()
160+
config[f"{k}_shape"] = list(rvals.shape)
161+
162+
if qpt_labels := config.pop("labels_dict"):
163+
config["qpoint_labels"] = list(qpt_labels)
164+
config["qpoint_labelled_points"] = [
165+
qpt_labels[k] for k in config["qpoint_labels"]
166+
]
167+
168+
col_prefix = col_prefix or ""
169+
return pa.Table.from_pydict({f"{col_prefix}{k}": [v] for k, v in config.items()})
170+
171+
@classmethod
172+
@requires(pa is not None, "`pip install pyarrow` to use this functionality.")
173+
def from_arrow(cls, table: ArrowTable, col_prefix : str | None = None) -> Self:
174+
"""Create a PhononBS from an arrow table."""
175+
col_prefix= col_prefix or ""
176+
config: dict[str, Any] = {}
177+
for k in (
178+
"structure",
179+
"has_nac",
180+
"qpoints",
181+
"frequencies",
182+
"reciprocal_lattice",
183+
"eigendisplacements_real",
184+
"qpoint_labels",
185+
):
186+
_k = f"{col_prefix}{k}"
187+
if _k not in table.column_names:
188+
continue
189+
v = table[_k].to_pylist()[0]
190+
if k == "structure":
191+
config[k] = Structure.from_dict(json.loads(v))
192+
elif k in ("qpoints", "frequencies", "reciprocal_lattice"):
193+
config[k] = np.array(v).reshape(
194+
tuple(table[f"{_k}_shape"].to_pylist()[0])
195+
)
196+
elif k == "eigendisplacements_real":
197+
config["eigendisplacements"] = (
198+
table[f"{col_prefix}eigendisplacements_real"].to_numpy()[0]
199+
+ 1.0j * table[f"{col_prefix}eigendisplacements_imag"].to_numpy()[0]
200+
).reshape(tuple(table[f"{col_prefix}eigendisplacements_shape"].to_pylist()[0]))
201+
elif k == "qpoint_labels":
202+
config["labels_dict"] = dict(
203+
zip(v, table[f"{col_prefix}qpoint_labelled_points"].to_pylist()[0])
204+
)
205+
else:
206+
config[k] = v
207+
return cls(**config)
208+
120209

121210
class SumRuleChecks(BaseModel):
122211
"""Container class for defining sum rule checks."""
@@ -281,7 +370,22 @@ def compute_thermo_quantites(
281370
thermo_props["temperature"] = temperatures
282371
return thermo_props
283372

284-
373+
@requires(pa is not None, "`pip install pyarrow` to use this functionality.")
374+
def objects_to_arrow(self) -> ArrowTable:
375+
"""Convert band structure and DOS to pyarrow table row."""
376+
table = pa.Table.from_pydict({"material_id": [self.material_id]})
377+
if self.phonon_bandstructure:
378+
bst = self.phonon_bandstructure.to_arrow(col_prefix="bs_")
379+
for k in bst.column_names:
380+
table = table.append_column(k,bst[k])
381+
382+
if self.phonon_dos:
383+
dost = self.phonon_dos.to_arrow(col_prefix="dos_")
384+
385+
for k in dost.column_names:
386+
table = table.append_column(k,dost[k])
387+
return table
388+
285389
class PhononComputationalSettings(BaseModel):
286390
"""Collection to store computational settings for the phonon computation."""
287391

0 commit comments

Comments
 (0)