"""
VTK export functionality for IWFM models.
This module provides classes for exporting IWFM model data to
VTK formats for 3D visualization in tools like ParaView.
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Literal
import numpy as np
from numpy.typing import NDArray
if TYPE_CHECKING:
import pyvista as pv
import vtk
from pyiwfm.core.mesh import AppGrid
from pyiwfm.core.stratigraphy import Stratigraphy
[docs]
class VTKExporter:
"""
Export IWFM model data to VTK formats.
This class converts model meshes and stratigraphy to VTK
UnstructuredGrid objects that can be exported to VTU or
legacy VTK formats for visualization in ParaView.
Attributes:
grid: Model mesh
stratigraphy: Model stratigraphy (optional, required for 3D)
"""
[docs]
def __init__(
self,
grid: AppGrid,
stratigraphy: Stratigraphy | None = None,
) -> None:
"""
Initialize the VTK exporter.
Args:
grid: Model mesh
stratigraphy: Model stratigraphy (optional, required for 3D)
"""
try:
import vtk # noqa: F401
except ImportError as e:
raise ImportError(
"VTK is required for VTK export. Install with: pip install vtk"
) from e
self.grid = grid
self.stratigraphy = stratigraphy
[docs]
def create_2d_mesh(self) -> vtk.vtkUnstructuredGrid:
"""
Create a 2D VTK UnstructuredGrid from the mesh.
Returns:
VTK UnstructuredGrid with 2D mesh
"""
import vtk
vtk_grid = vtk.vtkUnstructuredGrid()
# Create points
points = vtk.vtkPoints()
node_id_to_vtk_id = {}
for i, node in enumerate(self.grid.iter_nodes()):
points.InsertNextPoint(node.x, node.y, 0.0)
node_id_to_vtk_id[node.id] = i
vtk_grid.SetPoints(points)
# Create cells
for elem in self.grid.iter_elements():
if elem.is_triangle:
cell = vtk.vtkTriangle()
for i, vid in enumerate(elem.vertices):
cell.GetPointIds().SetId(i, node_id_to_vtk_id[vid])
else: # Quad
cell = vtk.vtkQuad()
for i, vid in enumerate(elem.vertices):
cell.GetPointIds().SetId(i, node_id_to_vtk_id[vid])
vtk_grid.InsertNextCell(cell.GetCellType(), cell.GetPointIds())
return vtk_grid
[docs]
def create_3d_mesh(self) -> vtk.vtkUnstructuredGrid:
"""
Create a 3D VTK UnstructuredGrid from mesh and stratigraphy.
Quad elements become hexahedra, triangles become wedges.
Returns:
VTK UnstructuredGrid with 3D mesh
Raises:
ValueError: If stratigraphy is not set
"""
import vtk
if self.stratigraphy is None:
raise ValueError("Stratigraphy required for 3D mesh")
vtk_grid = vtk.vtkUnstructuredGrid()
# Build mapping from node ID to index
sorted_node_ids = sorted(self.grid.nodes.keys())
node_id_to_idx = {nid: i for i, nid in enumerate(sorted_node_ids)}
n_nodes = self.grid.n_nodes
n_layers = self.stratigraphy.n_layers
# Create points - nodes at each layer surface
# Surfaces: top of layer 1, bottom of layer 1/top of layer 2, ..., bottom of last layer
n_surfaces = n_layers + 1
points = vtk.vtkPoints()
# Point ID mapping: [surface_idx, node_idx] -> vtk_point_id
point_id_map = np.zeros((n_surfaces, n_nodes), dtype=np.int32)
vtk_pt_id = 0
for surf_idx in range(n_surfaces):
for node_idx, node_id in enumerate(sorted_node_ids):
node = self.grid.nodes[node_id]
if surf_idx == 0:
# Top surface (ground surface / top of layer 1)
z = float(self.stratigraphy.top_elev[node_idx, 0])
else:
# Bottom of layer surf_idx (which is top of layer surf_idx+1)
z = float(self.stratigraphy.bottom_elev[node_idx, surf_idx - 1])
points.InsertNextPoint(node.x, node.y, z)
point_id_map[surf_idx, node_idx] = vtk_pt_id
vtk_pt_id += 1
vtk_grid.SetPoints(points)
# Create cells - one cell per element per layer
layer_data = []
for layer in range(n_layers):
top_surf = layer
bot_surf = layer + 1
for elem in self.grid.iter_elements():
# Get node indices for this element
node_indices = [node_id_to_idx[vid] for vid in elem.vertices]
if elem.is_triangle:
# Wedge (triangular prism)
cell = vtk.vtkWedge()
# Bottom triangle (counterclockwise looking up)
cell.GetPointIds().SetId(0, point_id_map[bot_surf, node_indices[0]])
cell.GetPointIds().SetId(1, point_id_map[bot_surf, node_indices[1]])
cell.GetPointIds().SetId(2, point_id_map[bot_surf, node_indices[2]])
# Top triangle
cell.GetPointIds().SetId(3, point_id_map[top_surf, node_indices[0]])
cell.GetPointIds().SetId(4, point_id_map[top_surf, node_indices[1]])
cell.GetPointIds().SetId(5, point_id_map[top_surf, node_indices[2]])
else:
# Hexahedron
cell = vtk.vtkHexahedron()
# Bottom quad (counterclockwise looking up)
cell.GetPointIds().SetId(0, point_id_map[bot_surf, node_indices[0]])
cell.GetPointIds().SetId(1, point_id_map[bot_surf, node_indices[1]])
cell.GetPointIds().SetId(2, point_id_map[bot_surf, node_indices[2]])
cell.GetPointIds().SetId(3, point_id_map[bot_surf, node_indices[3]])
# Top quad
cell.GetPointIds().SetId(4, point_id_map[top_surf, node_indices[0]])
cell.GetPointIds().SetId(5, point_id_map[top_surf, node_indices[1]])
cell.GetPointIds().SetId(6, point_id_map[top_surf, node_indices[2]])
cell.GetPointIds().SetId(7, point_id_map[top_surf, node_indices[3]])
vtk_grid.InsertNextCell(cell.GetCellType(), cell.GetPointIds())
layer_data.append(layer + 1)
# Add layer data as cell attribute
try:
from vtk.util.numpy_support import numpy_to_vtk
layer_np = np.array(layer_data, dtype=np.int32)
layer_array = numpy_to_vtk(layer_np, deep=True)
except (ImportError, ModuleNotFoundError):
layer_array = vtk.vtkIntArray()
layer_array.SetNumberOfValues(len(layer_data))
for i, val in enumerate(layer_data):
layer_array.SetValue(i, val)
layer_array.SetName("layer")
vtk_grid.GetCellData().AddArray(layer_array)
return vtk_grid
[docs]
def add_node_scalar(
self,
vtk_grid: vtk.vtkUnstructuredGrid,
name: str,
values: NDArray[np.float64],
) -> None:
"""
Add scalar data to mesh nodes.
Args:
vtk_grid: VTK grid to add data to
name: Scalar array name
values: Scalar values (one per node)
"""
import vtk
array = vtk.vtkDoubleArray()
array.SetName(name)
array.SetNumberOfTuples(len(values))
for i, val in enumerate(values):
array.SetValue(i, float(val))
vtk_grid.GetPointData().AddArray(array)
[docs]
def add_cell_scalar(
self,
vtk_grid: vtk.vtkUnstructuredGrid,
name: str,
values: NDArray[np.float64],
) -> None:
"""
Add scalar data to mesh cells.
Args:
vtk_grid: VTK grid to add data to
name: Scalar array name
values: Scalar values (one per cell)
"""
import vtk
array = vtk.vtkDoubleArray()
array.SetName(name)
array.SetNumberOfTuples(len(values))
for i, val in enumerate(values):
array.SetValue(i, float(val))
vtk_grid.GetCellData().AddArray(array)
[docs]
def export_vtu(
self,
output_path: Path | str,
mode: Literal["2d", "3d"] = "2d",
node_scalars: dict[str, NDArray[np.float64]] | None = None,
cell_scalars: dict[str, NDArray[np.float64]] | None = None,
) -> None:
"""
Export mesh to VTU format (XML-based VTK).
Args:
output_path: Output file path (.vtu)
mode: '2d' for surface mesh, '3d' for volumetric mesh
node_scalars: Dict of name -> values for node data
cell_scalars: Dict of name -> values for cell data
"""
import vtk
output_path = Path(output_path)
# Create mesh
if mode == "3d":
vtk_grid = self.create_3d_mesh()
else:
vtk_grid = self.create_2d_mesh()
# Add scalar data
if node_scalars:
for name, values in node_scalars.items():
self.add_node_scalar(vtk_grid, name, values)
if cell_scalars:
for name, values in cell_scalars.items():
self.add_cell_scalar(vtk_grid, name, values)
# Write file
writer = vtk.vtkXMLUnstructuredGridWriter()
writer.SetFileName(str(output_path))
writer.SetInputData(vtk_grid)
writer.Write()
[docs]
def export_vtk(
self,
output_path: Path | str,
mode: Literal["2d", "3d"] = "2d",
node_scalars: dict[str, NDArray[np.float64]] | None = None,
cell_scalars: dict[str, NDArray[np.float64]] | None = None,
) -> None:
"""
Export mesh to legacy VTK format.
Args:
output_path: Output file path (.vtk)
mode: '2d' for surface mesh, '3d' for volumetric mesh
node_scalars: Dict of name -> values for node data
cell_scalars: Dict of name -> values for cell data
"""
import vtk
output_path = Path(output_path)
# Create mesh
if mode == "3d":
vtk_grid = self.create_3d_mesh()
else:
vtk_grid = self.create_2d_mesh()
# Add scalar data
if node_scalars:
for name, values in node_scalars.items():
self.add_node_scalar(vtk_grid, name, values)
if cell_scalars:
for name, values in cell_scalars.items():
self.add_cell_scalar(vtk_grid, name, values)
# Write file
writer = vtk.vtkUnstructuredGridWriter()
writer.SetFileName(str(output_path))
writer.SetInputData(vtk_grid)
writer.Write()
[docs]
def to_pyvista_3d(
self,
node_scalars: dict[str, NDArray[np.float64]] | None = None,
cell_scalars: dict[str, NDArray[np.float64]] | None = None,
) -> pv.UnstructuredGrid:
"""
Create a PyVista UnstructuredGrid from mesh and stratigraphy.
This method converts the IWFM mesh and stratigraphy to a PyVista
UnstructuredGrid for use in interactive 3D visualization. The
resulting mesh can be used with PyVista plotting functions or
the Trame web visualization framework.
Parameters
----------
node_scalars : dict[str, NDArray], optional
Dictionary of scalar arrays to add to mesh nodes.
Keys are array names, values are 1D arrays with one value
per node (for 2D) or per node-surface point (for 3D).
cell_scalars : dict[str, NDArray], optional
Dictionary of scalar arrays to add to mesh cells.
Keys are array names, values are 1D arrays with one value
per cell.
Returns
-------
pv.UnstructuredGrid
PyVista UnstructuredGrid with 3D volumetric mesh if
stratigraphy is available, otherwise 2D surface mesh.
Raises
------
ImportError
If PyVista is not installed.
Examples
--------
Create a 3D mesh for visualization:
>>> exporter = VTKExporter(grid=grid, stratigraphy=strat)
>>> pv_mesh = exporter.to_pyvista_3d()
>>> pv_mesh.plot()
Add scalar data:
>>> kh_values = np.random.rand(n_cells)
>>> pv_mesh = exporter.to_pyvista_3d(cell_scalars={"Kh": kh_values})
>>> pv_mesh.plot(scalars="Kh", cmap="viridis")
"""
try:
import pyvista as pv # noqa: F401
except ImportError as e:
raise ImportError(
"PyVista is required for this method. Install with: pip install pyvista"
) from e
if self.stratigraphy is None:
return self._to_pyvista_2d(node_scalars, cell_scalars)
return self._to_pyvista_3d_impl(node_scalars, cell_scalars)
def _to_pyvista_2d(
self,
node_scalars: dict[str, NDArray[np.float64]] | None = None,
cell_scalars: dict[str, NDArray[np.float64]] | None = None,
) -> pv.UnstructuredGrid:
"""Create a 2D PyVista mesh."""
import pyvista as pv
# Build node index mapping
sorted_node_ids = sorted(self.grid.nodes.keys())
node_id_to_idx = {nid: i for i, nid in enumerate(sorted_node_ids)}
# Create points array
points = np.zeros((len(sorted_node_ids), 3))
for i, nid in enumerate(sorted_node_ids):
node = self.grid.nodes[nid]
points[i] = [node.x, node.y, 0.0]
# Build cells
cells_list: list[int] = []
cell_types_list: list[int] = []
for elem in self.grid.iter_elements():
vertex_indices = [node_id_to_idx[vid] for vid in elem.vertices]
cells_list.append(len(vertex_indices))
cells_list.extend(vertex_indices)
if elem.is_triangle:
cell_types_list.append(pv.CellType.TRIANGLE)
else:
cell_types_list.append(pv.CellType.QUAD)
cells_arr = np.array(cells_list)
cell_types_arr = np.array(cell_types_list)
mesh = pv.UnstructuredGrid(cells_arr, cell_types_arr, points)
# Add element IDs
elem_ids = [elem.id for elem in self.grid.iter_elements()]
mesh.cell_data["element_id"] = np.array(elem_ids)
# Add custom scalars
if node_scalars:
for name, values in node_scalars.items():
mesh.point_data[name] = values
if cell_scalars:
for name, values in cell_scalars.items():
mesh.cell_data[name] = values
return mesh
def _to_pyvista_3d_impl(
self,
node_scalars: dict[str, NDArray[np.float64]] | None = None,
cell_scalars: dict[str, NDArray[np.float64]] | None = None,
) -> pv.UnstructuredGrid:
"""Create a 3D PyVista mesh with stratigraphy."""
import pyvista as pv
assert self.stratigraphy is not None # Caller guarantees this
# Build node index mapping
sorted_node_ids = sorted(self.grid.nodes.keys())
node_id_to_idx = {nid: i for i, nid in enumerate(sorted_node_ids)}
n_nodes = len(sorted_node_ids)
n_layers = self.stratigraphy.n_layers
# Create points for all layer surfaces
n_surfaces = n_layers + 1
n_total_points = n_nodes * n_surfaces
points = np.zeros((n_total_points, 3))
for surf_idx in range(n_surfaces):
for node_idx, node_id in enumerate(sorted_node_ids):
node = self.grid.nodes[node_id]
point_idx = surf_idx * n_nodes + node_idx
points[point_idx, 0] = node.x
points[point_idx, 1] = node.y
if surf_idx == 0:
# Top surface (ground surface)
points[point_idx, 2] = float(self.stratigraphy.top_elev[node_idx, 0])
else:
# Bottom of layer surf_idx-1
points[point_idx, 2] = float(
self.stratigraphy.bottom_elev[node_idx, surf_idx - 1]
)
# Build cells for each element in each layer
cells_list: list[int] = []
cell_types_list: list[int] = []
layer_data = []
element_ids = []
for layer in range(n_layers):
top_surf_offset = layer * n_nodes
bot_surf_offset = (layer + 1) * n_nodes
for elem in self.grid.iter_elements():
node_indices = [node_id_to_idx[vid] for vid in elem.vertices]
if elem.is_triangle:
# Wedge (triangular prism)
cell_types_list.append(pv.CellType.WEDGE)
cells_list.append(6)
# Bottom (layer+1 surface)
cells_list.extend(
[
bot_surf_offset + node_indices[0],
bot_surf_offset + node_indices[1],
bot_surf_offset + node_indices[2],
]
)
# Top (layer surface)
cells_list.extend(
[
top_surf_offset + node_indices[0],
top_surf_offset + node_indices[1],
top_surf_offset + node_indices[2],
]
)
else:
# Hexahedron
cell_types_list.append(pv.CellType.HEXAHEDRON)
cells_list.append(8)
# Bottom quad
cells_list.extend(
[
bot_surf_offset + node_indices[0],
bot_surf_offset + node_indices[1],
bot_surf_offset + node_indices[2],
bot_surf_offset + node_indices[3],
]
)
# Top quad
cells_list.extend(
[
top_surf_offset + node_indices[0],
top_surf_offset + node_indices[1],
top_surf_offset + node_indices[2],
top_surf_offset + node_indices[3],
]
)
layer_data.append(layer + 1) # 1-indexed
element_ids.append(elem.id)
cells_arr = np.array(cells_list)
cell_types_arr = np.array(cell_types_list)
mesh = pv.UnstructuredGrid(cells_arr, cell_types_arr, points)
# Add standard cell data
mesh.cell_data["layer"] = np.array(layer_data)
mesh.cell_data["element_id"] = np.array(element_ids)
# Add custom scalars
if node_scalars:
for name, values in node_scalars.items():
mesh.point_data[name] = values
if cell_scalars:
for name, values in cell_scalars.items():
mesh.cell_data[name] = values
return mesh