Source code for physicsnemo.mesh.visualization.draw_mesh

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main entry point for mesh visualization with backend selection."""

from typing import TYPE_CHECKING, Any, Literal

import torch

from physicsnemo.core.version_check import check_version_spec

if TYPE_CHECKING:
    import matplotlib.axes
    import pyvista

    from physicsnemo.mesh.mesh import Mesh

# Check availability at module load (add new backends here)
BACKENDS_INSTALLED: dict[str, bool] = {
    name: check_version_spec(name) for name in ["matplotlib", "pyvista"]
}


[docs] def draw_mesh( mesh: "Mesh", backend: Literal["matplotlib", "pyvista", "auto"] = "auto", show: bool = True, point_scalars: None | torch.Tensor | str | tuple[str, ...] = None, cell_scalars: None | torch.Tensor | str | tuple[str, ...] = None, cmap: str = "viridis", vmin: float | None = None, vmax: float | None = None, alpha_points: float = 1.0, alpha_cells: float = 1.0, alpha_edges: float = 1.0, show_edges: bool = True, ax: Any = None, backend_options: dict[str, Any] | None = None, ) -> "matplotlib.axes.Axes | pyvista.Plotter": """Draw a mesh using matplotlib or PyVista backend. This is the main visualization function for Mesh objects. It automatically selects the appropriate backend based on spatial dimensions, or allows explicit backend specification. Parameters ---------- mesh : Mesh Mesh object to visualize. backend : {"auto", "matplotlib", "pyvista"} Visualization backend to use: - "auto": Automatically select based on n_spatial_dims (matplotlib for 0D/1D/2D, PyVista for 3D) - "matplotlib": Force matplotlib backend (supports 3D via mplot3d) - "pyvista": Force PyVista backend (requires n_spatial_dims <= 3) show : bool Whether to display the plot immediately (calls plt.show() or plotter.show()). If False, returns the plotter/axes for further customization before display. point_scalars : torch.Tensor or str or tuple[str, ...] or None, optional Scalar data to color points. Mutually exclusive with cell_scalars. Can be: - None: Points use neutral color (black) - torch.Tensor: Direct scalar values, shape (n_points,) or (n_points, ...) where trailing dimensions are L2-normed - str or tuple[str, ...]: Key to lookup in mesh.point_data cell_scalars : torch.Tensor or str or tuple[str, ...] or None, optional Scalar data to color cells. Mutually exclusive with point_scalars. Can be: - None: Cells use neutral color (lightblue if no scalars, lightgray if point_scalars active) - torch.Tensor: Direct scalar values, shape (n_cells,) or (n_cells, ...) where trailing dimensions are L2-normed - str or tuple[str, ...]: Key to lookup in mesh.cell_data cmap : str Colormap name for scalar visualization. vmin : float or None, optional Minimum value for colormap normalization. If None, uses data min. vmax : float or None, optional Maximum value for colormap normalization. If None, uses data max. alpha_points : float Opacity for points, range [0, 1]. alpha_cells : float Opacity for cells/faces, range [0, 1]. alpha_edges : float Opacity for cell edges, range [0, 1]. show_edges : bool Whether to draw cell edges. ax : matplotlib.axes.Axes or pyvista.Plotter, optional Existing canvas to draw on. For matplotlib, a matplotlib Axes; for PyVista, a pyvista Plotter. If ``None``, a new figure/plotter is created. Use this to overlay multiple meshes on the same scene. backend_options : dict[str, Any], optional Additional keyword arguments forwarded to the underlying visualization backend (e.g. PyVista's ``plotter.add_mesh()``). Returns ------- matplotlib.axes.Axes or pyvista.Plotter - matplotlib backend: matplotlib.axes.Axes object - PyVista backend: pyvista.Plotter object Raises ------ ValueError If both point_scalars and cell_scalars are specified, or if n_spatial_dims is not supported by the chosen backend, or if backend selection fails. ImportError If the requested backend is not installed. Examples -------- >>> # Draw mesh with automatic backend selection >>> mesh.draw() # doctest: +SKIP >>> >>> # Color cells by pressure data >>> mesh.draw(cell_scalars="pressure", cmap="coolwarm") # doctest: +SKIP >>> >>> # Color points by velocity magnitude (computing norm of vector field) >>> mesh.draw(point_scalars="velocity") # velocity is (n_points, 3) # doctest: +SKIP >>> >>> # Use nested TensorDict key >>> mesh.draw(cell_scalars=("flow", "temperature")) # doctest: +SKIP >>> >>> # Customize and display later >>> ax = mesh.draw(show=False, backend="matplotlib") # doctest: +SKIP >>> ax.set_title("My Mesh") # doctest: +SKIP >>> plt.show() # doctest: +SKIP """ ### Validate and process scalar data from physicsnemo.mesh.visualization._scalar_utils import ( validate_and_process_scalars, ) point_scalar_values, cell_scalar_values, active_scalar_source, scalar_label = ( validate_and_process_scalars( point_scalars=point_scalars, cell_scalars=cell_scalars, point_data=mesh.point_data, cell_data=mesh.cell_data, n_points=mesh.n_points, n_cells=mesh.n_cells, ) ) ### Validate spatial dimensions if mesh.n_spatial_dims > 3: raise ValueError( f"Visualization does not support {mesh.n_spatial_dims=}.\n" "Maximum spatial dimensions: 3." ) ### Determine and validate backend if backend == "auto": # Check that at least one backend is available if not any(BACKENDS_INSTALLED.values()): options = ", ".join(BACKENDS_INSTALLED) raise ImportError( f"No visualization backend available. Install one of: {options}" ) # Auto-select based on spatial dimensions with fallback if mesh.n_spatial_dims <= 2: # Prefer matplotlib for 0D/1D/2D backend = "matplotlib" if BACKENDS_INSTALLED["matplotlib"] else "pyvista" else: # Prefer pyvista for 3D backend = "pyvista" if BACKENDS_INSTALLED["pyvista"] else "matplotlib" elif backend in BACKENDS_INSTALLED: if not BACKENDS_INSTALLED[backend]: alternatives = [ n for n, ok in BACKENDS_INSTALLED.items() if ok and n != backend ] alt_hint = f" ({', '.join(alternatives)} available)" if alternatives else "" raise ImportError(f"{backend} is not installed{alt_hint}.") else: supported = ", ".join(repr(b) for b in BACKENDS_INSTALLED) raise ValueError( f"Unknown {backend=!r}. Supported backends: {supported}, 'auto'." ) # Track the resolved backend for warning checks below resolved_backend = backend ### Warn about unsupported options if backend_options and resolved_backend == "matplotlib": import warnings warnings.warn( "backend_options are only supported with the 'pyvista' backend and will be ignored.", stacklevel=2, ) if alpha_edges != 1.0 and resolved_backend == "pyvista": import warnings warnings.warn( "alpha_edges is not supported by the 'pyvista' backend and will be ignored.", stacklevel=2, ) ### Dispatch to backend if backend == "matplotlib": from physicsnemo.mesh.visualization._matplotlib_impl import draw_mesh_matplotlib return draw_mesh_matplotlib( mesh=mesh, point_scalar_values=point_scalar_values, cell_scalar_values=cell_scalar_values, active_scalar_source=active_scalar_source, scalar_label=scalar_label, show=show, cmap=cmap, vmin=vmin, vmax=vmax, alpha_points=alpha_points, alpha_cells=alpha_cells, alpha_edges=alpha_edges, show_edges=show_edges, ax=ax, ) elif backend == "pyvista": from physicsnemo.mesh.visualization._pyvista_impl import draw_mesh_pyvista return draw_mesh_pyvista( mesh=mesh, point_scalar_values=point_scalar_values, cell_scalar_values=cell_scalar_values, active_scalar_source=active_scalar_source, scalar_label=scalar_label, show=show, cmap=cmap, vmin=vmin, vmax=vmax, alpha_points=alpha_points, alpha_cells=alpha_cells, show_edges=show_edges, plotter=ax, **(backend_options or {}), ) else: raise AssertionError( f"Unreachable: {backend=!r} passed validation but has no dispatch." )