Source code for pyiwfm.visualization.plot_mesh

"""Mesh and spatial plotting functions for IWFM models."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import matplotlib

matplotlib.use("Agg")

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402
from matplotlib.axes import Axes  # noqa: E402
from matplotlib.figure import Figure  # noqa: E402
from numpy.typing import NDArray  # noqa: E402

from pyiwfm.visualization._plot_utils import (  # noqa: E402
    SPATIAL_STYLE,
    _format_thousands,
    _with_style,
)

if TYPE_CHECKING:
    from pyiwfm.components.lake import AppLake
    from pyiwfm.components.stream import AppStream
    from pyiwfm.core.mesh import AppGrid


[docs] @_with_style(SPATIAL_STYLE) def plot_mesh( grid: AppGrid, ax: Axes | None = None, show_edges: bool = True, show_node_ids: bool = False, show_element_ids: bool = False, edge_color: str = "black", edge_width: float = 0.5, fill_color: str = "lightblue", alpha: float = 0.3, figsize: tuple[float, float] = (10, 8), ) -> tuple[Figure, Axes]: """ Plot the mesh with elements and optional annotations. Args: grid: Model mesh ax: Existing axes to plot on (creates new if None) show_edges: Show element edges show_node_ids: Label nodes with their IDs show_element_ids: Label elements with their IDs edge_color: Color for element edges edge_width: Width of edge lines fill_color: Fill color for elements alpha: Transparency of element fill figsize: Figure size in inches Returns: Tuple of (Figure, Axes) """ from matplotlib.collections import PolyCollection if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # type: ignore[assignment] # Build element polygons polygons = [] for elem in grid.iter_elements(): coords = [] for vid in elem.vertices: node = grid.nodes[vid] coords.append((node.x, node.y)) polygons.append(coords) # Create polygon collection collection = PolyCollection( polygons, edgecolors=edge_color if show_edges else "none", facecolors=fill_color, linewidths=edge_width, alpha=alpha, ) ax.add_collection(collection) # Add node labels if show_node_ids: for node in grid.iter_nodes(): ax.annotate( str(node.id), (node.x, node.y), fontsize=8, ha="center", va="center", ) # Add element labels if show_element_ids: for elem in grid.iter_elements(): # Calculate centroid x_coords = [grid.nodes[vid].x for vid in elem.vertices] y_coords = [grid.nodes[vid].y for vid in elem.vertices] cx = sum(x_coords) / len(x_coords) cy = sum(y_coords) / len(y_coords) ax.annotate( str(elem.id), (cx, cy), fontsize=8, ha="center", va="center", color="red", ) ax.autoscale_view() ax.set_aspect("equal") ax.set_xlabel("X") ax.set_ylabel("Y") _format_thousands(ax) return fig, ax
[docs] @_with_style(SPATIAL_STYLE) def plot_nodes( grid: AppGrid, ax: Axes | None = None, highlight_boundary: bool = False, marker_size: float = 20, color: str = "blue", boundary_color: str = "red", figsize: tuple[float, float] = (10, 8), ) -> tuple[Figure, Axes]: """ Plot mesh nodes as points. Args: grid: Model mesh ax: Existing axes to plot on highlight_boundary: Use different color for boundary nodes marker_size: Size of node markers color: Color for interior nodes boundary_color: Color for boundary nodes figsize: Figure size in inches Returns: Tuple of (Figure, Axes) """ if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # type: ignore[assignment] # Collect node coordinates interior_x, interior_y = [], [] boundary_x, boundary_y = [], [] for node in grid.iter_nodes(): if highlight_boundary and node.is_boundary: boundary_x.append(node.x) boundary_y.append(node.y) else: interior_x.append(node.x) interior_y.append(node.y) # Plot nodes if interior_x: ax.scatter(interior_x, interior_y, s=marker_size, c=color, label="Interior") if highlight_boundary and boundary_x: ax.scatter(boundary_x, boundary_y, s=marker_size, c=boundary_color, label="Boundary") fig.legend(loc="outside right upper") ax.set_aspect("equal") ax.set_xlabel("X") ax.set_ylabel("Y") _format_thousands(ax) return fig, ax
[docs] @_with_style(SPATIAL_STYLE) def plot_elements( grid: AppGrid, ax: Axes | None = None, color_by: Literal["subregion", "area", "none"] = "none", cmap: str = "viridis", show_colorbar: bool = True, edge_color: str = "black", edge_width: float = 0.5, alpha: float = 0.7, figsize: tuple[float, float] = (10, 8), ) -> tuple[Figure, Axes]: """ Plot mesh elements with optional coloring by attribute. Args: grid: Model mesh ax: Existing axes to plot on color_by: Attribute to color elements by cmap: Colormap name show_colorbar: Show colorbar for colored plots edge_color: Color for element edges edge_width: Width of edge lines alpha: Transparency of element fill figsize: Figure size in inches Returns: Tuple of (Figure, Axes) """ import matplotlib.colors as mcolors from matplotlib.collections import PolyCollection if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # type: ignore[assignment] # Build element polygons and get color values polygons = [] values_list: list[float] = [] for elem in grid.iter_elements(): coords = [] for vid in elem.vertices: node = grid.nodes[vid] coords.append((node.x, node.y)) polygons.append(coords) if color_by == "subregion": values_list.append(float(elem.subregion)) elif color_by == "area": values_list.append(elem.area) else: values_list.append(0.0) # Create polygon collection if color_by == "subregion": # Discrete coloring with legend for categorical subregion values from matplotlib.patches import Patch values = np.array(values_list) unique_vals = np.unique(values) colormap = plt.get_cmap(cmap) n_unique = max(len(unique_vals), 1) val_to_color = {v: colormap(i / max(n_unique - 1, 1)) for i, v in enumerate(unique_vals)} face_colors = [val_to_color[v] for v in values] collection = PolyCollection( polygons, facecolors=face_colors, edgecolors=edge_color, linewidths=edge_width, alpha=alpha, ) if show_colorbar: legend_patches = [ Patch( facecolor=val_to_color[v], edgecolor=edge_color, alpha=alpha, label=f"Subregion {int(v)}", ) for v in unique_vals ] fig.legend(handles=legend_patches, loc="outside right upper") elif color_by != "none": values = np.array(values_list) norm = mcolors.Normalize(vmin=values.min(), vmax=values.max()) colormap = plt.get_cmap(cmap) collection = PolyCollection( polygons, array=values, cmap=colormap, norm=norm, edgecolors=edge_color, linewidths=edge_width, alpha=alpha, ) if show_colorbar: cbar = fig.colorbar(collection, ax=ax) cbar.set_label(color_by.capitalize(), fontsize=10) else: collection = PolyCollection( polygons, facecolors="lightblue", edgecolors=edge_color, linewidths=edge_width, alpha=alpha, ) ax.add_collection(collection) ax.autoscale_view() ax.set_aspect("equal") ax.set_xlabel("X") ax.set_ylabel("Y") _format_thousands(ax) return fig, ax
def _subdivide_quads( elem_conn: list[list[int]], x: NDArray[np.float64], y: NDArray[np.float64], values: NDArray[np.float64], n: int, ) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64], NDArray[np.int64]]: """ Subdivide quad elements using bilinear FE shape functions. Each quad is subdivided into an n x n grid of points using bilinear interpolation, then triangulated. Triangle elements pass through unchanged. Fully vectorized -- no Python loop over elements. Parameters ---------- elem_conn : list of list of int Element connectivity (each inner list has 3 or 4 vertex indices). x, y : ndarray Node coordinates indexed by the vertex indices in elem_conn. values : ndarray Scalar values at each node. n : int Subdivision level (n x n points per quad, must be >= 2). Returns ------- sub_x, sub_y, sub_values : ndarray Coordinates and values at the subdivided points. sub_triangles : ndarray of shape (n_tri, 3) Triangle connectivity into the sub_x/sub_y arrays. """ # Separate triangles from quads tri_conn = [v for v in elem_conn if len(v) == 3] quad_conn = [v for v in elem_conn if len(v) == 4] all_x: list[NDArray[np.float64]] = [] all_y: list[NDArray[np.float64]] = [] all_v: list[NDArray[np.float64]] = [] all_tri: list[NDArray[np.int64]] = [] offset = 0 # --- Process quads --- if quad_conn: quad_arr = np.array(quad_conn) # (n_quads, 4) n_quads = quad_arr.shape[0] # Precompute reference grid and bilinear shape functions xi_1d = np.linspace(-1, 1, n) xi, eta = np.meshgrid(xi_1d, xi_1d) xi_flat = xi.ravel() eta_flat = eta.ravel() # Shape function matrix: (n*n, 4) shape_funcs = 0.25 * np.column_stack( [ (1 - xi_flat) * (1 - eta_flat), (1 + xi_flat) * (1 - eta_flat), (1 + xi_flat) * (1 + eta_flat), (1 - xi_flat) * (1 + eta_flat), ] ) # Precompute triangle template for the n x n structured grid row, col = np.mgrid[: n - 1, : n - 1] i0 = (row * n + col).ravel() i1 = i0 + 1 i2 = i0 + n i3 = i2 + 1 tri_template = np.column_stack([i0, i1, i3, i0, i3, i2]).reshape(-1, 3) # Batch map all quads via matrix multiply vx_all = x[quad_arr] # (n_quads, 4) vy_all = y[quad_arr] vv_all = values[quad_arr] sub_qx = (shape_funcs @ vx_all.T).T # (n_quads, n*n) sub_qy = (shape_funcs @ vy_all.T).T sub_qv = (shape_funcs @ vv_all.T).T # Build triangle indices with vectorized offsets n_pts = n * n offsets = np.arange(n_quads, dtype=np.int64) * n_pts + offset quad_tris = offsets[:, None, None] + tri_template[None, :, :] all_x.append(sub_qx.ravel()) all_y.append(sub_qy.ravel()) all_v.append(sub_qv.ravel()) all_tri.append(quad_tris.reshape(-1, 3)) offset += n_quads * n_pts # --- Process triangles (pass through) --- if tri_conn: tri_arr = np.array(tri_conn) # (n_tris, 3) all_x.append(x[tri_arr].ravel()) all_y.append(y[tri_arr].ravel()) all_v.append(values[tri_arr].ravel()) n_tris = tri_arr.shape[0] tri_indices = np.arange(n_tris * 3, dtype=np.int64).reshape(n_tris, 3) + offset all_tri.append(tri_indices) return ( np.concatenate(all_x), np.concatenate(all_y), np.concatenate(all_v), np.concatenate(all_tri).astype(np.int64), )
[docs] @_with_style(SPATIAL_STYLE) def plot_scalar_field( grid: AppGrid, values: NDArray[np.float64], field_type: Literal["node", "cell"] = "node", ax: Axes | None = None, cmap: str = "viridis", show_colorbar: bool = True, vmin: float | None = None, vmax: float | None = None, show_mesh: bool = True, edge_color: str = "gray", edge_width: float = 0.3, n_subdiv: int = 4, figsize: tuple[float, float] = (10, 8), ) -> tuple[Figure, Axes]: """ Plot scalar field values on the mesh. Args: grid: Model mesh values: Scalar values (one per node or cell) field_type: 'node' for node values, 'cell' for cell values ax: Existing axes to plot on cmap: Colormap name show_colorbar: Show colorbar vmin: Minimum value for colormap vmax: Maximum value for colormap show_mesh: Show mesh edges edge_color: Color for mesh edges edge_width: Width of mesh edges n_subdiv: Subdivision level for bilinear quad interpolation (>=2 enables FE subdivision; 1 uses legacy diagonal-split triangulation) figsize: Figure size in inches Returns: Tuple of (Figure, Axes) """ import matplotlib.colors as mcolors from matplotlib.collections import PolyCollection from matplotlib.tri import Triangulation if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # type: ignore[assignment] if vmin is None: vmin = values.min() if vmax is None: vmax = values.max() norm = mcolors.Normalize(vmin=vmin, vmax=vmax) if field_type == "node": # Use triangulation for smooth interpolation # Build node coordinate arrays sorted_node_ids = sorted(grid.nodes.keys()) node_id_to_idx = {nid: i for i, nid in enumerate(sorted_node_ids)} x = np.array([grid.nodes[nid].x for nid in sorted_node_ids]) y = np.array([grid.nodes[nid].y for nid in sorted_node_ids]) # Build element connectivity elem_conn: list[list[int]] = [] for elem in grid.iter_elements(): verts = [node_id_to_idx[vid] for vid in elem.vertices] elem_conn.append(verts) has_quads = any(len(v) == 4 for v in elem_conn) if has_quads and n_subdiv > 1: # Bilinear FE subdivision for quads sub_x, sub_y, sub_v, sub_tri = _subdivide_quads( elem_conn, x, y, values, n_subdiv, ) triang = Triangulation(sub_x, sub_y, sub_tri) tcf = ax.tripcolor(triang, sub_v, cmap=cmap, norm=norm, shading="gouraud") else: # Legacy 2-triangle diagonal split triangles_list: list[list[int]] = [] for verts in elem_conn: if len(verts) == 3: triangles_list.append(verts) else: triangles_list.append([verts[0], verts[1], verts[2]]) triangles_list.append([verts[0], verts[2], verts[3]]) triangles = np.array(triangles_list) triang = Triangulation(x, y, triangles) tcf = ax.tripcolor(triang, values, cmap=cmap, norm=norm, shading="gouraud") if show_mesh: node_xy = np.column_stack([x, y]) mesh_polys = node_xy[np.array(elem_conn)] mesh_collection = PolyCollection( mesh_polys, edgecolors=edge_color, facecolors="none", linewidths=edge_width, ) ax.add_collection(mesh_collection) else: # cell values # Build element polygons polygons = [] for elem in grid.iter_elements(): coords = [] for vid in elem.vertices: node = grid.nodes[vid] coords.append((node.x, node.y)) polygons.append(coords) collection = PolyCollection( polygons, array=values, cmap=cmap, norm=norm, edgecolors=edge_color if show_mesh else "none", linewidths=edge_width, ) ax.add_collection(collection) ax.autoscale_view() tcf = collection # type: ignore[assignment] if show_colorbar: fig.colorbar(tcf, ax=ax) ax.set_aspect("equal") ax.set_xlabel("X") ax.set_ylabel("Y") _format_thousands(ax) return fig, ax
[docs] @_with_style(SPATIAL_STYLE) def plot_streams( streams: AppStream, ax: Axes | None = None, show_nodes: bool = False, line_color: str = "blue", line_width: float = 2.0, node_color: str = "blue", node_size: float = 30, figsize: tuple[float, float] = (10, 8), ) -> tuple[Figure, Axes]: """ Plot stream network. Args: streams: Stream network ax: Existing axes to plot on show_nodes: Show stream node markers line_color: Color for stream lines line_width: Width of stream lines node_color: Color for stream nodes node_size: Size of stream node markers figsize: Figure size in inches Returns: Tuple of (Figure, Axes) """ if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # type: ignore[assignment] # Plot reaches as lines for reach in streams.iter_reaches(): x_coords = [] y_coords = [] for nid in reach.nodes: if nid in streams.nodes: node = streams.nodes[nid] x_coords.append(node.x) y_coords.append(node.y) if len(x_coords) >= 2: ax.plot(x_coords, y_coords, color=line_color, linewidth=line_width) # Plot stream nodes if show_nodes: x = [node.x for node in streams.nodes.values()] y = [node.y for node in streams.nodes.values()] ax.scatter(x, y, s=node_size, c=node_color, zorder=5) ax.set_aspect("equal") ax.set_xlabel("X") ax.set_ylabel("Y") _format_thousands(ax) return fig, ax
[docs] @_with_style(SPATIAL_STYLE) def plot_lakes( lakes: AppLake, grid: AppGrid, ax: Axes | None = None, fill_color: str = "cyan", edge_color: str = "blue", edge_width: float = 1.5, alpha: float = 0.5, show_labels: bool = True, label_fontsize: float = 9, cmap: str | None = None, figsize: tuple[float, float] = (10, 8), ) -> tuple[Figure, Axes]: """ Plot lake elements on the mesh. Parameters ---------- lakes : AppLake Lake component containing lake definitions and element assignments. grid : AppGrid Model mesh used to look up element vertex coordinates. ax : Axes, optional Existing axes to plot on. Creates new figure if None. fill_color : str, default "cyan" Fill color for lake elements (used when *cmap* is None). edge_color : str, default "blue" Edge color for lake element polygons. edge_width : float, default 1.5 Width of lake element edges. alpha : float, default 0.5 Transparency of lake element fill. show_labels : bool, default True Show lake name labels at the centroid of each lake. label_fontsize : float, default 9 Font size for lake labels. cmap : str, optional If provided, color each lake with a different color from this colormap instead of using *fill_color*. figsize : tuple, default (10, 8) Figure size in inches. Returns ------- tuple (Figure, Axes) matplotlib objects. """ from matplotlib.patches import Polygon as MplPolygon if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # type: ignore[assignment] lake_list = list(lakes.iter_lakes()) if cmap is not None: colormap = plt.get_cmap(cmap) n_lakes = max(len(lake_list), 1) for idx, lake in enumerate(lake_list): color = colormap(idx / n_lakes) if cmap is not None else fill_color lake_elems = lakes.get_elements_for_lake(lake.id) all_x: list[float] = [] all_y: list[float] = [] for le in lake_elems: if le.element_id not in grid.elements: continue elem = grid.elements[le.element_id] verts = [(grid.nodes[vid].x, grid.nodes[vid].y) for vid in elem.vertices] patch = MplPolygon( verts, facecolor=color, edgecolor=edge_color, linewidth=edge_width, alpha=alpha, ) ax.add_patch(patch) for vx, vy in verts: all_x.append(vx) all_y.append(vy) if show_labels and all_x: cx = sum(all_x) / len(all_x) cy = sum(all_y) / len(all_y) label = lake.name or f"Lake {lake.id}" ax.text( cx, cy, label, ha="center", va="center", fontsize=label_fontsize, fontweight="bold", zorder=10, ) ax.autoscale_view() ax.set_aspect("equal") ax.set_xlabel("X") ax.set_ylabel("Y") _format_thousands(ax) return fig, ax
[docs] @_with_style(SPATIAL_STYLE) def plot_boundary( grid: AppGrid, ax: Axes | None = None, line_color: str = "black", line_width: float = 2.0, fill: bool = False, fill_color: str = "lightgray", alpha: float = 0.3, figsize: tuple[float, float] = (10, 8), ) -> tuple[Figure, Axes]: """ Plot model boundary. Args: grid: Model mesh ax: Existing axes to plot on line_color: Color for boundary line line_width: Width of boundary line fill: Fill the boundary polygon fill_color: Fill color alpha: Fill transparency figsize: Figure size in inches Returns: Tuple of (Figure, Axes) """ from matplotlib.patches import Polygon as MplPolygon if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # type: ignore[assignment] # Extract boundary nodes (ordered) boundary_nodes = [n for n in grid.iter_nodes() if n.is_boundary] if not boundary_nodes: # Fallback: use convex hull of all nodes from scipy.spatial import ConvexHull points = np.array([[n.x, n.y] for n in grid.iter_nodes()]) hull = ConvexHull(points) hull_points = points[hull.vertices] if fill: patch = MplPolygon( hull_points, facecolor=fill_color, edgecolor=line_color, linewidth=line_width, alpha=alpha, ) ax.add_patch(patch) else: # Close the polygon x = np.append(hull_points[:, 0], hull_points[0, 0]) y = np.append(hull_points[:, 1], hull_points[0, 1]) ax.plot(x, y, color=line_color, linewidth=line_width) else: # Use boundary nodes - order them by angle from centroid cx = sum(n.x for n in boundary_nodes) / len(boundary_nodes) cy = sum(n.y for n in boundary_nodes) / len(boundary_nodes) def angle(n: Any) -> Any: return np.arctan2(n.y - cy, n.x - cx) sorted_nodes = sorted(boundary_nodes, key=angle) coords = [(n.x, n.y) for n in sorted_nodes] if fill: patch = MplPolygon( coords, facecolor=fill_color, edgecolor=line_color, linewidth=line_width, alpha=alpha, ) ax.add_patch(patch) else: x_plot = [c[0] for c in coords] + [coords[0][0]] y_plot = [c[1] for c in coords] + [coords[0][1]] ax.plot(x_plot, y_plot, color=line_color, linewidth=line_width) ax.autoscale_view() ax.set_aspect("equal") ax.set_xlabel("X") ax.set_ylabel("Y") _format_thousands(ax) return fig, ax
[docs] class MeshPlotter: """ High-level class for creating mesh visualizations. This class provides a convenient interface for creating multi-layer visualizations of IWFM model meshes. Attributes: grid: Model mesh streams: Stream network (optional) """
[docs] def __init__( self, grid: AppGrid, streams: AppStream | None = None, figsize: tuple[float, float] = (10, 8), ) -> None: """ Initialize the mesh plotter. Args: grid: Model mesh streams: Stream network (optional) figsize: Default figure size """ self.grid = grid self.streams = streams self.figsize = figsize self._fig: Figure | None = None self._ax: Axes | None = None
[docs] def plot_mesh( self, show_edges: bool = True, show_node_ids: bool = False, show_element_ids: bool = False, show_streams: bool = False, **kwargs: Any, ) -> tuple[Figure, Axes]: """ Plot the mesh with optional overlays. Args: show_edges: Show element edges show_node_ids: Label nodes with their IDs show_element_ids: Label elements with their IDs show_streams: Overlay stream network **kwargs: Additional arguments passed to plot_mesh Returns: Tuple of (Figure, Axes) """ fig, ax = plot_mesh( self.grid, show_edges=show_edges, show_node_ids=show_node_ids, show_element_ids=show_element_ids, figsize=self.figsize, **kwargs, ) if show_streams and self.streams is not None: plot_streams(self.streams, ax=ax) self._fig = fig self._ax = ax return fig, ax
[docs] def plot_composite( self, show_mesh: bool = True, show_streams: bool = False, node_values: NDArray[np.float64] | None = None, cell_values: NDArray[np.float64] | None = None, title: str | None = None, cmap: str = "viridis", **kwargs: Any, ) -> tuple[Figure, Axes]: """ Create a composite plot with multiple layers. Args: show_mesh: Show mesh edges show_streams: Overlay stream network node_values: Scalar values at nodes (optional) cell_values: Scalar values at cells (optional) title: Plot title cmap: Colormap for scalar values **kwargs: Additional arguments Returns: Tuple of (Figure, Axes) """ fig, ax = plt.subplots(figsize=self.figsize) # Plot scalar field if provided if node_values is not None: plot_scalar_field( self.grid, node_values, field_type="node", ax=ax, cmap=cmap, show_mesh=show_mesh, ) elif cell_values is not None: plot_scalar_field( self.grid, cell_values, field_type="cell", ax=ax, cmap=cmap, show_mesh=show_mesh, ) elif show_mesh: plot_mesh(self.grid, ax=ax) # Add streams if show_streams and self.streams is not None: plot_streams(self.streams, ax=ax) if title: ax.set_title(title) self._fig = fig self._ax = ax return fig, ax
[docs] def save( self, output_path: Path | str, dpi: int = 150, **kwargs: Any, ) -> None: """ Save the current figure to file. Args: output_path: Output file path dpi: Resolution in dots per inch **kwargs: Additional arguments passed to savefig """ if self._fig is None: # Create default plot if none exists self.plot_mesh() if self._fig is not None: self._fig.savefig(output_path, dpi=dpi, bbox_inches="tight", **kwargs)