Skip to content

Commit a037651

Browse files
authored
Wrap C++ Mesh from Python (fixes support for Python 3.11) (#2500)
* Wrap mesh in place of too clever casting * Wrap Mesh on Python side. * flake8 fixes * Update for gmsh interface * Code improvement * Fix function space clone * Various fixes * Flake8 fix * Re-enable mypy checks * Small edit * mypy work-around * Merge fix * flake8 fixes * Syntax fixes * _mesh -> _cpp_object * Small fixes * Updates * Add a hint * Simplifications * Add missing import * Tidy up * Simplifications * Tidy up * Flake8 fixes * Eliminate test warning * Wrap cell size function on Python side
1 parent 7422ea4 commit a037651

21 files changed

+252
-210
lines changed

python/demo/demo_lagrange_variants.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,15 @@ def saw_tooth(x):
157157
# elements, and plot the finite element interpolation.
158158

159159
# +
160-
mesh = mesh.create_unit_interval(MPI.COMM_WORLD, 10)
160+
msh = mesh.create_unit_interval(MPI.COMM_WORLD, 10)
161161

162-
x = ufl.SpatialCoordinate(mesh)
162+
x = ufl.SpatialCoordinate(msh)
163163
u_exact = saw_tooth(x[0])
164164

165165
for variant in [basix.LagrangeVariant.equispaced, basix.LagrangeVariant.gll_warped]:
166166
element = basix.create_element(basix.ElementFamily.P, basix.CellType.interval, 10, variant)
167167
ufl_element = basix.ufl_wrapper.BasixElement(element)
168-
V = fem.FunctionSpace(mesh, ufl_element)
168+
V = fem.FunctionSpace(msh, ufl_element)
169169
uh = fem.Function(V)
170170
uh.interpolate(lambda x: saw_tooth(x[0]))
171171
if MPI.COMM_WORLD.size == 1: # Skip this plotting in parallel
@@ -205,11 +205,11 @@ def saw_tooth(x):
205205
for variant in [basix.LagrangeVariant.equispaced, basix.LagrangeVariant.gll_warped]:
206206
element = basix.create_element(basix.ElementFamily.P, basix.CellType.interval, 10, variant)
207207
ufl_element = basix.ufl_wrapper.BasixElement(element)
208-
V = fem.FunctionSpace(mesh, ufl_element)
208+
V = fem.FunctionSpace(msh, ufl_element)
209209
uh = fem.Function(V)
210210
uh.interpolate(lambda x: saw_tooth(x[0]))
211211
M = fem.form((u_exact - uh)**2 * dx)
212-
error = mesh.comm.allreduce(fem.assemble_scalar(M), op=MPI.SUM)
212+
error = msh.comm.allreduce(fem.assemble_scalar(M), op=MPI.SUM)
213213
print(f"Computed L2 interpolation error ({variant.name}):", error ** 0.5)
214214
# -
215215

python/demo/demo_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from dolfinx import fem, la, mesh, plot
2727

2828
from mpi4py import MPI
29-
3029
# -
3130

3231
# SciPy solvers do not support MPI, so all computations will be

python/dolfinx/fem/function.py

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Copyright (C) 2009-2022 Chris N. Richardson, Garth N. Wells and Michal Habera
1+
# Copyright (C) 2009-2023 Chris N. Richardson, Garth N. Wells and Michal Habera
22
#
33
# This file is part of DOLFINx (https://www.fenicsproject.org)
44
#
55
# SPDX-License-Identifier: LGPL-3.0-or-later
6-
"""Collection of functions and function spaces"""
6+
"""Finite element function spaces and functions"""
77

88
from __future__ import annotations
99

@@ -14,20 +14,19 @@
1414

1515
from functools import singledispatch
1616

17-
import numpy as np
18-
import numpy.typing as npt
19-
2017
import basix
2118
import basix.ufl_wrapper
19+
import numpy as np
20+
import numpy.typing as npt
2221
import ufl
2322
import ufl.algorithms
2423
import ufl.algorithms.analysis
25-
from dolfinx import cpp as _cpp
26-
from dolfinx import jit, la
2724
from dolfinx.fem import dofmap
25+
from petsc4py import PETSc
2826
from ufl.domain import extract_unique_domain
2927

30-
from petsc4py import PETSc
28+
from dolfinx import cpp as _cpp
29+
from dolfinx import jit, la
3130

3231

3332
class Constant(ufl.Constant):
@@ -41,7 +40,6 @@ def __init__(self, domain, c: typing.Union[np.ndarray, typing.Sequence, float]):
4140
"""
4241
c = np.asarray(c)
4342
super().__init__(domain, c.shape)
44-
4543
try:
4644
if c.dtype == np.complex64:
4745
self._cpp_object = _cpp.fem.Constant_complex64(c)
@@ -373,13 +371,11 @@ def _(expr: Expression, cells: typing.Optional[np.ndarray] = None):
373371
except TypeError:
374372
# u is callable
375373
assert callable(u)
376-
x = _cpp.fem.interpolation_coords(
377-
self._V.element, self._V.mesh, cells)
378-
self._cpp_object.interpolate(
379-
np.asarray(u(x), dtype=self.dtype), cells)
374+
x = _cpp.fem.interpolation_coords(self._V.element, self._V.mesh._cpp_object, cells)
375+
self._cpp_object.interpolate(np.asarray(u(x), dtype=self.dtype), cells)
380376

381377
def copy(self) -> Function:
382-
"""Return a copy of the Function. The FunctionSpace is shared and the
378+
"""Create a copy of the Function. The FunctionSpace is shared and the
383379
degree-of-freedom vector is copied.
384380
385381
"""
@@ -445,8 +441,7 @@ def split(self) -> tuple[Function, ...]:
445441

446442
def collapse(self) -> Function:
447443
u_collapsed = self._cpp_object.collapse()
448-
V_collapsed = FunctionSpace(None, self.ufl_element(),
449-
u_collapsed.function_space)
444+
V_collapsed = FunctionSpace(self.function_space._mesh, self.ufl_element(), u_collapsed.function_space)
450445
return Function(V_collapsed, u_collapsed.x)
451446

452447

@@ -459,23 +454,13 @@ class ElementMetaData(typing.NamedTuple):
459454
class FunctionSpace(ufl.FunctionSpace):
460455
"""A space on which Functions (fields) can be defined."""
461456

462-
def __init__(self, mesh: typing.Union[None, Mesh],
457+
def __init__(self, mesh: Mesh,
463458
element: typing.Union[ufl.FiniteElementBase, ElementMetaData, typing.Tuple[str, int]],
464459
cppV: typing.Optional[_cpp.fem.FunctionSpace] = None,
465460
form_compiler_options: dict[str, typing.Any] = {}, jit_options: dict[str, typing.Any] = {}):
466461
"""Create a finite element function space."""
467462

468-
# Create function space from a UFL element and existing cpp
469-
# FunctionSpace
470-
if cppV is not None:
471-
assert mesh is None
472-
ufl_domain = cppV.mesh.ufl_domain()
473-
super().__init__(ufl_domain, element)
474-
self._cpp_object = cppV
475-
return
476-
477-
if mesh is not None:
478-
assert cppV is None
463+
if cppV is None:
479464
# Initialise the ufl.FunctionSpace
480465
if isinstance(element, ufl.FiniteElementBase):
481466
super().__init__(mesh.ufl_domain(), element)
@@ -491,17 +476,26 @@ def __init__(self, mesh: typing.Union[None, Mesh],
491476
jit_options=jit_options)
492477

493478
ffi = module.ffi
494-
cpp_element = _cpp.fem.FiniteElement(
495-
ffi.cast("uintptr_t", ffi.addressof(self._ufcx_element)))
479+
cpp_element = _cpp.fem.FiniteElement(ffi.cast("uintptr_t", ffi.addressof(self._ufcx_element)))
496480
cpp_dofmap = _cpp.fem.create_dofmap(mesh.comm, ffi.cast(
497481
"uintptr_t", ffi.addressof(self._ufcx_dofmap)), mesh.topology, cpp_element)
498482

499-
# Initialize the cpp.FunctionSpace
500-
self._cpp_object = _cpp.fem.FunctionSpace(
501-
mesh, cpp_element, cpp_dofmap)
483+
# Initialize the cpp.FunctionSpace and store mesh
484+
self._cpp_object = _cpp.fem.FunctionSpace(mesh._cpp_object, cpp_element, cpp_dofmap)
485+
self._mesh = mesh
486+
else:
487+
# Create function space from a UFL element and an existing
488+
# C++ FunctionSpace
489+
if mesh._cpp_object is not cppV.mesh:
490+
raise RecursionError("Meshes do not match in FunctionSpace initialisation.")
491+
ufl_domain = mesh.ufl_domain()
492+
super().__init__(ufl_domain, element)
493+
self._cpp_object = cppV
494+
self._mesh = mesh
495+
return
502496

503497
def clone(self) -> FunctionSpace:
504-
"""Return a new FunctionSpace :math:`W` which shares data with this
498+
"""Create a new FunctionSpace :math:`W` which shares data with this
505499
FunctionSpace :math:`V`, but with a different unique integer ID.
506500
507501
This function is helpful for defining mixed problems and using
@@ -513,10 +507,12 @@ def clone(self) -> FunctionSpace:
513507
diagonal blocks. This is relevant for the handling of boundary
514508
conditions.
515509
510+
Returns:
511+
A new function space that shares data
512+
516513
"""
517-
Vcpp = _cpp.fem.FunctionSpace(
518-
self._cpp_object.mesh, self._cpp_object.element, self._cpp_object.dofmap)
519-
return FunctionSpace(None, self.ufl_element(), Vcpp)
514+
Vcpp = _cpp.fem.FunctionSpace(self._cpp_object.mesh, self._cpp_object.element, self._cpp_object.dofmap)
515+
return FunctionSpace(self._mesh, self.ufl_element(), Vcpp)
520516

521517
@property
522518
def num_sub_spaces(self) -> int:
@@ -536,7 +532,7 @@ def sub(self, i: int) -> FunctionSpace:
536532
assert self.ufl_element().num_sub_elements() > i
537533
sub_element = self.ufl_element().sub_elements()[i]
538534
cppV_sub = self._cpp_object.sub([i])
539-
return FunctionSpace(None, sub_element, cppV_sub)
535+
return FunctionSpace(self._mesh, sub_element, cppV_sub)
540536

541537
def component(self):
542538
"""Return the component relative to the parent space."""
@@ -562,15 +558,13 @@ def __ne__(self, other):
562558
"""Comparison for inequality."""
563559
return super().__ne__(other) or self._cpp_object != other._cpp_object
564560

565-
def ufl_cell(self):
566-
return self._cpp_object.mesh.ufl_cell()
567-
568561
def ufl_function_space(self) -> ufl.FunctionSpace:
569562
"""UFL function space"""
570563
return self
571564

572565
@property
573566
def element(self):
567+
"""Function space finite element."""
574568
return self._cpp_object.element
575569

576570
@property
@@ -579,9 +573,9 @@ def dofmap(self) -> dofmap.DofMap:
579573
return dofmap.DofMap(self._cpp_object.dofmap)
580574

581575
@property
582-
def mesh(self) -> _cpp.mesh.Mesh:
583-
"""Return the mesh on which the function space is defined."""
584-
return self._cpp_object.mesh
576+
def mesh(self) -> Mesh:
577+
"""Mesh on which the function space is defined."""
578+
return self._mesh
585579

586580
def collapse(self) -> tuple[FunctionSpace, np.ndarray]:
587581
"""Collapse a subspace and return a new function space and a map from
@@ -592,31 +586,37 @@ def collapse(self) -> tuple[FunctionSpace, np.ndarray]:
592586
593587
"""
594588
cpp_space, dofs = self._cpp_object.collapse()
595-
V = FunctionSpace(None, self.ufl_element(), cpp_space)
589+
V = FunctionSpace(self._mesh, self.ufl_element(), cpp_space)
596590
return V, dofs
597591

598-
def tabulate_dof_coordinates(self) -> np.ndarray:
592+
def tabulate_dof_coordinates(self) -> npt.NDArray[np.float64]:
593+
"""Tabulate the coordinates of the degrees-of-freedom in the function space.
594+
595+
Returns:
596+
Coordinates of the degrees-of-freedom.
597+
598+
Notes:
599+
This method should be used only for elements with point
600+
evaluation degrees-of-freedom.
601+
602+
"""
599603
return self._cpp_object.tabulate_dof_coordinates()
600604

601605

602-
def VectorFunctionSpace(mesh: Mesh, element: typing.Union[ElementMetaData, typing.Tuple[str, int]], dim=None,
603-
restriction=None) -> FunctionSpace:
606+
def VectorFunctionSpace(mesh: Mesh, element: typing.Union[ElementMetaData, typing.Tuple[str, int]],
607+
dim=None) -> FunctionSpace:
604608
"""Create vector finite element (composition of scalar elements) function space."""
605-
606609
e = ElementMetaData(*element)
607-
ufl_element = basix.ufl_wrapper.create_vector_element(
608-
e.family, mesh.ufl_cell().cellname(), e.degree, dim=dim,
609-
gdim=mesh.geometry.dim)
610-
610+
ufl_element = basix.ufl_wrapper.create_vector_element(e.family, mesh.ufl_cell().cellname(), e.degree,
611+
dim=dim, gdim=mesh.geometry.dim)
611612
return FunctionSpace(mesh, ufl_element)
612613

613614

614615
def TensorFunctionSpace(mesh: Mesh, element: typing.Union[ElementMetaData, typing.Tuple[str, int]], shape=None,
615-
symmetry: typing.Optional[bool] = None, restriction=None) -> FunctionSpace:
616+
symmetry: typing.Optional[bool] = None) -> FunctionSpace:
616617
"""Create tensor finite element (composition of scalar elements) function space."""
617-
618618
e = ElementMetaData(*element)
619-
ufl_element = basix.ufl_wrapper.create_tensor_element(
620-
e.family, mesh.ufl_cell().cellname(), e.degree, shape=shape, symmetry=symmetry,
621-
gdim=mesh.geometry.dim)
619+
ufl_element = basix.ufl_wrapper.create_tensor_element(e.family, mesh.ufl_cell().cellname(),
620+
e.degree, shape=shape, symmetry=symmetry,
621+
gdim=mesh.geometry.dim)
622622
return FunctionSpace(mesh, ufl_element)

python/dolfinx/geometry.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99

1010
import typing
1111

12+
import numpy as np
13+
import numpy.typing as npt
14+
1215
if typing.TYPE_CHECKING:
1316
from dolfinx.mesh import Mesh
1417
from dolfinx.cpp.graph import AdjacencyList_int32
1518

1619
import numpy
20+
from dolfinx.cpp.geometry import compute_collisions, compute_distance_gjk
1721

1822
from dolfinx import cpp as _cpp
19-
from dolfinx.cpp.geometry import (compute_closest_entity, compute_collisions,
20-
compute_distance_gjk, create_midpoint_tree)
2123

2224
__all__ = ["compute_colliding_cells", "squared_distance", "compute_closest_entity", "compute_collisions",
2325
"compute_distance_gjk", "create_midpoint_tree"]
@@ -44,7 +46,30 @@ def __init__(self, mesh: Mesh, dim: int, entities=None, padding: float = 0.0):
4446
if entities is None:
4547
entities = range(0, map.size_local + map.num_ghosts)
4648

47-
super().__init__(mesh, dim, entities, padding)
49+
super().__init__(mesh._cpp_object, dim, entities, padding)
50+
51+
52+
def compute_closest_entity(tree: BoundingBoxTree, midpoint_tree: BoundingBoxTree, mesh: Mesh,
53+
points: numpy.ndarray) -> npt.NDArray[np.int32]:
54+
"""Compute closest mesh entity to a point.
55+
56+
Args:
57+
tree: bounding box tree for the entities
58+
midpoint_tree: A bounding box tree with the midpoints of all
59+
the mesh entities. This is used to accelerate the search.
60+
mesh: The mesh
61+
points: The points to check for collision, shape=(num_points, 3)
62+
63+
Returns:
64+
Mesh entity index for each point in `points`. Returns -1 for
65+
a point if the bounding box tree is empty.
66+
67+
"""
68+
return _cpp.geometry.compute_closest_entity(tree, midpoint_tree, mesh._cpp_object, points)
69+
70+
71+
def create_midpoint_tree(mesh: Mesh, dim: int, entities: numpy.ndarray):
72+
return _cpp.geometry.create_midpoint_tree(mesh._cpp_object, dim, entities)
4873

4974

5075
def compute_colliding_cells(mesh: Mesh, candidates: AdjacencyList_int32, x: numpy.ndarray):
@@ -60,7 +85,7 @@ def compute_colliding_cells(mesh: Mesh, candidates: AdjacencyList_int32, x: nump
6085
Adjacency list where the ith node is the list of entities that
6186
collide with the ith point
6287
"""
63-
return _cpp.geometry.compute_colliding_cells(mesh, candidates, x)
88+
return _cpp.geometry.compute_colliding_cells(mesh._cpp_object, candidates, x)
6489

6590

6691
def squared_distance(mesh: Mesh, dim: int, entities: typing.List[int], points: numpy.ndarray):
@@ -80,4 +105,4 @@ def squared_distance(mesh: Mesh, dim: int, entities: typing.List[int], points: n
80105
Squared shortest distance from points[i] to entities[i]
81106
82107
"""
83-
return _cpp.geometry.squared_distance(mesh, dim, entities, points)
108+
return _cpp.geometry.squared_distance(mesh._cpp_object, dim, entities, points)

python/dolfinx/io/gmshio.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ def model_to_mesh(model: gmsh.model, comm: _MPI.Comm, rank: int, gdim: int = 3,
230230

231231
cells = np.asarray(topologies[cell_id]["topology"], dtype=np.int64)
232232
cell_values = np.asarray(topologies[cell_id]["cell_data"], dtype=np.int32)
233-
234233
else:
235234
cell_id, num_nodes = comm.bcast([None, None], root=rank)
236235
cells, x = np.empty([0, num_nodes], dtype=np.int32), np.empty([0, gdim])
@@ -248,7 +247,8 @@ def model_to_mesh(model: gmsh.model, comm: _MPI.Comm, rank: int, gdim: int = 3,
248247
mesh = create_mesh(comm, cells, x[:, :gdim], ufl_domain, partitioner)
249248

250249
# Create MeshTags for cells
251-
local_entities, local_values = _cpp.io.distribute_entity_data(mesh, mesh.topology.dim, cells, cell_values)
250+
local_entities, local_values = _cpp.io.distribute_entity_data(
251+
mesh._cpp_object, mesh.topology.dim, cells, cell_values)
252252
mesh.topology.create_connectivity(mesh.topology.dim, 0)
253253
adj = _cpp.graph.AdjacencyList_int32(local_entities)
254254
ct = meshtags_from_entities(mesh, mesh.topology.dim, adj, local_values.astype(np.int32, copy=False))
@@ -257,7 +257,7 @@ def model_to_mesh(model: gmsh.model, comm: _MPI.Comm, rank: int, gdim: int = 3,
257257
# Create MeshTags for facets
258258
topology = mesh.topology
259259
if has_facet_data:
260-
# Permute facets from MSH to Dolfin-X ordering
260+
# Permute facets from MSH to DOLFINx ordering
261261
# FIXME: This does not work for prism meshes
262262
if topology.cell_type == CellType.prism or topology.cell_type == CellType.pyramid:
263263
raise RuntimeError(f"Unsupported cell type {topology.cell_type}")
@@ -267,7 +267,7 @@ def model_to_mesh(model: gmsh.model, comm: _MPI.Comm, rank: int, gdim: int = 3,
267267
marked_facets = marked_facets[:, gmsh_facet_perm]
268268

269269
local_entities, local_values = _cpp.io.distribute_entity_data(
270-
mesh, mesh.topology.dim - 1, marked_facets, facet_values)
270+
mesh._cpp_object, mesh.topology.dim - 1, marked_facets, facet_values)
271271
mesh.topology.create_connectivity(topology.dim - 1, topology.dim)
272272
adj = _cpp.graph.AdjacencyList_int32(local_entities)
273273
ft = meshtags_from_entities(mesh, topology.dim - 1, adj, local_values.astype(np.int32, copy=False))
@@ -301,11 +301,11 @@ def read_from_msh(filename: str, comm: _MPI.Comm, rank: int = 0, gdim: int = 3,
301301
gmsh.initialize()
302302
gmsh.model.add("Mesh from file")
303303
gmsh.merge(filename)
304-
305-
output = model_to_mesh(gmsh.model, comm, rank, gdim=gdim, partitioner=partitioner)
306-
if comm.rank == rank:
304+
msh = model_to_mesh(gmsh.model, comm, rank, gdim=gdim, partitioner=partitioner)
307305
gmsh.finalize()
308-
return output
306+
return msh
307+
else:
308+
return model_to_mesh(gmsh.model, comm, rank, gdim=gdim, partitioner=partitioner)
309309

310310
# Map from Gmsh cell type identifier (integer) to DOLFINx cell type
311311
# and degree http://gmsh.info//doc/texinfo/gmsh.html#MSH-file-format

0 commit comments

Comments
 (0)