STL Geometry: Blood Flow in Intracranial Aneurysm
In this tutorial, you will import an STL file for a complicated geometry and use Modulus Sym’ SDF library to sample points on the surface and the interior of the STL and train the PINNs to predict flow in this complex geometry. In this tutorial you will learn the following:
How to import an STL file in Modulus Sym and sample points in the interior and on the surface of the geometry.
Fig. 110 Aneurysm STL file
This tutorial assumes that you have completed tutorial Introductory Example and have familiarized yourself with the basics of the Modulus Sym APIs. Additionally, to use the modules described in this tutorial, make sure your system satisfies the requirements for SDF library (system_requirements).
For the interior sampling to work, ensure that the STL geometry is watertight. This requirement is not necessary for sampling points on the surface.
All the python scripts for this problem can be found at examples/aneurysm/
.
This simulation, uses a no-slip boundary condition on the walls of the aneurysm \(u,v,w=0\). For the inlet, a parabolic flow where the flow goes in the normal direction of the inlet and has peak velocity 1.5, is used. The outlet has a zero pressure condition, \(p=0\). The kinematic viscosity of the fluid is \(0.025\) and the density is a constant \(1.0\).
In this tutorial, you will use Modulus Sym’ Tessellation
module to sample points
using a STL geometry. The module works similar to Modulus Sym’ geometry
module. Which means you can use PointwiseInteriorConstraint
and PointwiseBoundaryConstraint
to sample points in the interior and the boundary of the geometry and define
appropriate constraints. Separate STL files for each boundary of the
geometry and another watertight geometry for sampling points in the interior of the
geometry are required.
Importing the required packages
The list of required packages can be found below. Import Modulus Sym’
Tessellation
module to the sample points on the STL geometry.
# limitations under the License.
import os
import warnings
import torch
import numpy as np
from sympy import Symbol, sqrt, Max
import modulus.sym
from modulus.sym.hydra import to_absolute_path, instantiate_arch, ModulusConfig
from modulus.sym.solver import Solver
from modulus.sym.domain import Domain
from modulus.sym.domain.constraint import (
PointwiseBoundaryConstraint,
PointwiseInteriorConstraint,
IntegralBoundaryConstraint,
)
from modulus.sym.domain.validator import PointwiseValidator
from modulus.sym.domain.monitor import PointwiseMonitor
from modulus.sym.key import Key
from modulus.sym.eq.pdes.navier_stokes import NavierStokes
from modulus.sym.eq.pdes.basic import NormalDotVec
Using STL files to generate point clouds
Import the STL geometries using the Tessellation.from_stl()
function. This function takes in the path of the STL geometry as input.
You will need to specify the value of attribute airtight
as False
for the open surfaces (eg. boundary STL files).
Then these mesh objects can be used to create boundary or interior
constraints similar to tutorial Introductory Example using the PointwiseBoundaryConstraint
or
PointwiseInteriorConstraint
.
For this tutorial, you can normalize the geometry by scaling it and centering it about the origin (0, 0, 0). This will help in speeding up the training process.
The code to sample using STL geometry, define all these functions, boundary conditions is shown below.
def run(cfg: ModulusConfig) -> None:
# read stl files to make geometry
point_path = to_absolute_path("./stl_files")
inlet_mesh = Tessellation.from_stl(
point_path + "/aneurysm_inlet.stl", airtight=False
)
outlet_mesh = Tessellation.from_stl(
point_path + "/aneurysm_outlet.stl", airtight=False
)
noslip_mesh = Tessellation.from_stl(
point_path + "/aneurysm_noslip.stl", airtight=False
)
integral_mesh = Tessellation.from_stl(
point_path + "/aneurysm_integral.stl", airtight=False
)
interior_mesh = Tessellation.from_stl(
point_path + "/aneurysm_closed.stl", airtight=True
)
# params
nu = 0.025
inlet_vel = 1.5
# inlet velocity profile
def circular_parabola(x, y, z, center, normal, radius, max_vel):
centered_x = x - center[0]
centered_y = y - center[1]
centered_z = z - center[2]
distance = sqrt(centered_x**2 + centered_y**2 + centered_z**2)
parabola = max_vel * Max((1 - (distance / radius) ** 2), 0)
return normal[0] * parabola, normal[1] * parabola, normal[2] * parabola
# normalize meshes
def normalize_mesh(mesh, center, scale):
mesh = mesh.translate([-c for c in center])
mesh = mesh.scale(scale)
return mesh
# normalize invars
def normalize_invar(invar, center, scale, dims=2):
invar["x"] -= center[0]
invar["y"] -= center[1]
invar["z"] -= center[2]
invar["x"] *= scale
invar["y"] *= scale
invar["z"] *= scale
if "area" in invar.keys():
invar["area"] *= scale**dims
return invar
# scale and normalize mesh and openfoam data
center = (-18.40381048596882, -50.285383353981196, 12.848136936899031)
scale = 0.4
inlet_mesh = normalize_mesh(inlet_mesh, center, scale)
outlet_mesh = normalize_mesh(outlet_mesh, center, scale)
noslip_mesh = normalize_mesh(noslip_mesh, center, scale)
Defining the Equations, Networks and Nodes
This process is similar to other tutorials. In this problem you are only solving
for laminar flow, so you can use only NavierStokes
and
NormalDotVec
equations and define a network similar to
tutorial Introductory Example. The code to generate the Network and required nodes is shown below.
domain = Domain()
# make list of nodes to unroll graph on
ns = NavierStokes(nu=nu * scale, rho=1.0, dim=3, time=False)
normal_dot_vel = NormalDotVec(["u", "v", "w"])
flow_net = instantiate_arch(
input_keys=[Key("x"), Key("y"), Key("z")],
output_keys=[Key("u"), Key("v"), Key("w"), Key("p")],
cfg=cfg.arch.fully_connected,
)
nodes = (
ns.make_nodes()
+ normal_dot_vel.make_nodes()
+ [flow_net.make_node(name="flow_network")]
Setting up Domain and adding Constraints
Now that you have all the nodes and geometry elements defined, you can use the tesselated/mesh
objects to create boundary or interior constraints similar to tutorial Introductory Example using
the PointwiseBoundaryConstraint
or PointwiseInteriorConstraint
.
outlet_radius = np.sqrt(outlet_area / np.pi)
)
# add constraints to solver
# inlet
u, v, w = circular_parabola(
Symbol("x"),
Symbol("y"),
Symbol("z"),
center=inlet_center,
normal=inlet_normal,
radius=inlet_radius,
max_vel=inlet_vel,
)
inlet = PointwiseBoundaryConstraint(
nodes=nodes,
geometry=inlet_mesh,
outvar={"u": u, "v": v, "w": w},
batch_size=cfg.batch_size.inlet,
)
domain.add_constraint(inlet, "inlet")
# outlet
outlet = PointwiseBoundaryConstraint(
nodes=nodes,
geometry=outlet_mesh,
outvar={"p": 0},
batch_size=cfg.batch_size.outlet,
)
domain.add_constraint(outlet, "outlet")
# no slip
no_slip = PointwiseBoundaryConstraint(
nodes=nodes,
geometry=noslip_mesh,
outvar={"u": 0, "v": 0, "w": 0},
batch_size=cfg.batch_size.no_slip,
)
domain.add_constraint(no_slip, "no_slip")
# interior
interior = PointwiseInteriorConstraint(
nodes=nodes,
geometry=interior_mesh,
outvar={"continuity": 0, "momentum_x": 0, "momentum_y": 0, "momentum_z": 0},
batch_size=cfg.batch_size.interior,
)
domain.add_constraint(interior, "interior")
# Integral Continuity 1
integral_continuity = IntegralBoundaryConstraint(
nodes=nodes,
geometry=outlet_mesh,
outvar={"normal_dot_vel": 2.540},
batch_size=1,
integral_batch_size=cfg.batch_size.integral_continuity,
lambda_weighting={"normal_dot_vel": 0.1},
)
domain.add_constraint(integral_continuity, "integral_continuity_1")
# Integral Continuity 2
integral_continuity = IntegralBoundaryConstraint(
nodes=nodes,
geometry=integral_mesh,
outvar={"normal_dot_vel": -2.540},
batch_size=1,
integral_batch_size=cfg.batch_size.integral_continuity,
lambda_weighting={"normal_dot_vel": 0.1},
Adding Validators and Monitors
The process of adding validation data and monitors is similar to previous tutorials. This example uses the simulation from OpenFOAM for validating the Modulus Sym results. Also, you can create a monitor for pressure drop across the aneurysm to monitor the convergence and compare against OpenFOAM data. The code to generate the these domains is shown below.
domain.add_constraint(integral_continuity, "integral_continuity_1")
# add validation data
file_path = "./openfoam/aneurysm_parabolicInlet_sol0.csv"
if os.path.exists(to_absolute_path(file_path)):
mapping = {
"Points:0": "x",
"Points:1": "y",
"Points:2": "z",
"U:0": "u",
"U:1": "v",
"U:2": "w",
"p": "p",
}
openfoam_var = csv_to_dict(to_absolute_path(file_path), mapping)
openfoam_invar = {
key: value for key, value in openfoam_var.items() if key in ["x", "y", "z"]
}
openfoam_invar = normalize_invar(openfoam_invar, center, scale, dims=3)
openfoam_outvar = {
key: value
for key, value in openfoam_var.items()
if key in ["u", "v", "w", "p"]
}
openfoam_validator = PointwiseValidator(
nodes=nodes,
invar=openfoam_invar,
true_outvar=openfoam_outvar,
batch_size=4096,
)
domain.add_validator(openfoam_validator)
else:
warnings.warn(
f"Directory{file_path}does not exist. Will skip adding validators. Please download the additional files from NGC https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_sym_examples_supplemental_materials"
)
# add pressure monitor
pressure_monitor = PointwiseMonitor(
inlet_mesh.sample_boundary(16),
output_names=["p"],
metrics={"pressure_drop": lambda var: torch.mean(var["p"])},
nodes=nodes,
Once the python file is setup, the training can be simply started by executing the python script.
python aneurysm.py
We use this tutorial to give an example of overfitting of training data in the PINNs. Fig. 111 shows the comparison of the validation error plots achieved for two different point densities. The case using 10 M points shows an initial convergence which later diverges even when the training error keeps reducing. This implies that the network is overfitting the sampled points while sacrificing the accuracy of flow in between them. Increasing the points to 20 M solves that problem and the flow field is generalized to a better resolution.
Fig. 111 Convergence plots for different point density
Fig. 113 shows the pressure developed inside the aneurysm and the vein. A cross-sectional view in Fig. 112 shows the distribution of velocity magnitude inside the aneurysm. One of the key challenges of this problem is getting the flow to develop inside the aneurysm sac and the streamline plot in Fig. 114 shows that Modulus Sym successfully captures the flow field inside.
Fig. 112 Cross-sectional view aneurysm showing velocity magnitude. Left: Modulus Sym. Center: OpenFOAM. Right: Difference
Fig. 113 Pressure across aneurysm. Left: Modulus Sym. Center: OpenFOAM. Right: Difference
Fig. 114 Flow streamlines inside the aneurysm generated from Modulus Sym simulation.
Numerous applications in science and engineering require repetitive simulations, such as simulation of blood flow in different patient specific models. Traditional solvers simulate these models independently and from scratch. Even a minor change to the model geometry (such as an adjustment to the patient specific medical image segmentation) requires a new simulation. Interestingly, and unlike the traditional solvers, neural network solvers can transfer knowledge across different neural network models via transfer learning. In transfer learning, the knowledge acquired by a (source) trained neural network model for a physical system is transferred to another (target) neural network model that is to be trained for a similar physical system with slightly different characteristics (such as geometrical differences). The network parameters of the target model are initialized from the source model, and are retrained to cope with the new system characteristics without having the neural network model trained from scratch. This transfer of knowledge effectively reduces the time to convergence for neural network solvers. As an example, Fig. 115 shows the application of transfer learning in training of neural network solvers for two intracranial aneurysm models with different sac shapes.
Fig. 115 Transfer learning accelerates intracranial aneurysm simulations. Results are for two intracranial aneurysms with different sac shapes.
To use transfer learning in Modulus Sym, set 'initialize_network_dir'
in the configs
to the source model network checkpoint. Also, since in transfer learning
you fine-tune the source model instead of training from scratch, use a
relatively smaller learning rate compared to a full run, with smaller
number of iterations and faster decay, as shown below.
defaults:
- modulus_default
- arch:
- fully_connected
- scheduler: tf_exponential_lr
- optimizer: adam
- loss: sum
- _self_
scheduler:
decay_rate: 0.95
#decay_steps: 15000 # full run
decay_steps: 6000 # TL run
network_dir: "network_checkpoint_target"
initialization_network_dir: "../aneurysm/network_checkpoint_source/"
training:
rec_results_freq: 10000
rec_constraint_freq: 50000
#max_steps: 1500000 # full run
max_steps: 400000 # TL run
batch_size:
inlet: 1100
outlet: 650
no_slip: 5200
interior: 6000
integral_continuity: 310