NVIDIA Modulus Sym v1.1.0
Sym v1.1.0

Jupyter Notebook workflow

This tutorial builds a simple example in a jupyter notebook. This workflow is useful for interactive developement and rapid prototyping.

Hydra is a configuration package that is built into the heart of Modulus Sym, which allows for easy control over various hyperparameters using YAML files. For every problem being solved using Modulus, Hydra is the first component to be initialized and has direct influence on all component levels inside of Modulus Sym.

Typically, if you are setting up the problem using a python script, this config file is loaded into a Python ModulusConfig object which is then used by Modulus. The below code shows a minimal python script that is ingesting the Modulus Sym configs.

Copy
Copied!
            

import modulus.sym from modulus.sym.hydra import to_yaml from modulus.sym.hydra.config import ModulusConfig @modulus.main(config_path="conf", config_name="config") def run(cfg: ModulusConfig) -> None: print(to_yaml(cfg)) if __name__ == "__main__": run()

We recommend this workflow for larger more complex projects. This configuration setup is used in most of the examples documented in this User Guide. However, for running the Modulus Sym in a jupyter notebook environment, we can take the approach shown below which uses the compose() utility. Let’s see how the config object can be loaded in such a case. The contents of the config.yaml file are shown:

Copy
Copied!
            

defaults: - modulus_default - scheduler: tf_exponential_lr - optimizer: adam - loss: sum - _self_ scheduler: decay_rate: 0.95 decay_steps: 200 save_filetypes: "vtk,npz" training: rec_results_freq: 1000 rec_constraint_freq: 1000 max_steps: 10000

Next, we load the config file using:

Copy
Copied!
            

[1]:

Copy
Copied!
            

import modulus.sym from modulus.sym.hydra import to_yaml from modulus.sym.hydra.utils import compose from modulus.sym.hydra.config import ModulusConfig cfg = compose(config_path="conf", config_name="config") cfg.network_dir = 'outputs' # Set the network directory for checkpoints print(to_yaml(cfg))


Copy
Copied!
            

/usr/local/lib/python3.10/dist-packages/modulus/sym/hydra/utils.py:149: UserWarning: The version_base parameter is not specified. Please specify a compatability version level, or None. Will assume defaults for version 1.1 hydra.initialize( TorchScript default is being turned off due to PyTorch version mismatch.



Copy
Copied!
            

training: max_steps: 10000 grad_agg_freq: 1 rec_results_freq: 1000 rec_validation_freq: ${training.rec_results_freq} rec_inference_freq: ${training.rec_results_freq} rec_monitor_freq: ${training.rec_results_freq} rec_constraint_freq: 1000 save_network_freq: 1000 print_stats_freq: 100 summary_freq: 1000 amp: false amp_dtype: float16 ntk: use_ntk: false save_name: null run_freq: 1000 graph: func_arch: false func_arch_allow_partial_hessian: true stop_criterion: metric: null min_delta: null patience: 50000 mode: min freq: 1000 strict: false profiler: profile: false start_step: 0 end_step: 100 name: nvtx network_dir: outputs initialization_network_dir: '' save_filetypes: vtk,npz summary_histograms: false jit: false jit_use_nvfuser: true jit_arch_mode: only_activation jit_autograd_nodes: false cuda_graphs: true cuda_graph_warmup: 20 find_unused_parameters: false broadcast_buffers: false device: '' debug: false run_mode: train arch: ??? models: ??? loss: _target_: modulus.sym.loss.aggregator.Sum weights: null optimizer: _params_: compute_gradients: adam_compute_gradients apply_gradients: adam_apply_gradients _target_: torch.optim.Adam lr: 0.001 betas: - 0.9 - 0.999 eps: 1.0e-08 weight_decay: 0.0 amsgrad: false scheduler: _target_: custom _name_: tf.ExponentialLR decay_rate: 0.95 decay_steps: 200 batch_size: ??? custom: ???


For this puerly physics driven case, we won’t use any external training data. Instead we will create some geometry that we can use to sample the various collocation points needed to impose the boundary and equation losses. Modulus Sym has several geometry objects to choose from ranging from 1D shapes like Point1D, Line1D to more complex 3D ones like Torus, Tetrahedron etc. Let’s use the Line1D object for this example to sample the required points.

Copy
Copied!
            

[2]:

Copy
Copied!
            

from sympy import Symbol from modulus.sym.geometry.primitives_1d import Line1D # make geometry x = Symbol("x") geo = Line1D(0, 1)

Once the geometry object is instantiated, you can use methods like sample_boundary and sample_interior to sample the points using that geometry object to get a feel for what is being sampled. Feel free to plot the samples for more visualization.

Copy
Copied!
            

[3]:

Copy
Copied!
            

samples = geo.sample_boundary(10, quasirandom=True) print("Boundary Samples", samples) samples = geo.sample_interior(100, quasirandom=True) # print("Interior Samples", samples) import matplotlib.pyplot as plt plt.figure() plt.scatter(samples['x'], samples['sdf'], label='Signed Distance Field') plt.legend() plt.show()


Copy
Copied!
            

Boundary Samples {'x': array([[0.], [0.], [0.], [0.], [1.], [1.], [1.], [1.], [1.], [1.]]), 'normal_x': array([[-1.], [-1.], [-1.], [-1.], [ 1.], [ 1.], [ 1.], [ 1.], [ 1.], [ 1.]]), 'area': array([[0.25 ], [0.25 ], [0.25 ], [0.25 ], [0.16666667], [0.16666667], [0.16666667], [0.16666667], [0.16666667], [0.16666667]])}



notebook.nbconvert_9_1.png

In this section, we will create the nodes required for our problem. These include the neural network itself (which acts as a adaptable function) and any equations that are used to formulate the PDE loss functions. Before that, let’s quickly define the differential equation for the problem using sympy.

The PDE class allows us to write the equations symbolically in Sympy. This allows you to quickly write your equations in the most natural way possible. The Sympy equations are converted to PyTorch expressions in the back-end and can also be printed to ensure correct implementation.

In the subsequent examples we will look at how to code a more complicated PDE, but for this example, the simple PDE can be set up as below. Modulus Sym also comes with several common PDEs predefined for the user to choose from. Some of the PDEs that are already available in the PDEs module are: Navier Stokes, Linear Elasticity, Advection Diffusion, Wave Equations, etc.

Copy
Copied!
            

[4]:

Copy
Copied!
            

from sympy import Symbol, Number, Function from modulus.sym.eq.pde import PDE class CustomPDE(PDE): def __init__(self, f=1.0): # coordinates x = Symbol("x") # make input variables input_variables = {"x": x} # make u function u = Function("u")(*input_variables) # source term if type(f) is str: f = Function(f)(*input_variables) elif type(f) in [float, int]: f = Number(f) # set equations self.equations = {} self.equations["custom_pde"] = ( u.diff(x, 2) - f ) # "custom_pde" key name will be used in constraints

Now that we have the custom PDE defined, let’s setup the nodes for the problem.

Copy
Copied!
            

[5]:

Copy
Copied!
            

from modulus.sym.models.fully_connected import FullyConnectedArch from modulus.sym.key import Key # make list of nodes to unroll graph on eq = CustomPDE(f=1.0) u_net = FullyConnectedArch( input_keys=[Key("x")], output_keys=[Key("u")], nr_layers=3, layer_size=32 ) nodes = eq.make_nodes() + [u_net.make_node(name="u_network")]

Let’s visualize the symbolic node that we created and the architecture itself.

Copy
Copied!
            

[6]:

Copy
Copied!
            

# visualize the network and symbolic equation in Modulus Sym: print(u_net) print(eq.pprint()) # graphically visualize the PyTorch execution graph !pip install torchviz import torch from torchviz import make_dot from IPython.display import Image # pass dummy data through the model data_out = u_net({"x": (torch.rand(10, 1)),}) make_dot(data_out["u"], params=dict(u_net.named_parameters())).render("u_network", format="png") display(Image(filename='./u_network.png'))


Copy
Copied!
            

FullyConnectedArch( (_impl): FullyConnectedArchCore( (layers): ModuleList( (0): FCLayer( (linear): WeightNormLinear(in_features=1, out_features=32, bias=True) ) (1-2): 2 x FCLayer( (linear): WeightNormLinear(in_features=32, out_features=32, bias=True) ) ) (final_layer): FCLayer( (linear): Linear(in_features=32, out_features=1, bias=True) ) ) ) custom_pde: u__x__x - 1.0 None Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com Collecting torchviz Downloading torchviz-0.0.2.tar.gz (4.9 kB) Preparing metadata (setup.py) ... - done Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from torchviz) (2.1.0a0+b5021ba) Collecting graphviz (from torchviz) Downloading graphviz-0.20.1-py3-none-any.whl (47 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 47.0/47.0 kB 2.6 MB/s eta 0:00:00 Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->torchviz) (3.12.2) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->torchviz) (4.7.0) Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->torchviz) (1.5.1) Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->torchviz) (2.6.3) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->torchviz) (3.0.3) Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->torchviz) (2023.9.1) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->torchviz) (2.1.3) Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->torchviz) (1.3.0) Building wheels for collected packages: torchviz Building wheel for torchviz (setup.py) ... - \ | done Created wheel for torchviz: filename=torchviz-0.0.2-py3-none-any.whl size=4147 sha256=0b583a441d2c857d53264d1e8f7b9ac8c224cbe6f3581aa50da8b3f048bd5eff Stored in directory: /tmp/pip-ephem-wheel-cache-_xhp9say/wheels/4c/97/88/a02973217949e0db0c9f4346d154085f4725f99c4f15a87094 Successfully built torchviz Installing collected packages: graphviz, torchviz Successfully installed graphviz-0.20.1 torchviz-0.0.2 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv [notice] A new release of pip is available: 23.2.1 -> 23.3 [notice] To update, run: python -m pip install --upgrade pip



notebook.nbconvert_15_1.png

The Domain holds all constraints as well as additional components needed in the training process. These additional components include inferencers, validators, and monitors. When developing in Modulus Sym, constraints that the user defines are then added to the training Domain to create a collection of training objectives. The Domain and the configs are passed as inputs when using the Solver class.

Copy
Copied!
            

[7]:

Copy
Copied!
            

from modulus.sym.domain import Domain # make domain domain = Domain()

Now let’s look into adding constraints to this domain. This can be thought of as adding specific constraints to the neural network optimization. For this physics-driven problem, these constraints are the boundary conditions and equation residuals. The goal is to satisfy the boundary conditions exactly, and ideally have the PDE residuals to go 0. These constraints can be specified within Modulus Sym using classes like PointwiseBoundaryConstrant and PointwiseInteriorConstraint. A L2 loss (defult and can be modified) is then constructed from these constraints which is used by the optimizer to minimize on. Specifying the constraints in this fashion is called soft-constraints.

Boundary constraints: For generating a boundary condition, we need to sample the points on the required boundary/surface of the geometry, specify the nodes we would like to unroll/evaluate on these points and then assign them the desired values.

A boundary can be sampled using PointwiseBoundaryConstraint class. This will sample the entire boundary of the geometry we specify in the geometry argument, in this case, both the endpoints of the 1d line. A particular boundary of the geometry can be sub-sampled by using a particular criterion using the criteria parameter. We will see its use in a later example. The desired values for the boundary condition are listed as a dictionary in outvar argument. These dictionaries are then used when unrolling the computational graph (specified using the nodes argument) for training. The number of points to sample on each boundary are specified using the batch_size argument.

Equations to solve: The Custom PDE we defined is enforced on all the points in the interior. We will use PointwiseInteriorConstraint class to sample points in the interior of the geometry. Again, the appropriate geometry is specified in the geometry argument; the equations to solve are specified as a dictionary input to outvar argument. These dictionaries are then used when unrolling the computational graph (specified using the nodes argument) for training. For this problem we have the 'custom_pde':0. The argument bounds, determines the range for sampling the values for variables.

Copy
Copied!
            

[8]:

Copy
Copied!
            

from modulus.sym.domain.constraint import PointwiseBoundaryConstraint, PointwiseInteriorConstraint # bcs bc = PointwiseBoundaryConstraint( nodes=nodes, geometry=geo, outvar={"u": 0}, batch_size=2, ) domain.add_constraint(bc, "bc") # interior interior = PointwiseInteriorConstraint( nodes=nodes, geometry=geo, outvar={"custom_pde": 0}, batch_size=100, bounds={x: (0, 1)}, ) domain.add_constraint(interior, "interior")

Let’s create some inferencer object to visualize our results.

Copy
Copied!
            

[9]:

Copy
Copied!
            

import numpy as np from modulus.sym.domain.inferencer import PointwiseInferencer # add inferencer inference = PointwiseInferencer( nodes=nodes, invar={"x": np.linspace(0, 1.0, 100).reshape(-1,1)}, output_names=["u"], ) domain.add_inferencer(inference, "inf_data")

We can create a solver by using the domain we just created along with the other configurations that define the optimizer choices, settings (i.e. conf) using Modulus Sym’s Solver class. The solver can then be executed using the solve method.

Copy
Copied!
            

[10]:

Copy
Copied!
            

# to make the logging work in the jupyter cells # execute this cell only once import logging logging.getLogger().addHandler(logging.StreamHandler())

Copy
Copied!
            

[11]:

Copy
Copied!
            

import os from modulus.sym.solver import Solver # optional # set appropriate GPU in case of multi-GPU machine # os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"]="2" # make solver slv = Solver(cfg, domain) # start solver slv.solve()


Copy
Copied!
            

attempting to restore from: /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs optimizer checkpoint not found model u_network.0.pth not found /usr/local/lib/python3.10/dist-packages/torch/_functorch/deprecated.py:61: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.vmap is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.vmap instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html warn_deprecated('vmap', 'torch.vmap') /usr/local/lib/python3.10/dist-packages/torch/_functorch/deprecated.py:77: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.jvp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.jvp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html warn_deprecated('jvp') /usr/local/lib/python3.10/dist-packages/torch/_functorch/deprecated.py:73: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.vjp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.vjp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html warn_deprecated('vjp') [step: 0] saved constraint results to outputs [step: 0] record constraint batch time: 2.020e-02s [step: 0] saved inferencer results to outputs [step: 0] record inferencers time: 3.926e-03s [step: 0] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 0] loss: 6.198e-01 Attempting cuda graph building, this may take a bit... [step: 100] loss: 2.897e-04, time/iteration: 1.109e+02 ms [step: 200] loss: 1.992e-05, time/iteration: 4.593e+00 ms [step: 300] loss: 2.384e-06, time/iteration: 4.546e+00 ms [step: 400] loss: 1.701e-06, time/iteration: 4.583e+00 ms [step: 500] loss: 1.624e-06, time/iteration: 4.547e+00 ms [step: 600] loss: 1.805e-06, time/iteration: 4.580e+00 ms [step: 700] loss: 1.209e-06, time/iteration: 4.620e+00 ms [step: 800] loss: 1.135e-06, time/iteration: 4.614e+00 ms [step: 900] loss: 8.912e-07, time/iteration: 4.655e+00 ms [step: 1000] saved constraint results to outputs [step: 1000] record constraint batch time: 1.995e-02s [step: 1000] saved inferencer results to outputs [step: 1000] record inferencers time: 3.986e-03s [step: 1000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 1000] loss: 1.173e-06, time/iteration: 5.293e+00 ms [step: 1100] loss: 9.503e-07, time/iteration: 4.543e+00 ms [step: 1200] loss: 8.333e-07, time/iteration: 4.519e+00 ms [step: 1300] loss: 6.433e-07, time/iteration: 4.480e+00 ms [step: 1400] loss: 5.219e-07, time/iteration: 4.482e+00 ms [step: 1500] loss: 7.840e-07, time/iteration: 4.480e+00 ms [step: 1600] loss: 7.086e-07, time/iteration: 4.653e+00 ms [step: 1700] loss: 6.743e-07, time/iteration: 4.459e+00 ms [step: 1800] loss: 5.179e-07, time/iteration: 4.502e+00 ms [step: 1900] loss: 4.267e-07, time/iteration: 4.476e+00 ms [step: 2000] saved constraint results to outputs [step: 2000] record constraint batch time: 1.777e-02s [step: 2000] saved inferencer results to outputs [step: 2000] record inferencers time: 3.750e-03s [step: 2000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 2000] loss: 4.716e-07, time/iteration: 4.880e+00 ms [step: 2100] loss: 5.014e-07, time/iteration: 4.505e+00 ms [step: 2200] loss: 4.765e-07, time/iteration: 4.465e+00 ms [step: 2300] loss: 4.829e-07, time/iteration: 4.487e+00 ms [step: 2400] loss: 4.770e-07, time/iteration: 4.472e+00 ms [step: 2500] loss: 3.817e-07, time/iteration: 4.465e+00 ms [step: 2600] loss: 5.599e-07, time/iteration: 4.469e+00 ms [step: 2700] loss: 3.385e-07, time/iteration: 4.459e+00 ms [step: 2800] loss: 4.558e-07, time/iteration: 4.468e+00 ms [step: 2900] loss: 4.098e-07, time/iteration: 4.469e+00 ms [step: 3000] saved constraint results to outputs [step: 3000] record constraint batch time: 1.626e-02s [step: 3000] saved inferencer results to outputs [step: 3000] record inferencers time: 3.900e-03s [step: 3000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 3000] loss: 4.536e-07, time/iteration: 4.868e+00 ms [step: 3100] loss: 4.093e-07, time/iteration: 4.476e+00 ms [step: 3200] loss: 4.538e-07, time/iteration: 4.468e+00 ms [step: 3300] loss: 3.622e-07, time/iteration: 4.469e+00 ms [step: 3400] loss: 4.896e-07, time/iteration: 4.481e+00 ms [step: 3500] loss: 4.094e-07, time/iteration: 4.472e+00 ms [step: 3600] loss: 4.390e-07, time/iteration: 4.454e+00 ms [step: 3700] loss: 4.511e-07, time/iteration: 4.469e+00 ms [step: 3800] loss: 4.221e-07, time/iteration: 4.457e+00 ms [step: 3900] loss: 3.551e-07, time/iteration: 4.484e+00 ms [step: 4000] saved constraint results to outputs [step: 4000] record constraint batch time: 1.729e-02s [step: 4000] saved inferencer results to outputs [step: 4000] record inferencers time: 3.702e-03s [step: 4000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 4000] loss: 2.978e-07, time/iteration: 4.864e+00 ms [step: 4100] loss: 3.336e-07, time/iteration: 4.458e+00 ms [step: 4200] loss: 4.271e-07, time/iteration: 4.470e+00 ms [step: 4300] loss: 3.872e-07, time/iteration: 4.460e+00 ms [step: 4400] loss: 3.924e-07, time/iteration: 4.472e+00 ms [step: 4500] loss: 3.110e-07, time/iteration: 4.467e+00 ms [step: 4600] loss: 3.840e-07, time/iteration: 4.475e+00 ms [step: 4700] loss: 3.453e-07, time/iteration: 4.465e+00 ms [step: 4800] loss: 3.240e-07, time/iteration: 4.474e+00 ms [step: 4900] loss: 3.005e-07, time/iteration: 4.455e+00 ms [step: 5000] saved constraint results to outputs [step: 5000] record constraint batch time: 1.635e-02s [step: 5000] saved inferencer results to outputs [step: 5000] record inferencers time: 3.474e-03s [step: 5000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 5000] loss: 4.937e-07, time/iteration: 4.863e+00 ms [step: 5100] loss: 3.434e-07, time/iteration: 4.472e+00 ms [step: 5200] loss: 3.444e-07, time/iteration: 4.467e+00 ms [step: 5300] loss: 2.832e-07, time/iteration: 4.454e+00 ms [step: 5400] loss: 3.869e-07, time/iteration: 4.453e+00 ms [step: 5500] loss: 2.286e-07, time/iteration: 4.460e+00 ms [step: 5600] loss: 2.485e-07, time/iteration: 4.496e+00 ms [step: 5700] loss: 2.399e-07, time/iteration: 4.469e+00 ms [step: 5800] loss: 2.396e-07, time/iteration: 4.455e+00 ms [step: 5900] loss: 2.385e-07, time/iteration: 4.478e+00 ms [step: 6000] saved constraint results to outputs [step: 6000] record constraint batch time: 1.662e-02s [step: 6000] saved inferencer results to outputs [step: 6000] record inferencers time: 3.530e-03s [step: 6000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 6000] loss: 2.343e-07, time/iteration: 4.879e+00 ms [step: 6100] loss: 2.495e-07, time/iteration: 4.464e+00 ms [step: 6200] loss: 1.919e-07, time/iteration: 4.481e+00 ms [step: 6300] loss: 2.476e-07, time/iteration: 4.451e+00 ms [step: 6400] loss: 3.320e-07, time/iteration: 4.469e+00 ms [step: 6500] loss: 2.624e-07, time/iteration: 4.462e+00 ms [step: 6600] loss: 2.089e-07, time/iteration: 4.464e+00 ms [step: 6700] loss: 2.930e-07, time/iteration: 4.477e+00 ms [step: 6800] loss: 1.763e-07, time/iteration: 4.463e+00 ms [step: 6900] loss: 1.849e-07, time/iteration: 4.466e+00 ms [step: 7000] saved constraint results to outputs [step: 7000] record constraint batch time: 1.687e-02s [step: 7000] saved inferencer results to outputs [step: 7000] record inferencers time: 3.583e-03s [step: 7000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 7000] loss: 1.626e-07, time/iteration: 4.878e+00 ms [step: 7100] loss: 2.192e-07, time/iteration: 4.693e+00 ms [step: 7200] loss: 2.445e-07, time/iteration: 4.456e+00 ms [step: 7300] loss: 1.656e-07, time/iteration: 4.453e+00 ms [step: 7400] loss: 4.015e-07, time/iteration: 4.461e+00 ms [step: 7500] loss: 1.807e-07, time/iteration: 4.529e+00 ms [step: 7600] loss: 9.484e-07, time/iteration: 4.509e+00 ms [step: 7700] loss: 1.505e-06, time/iteration: 4.572e+00 ms [step: 7800] loss: 1.922e-07, time/iteration: 4.590e+00 ms [step: 7900] loss: 3.970e-06, time/iteration: 4.615e+00 ms [step: 8000] saved constraint results to outputs [step: 8000] record constraint batch time: 1.870e-02s [step: 8000] saved inferencer results to outputs [step: 8000] record inferencers time: 3.477e-03s [step: 8000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 8000] loss: 4.261e-06, time/iteration: 5.017e+00 ms [step: 8100] loss: 1.624e-07, time/iteration: 4.484e+00 ms [step: 8200] loss: 1.113e-07, time/iteration: 4.532e+00 ms [step: 8300] loss: 1.772e-07, time/iteration: 4.652e+00 ms [step: 8400] loss: 1.324e-07, time/iteration: 4.467e+00 ms [step: 8500] loss: 1.107e-07, time/iteration: 4.524e+00 ms [step: 8600] loss: 1.186e-07, time/iteration: 4.514e+00 ms [step: 8700] loss: 4.657e-07, time/iteration: 4.505e+00 ms [step: 8800] loss: 9.405e-08, time/iteration: 4.479e+00 ms [step: 8900] loss: 7.684e-08, time/iteration: 4.487e+00 ms [step: 9000] saved constraint results to outputs [step: 9000] record constraint batch time: 1.760e-02s [step: 9000] saved inferencer results to outputs [step: 9000] record inferencers time: 3.445e-03s [step: 9000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 9000] loss: 2.143e-07, time/iteration: 4.896e+00 ms [step: 9100] loss: 6.356e-08, time/iteration: 4.490e+00 ms [step: 9200] loss: 8.093e-08, time/iteration: 4.541e+00 ms [step: 9300] loss: 5.905e-08, time/iteration: 4.468e+00 ms [step: 9400] loss: 5.401e-08, time/iteration: 4.491e+00 ms [step: 9500] loss: 1.142e-07, time/iteration: 4.476e+00 ms [step: 9600] loss: 5.954e-08, time/iteration: 4.467e+00 ms [step: 9700] loss: 5.403e-08, time/iteration: 4.533e+00 ms [step: 9800] loss: 5.789e-08, time/iteration: 4.464e+00 ms [step: 9900] loss: 7.232e-08, time/iteration: 4.484e+00 ms [step: 10000] saved constraint results to outputs [step: 10000] record constraint batch time: 1.695e-02s [step: 10000] saved inferencer results to outputs [step: 10000] record inferencers time: 3.413e-03s [step: 10000] saved checkpoint to /examples/release_23.11/first_testing/modulus-sym/docs/user_guide/notebook/outputs [step: 10000] loss: 6.352e-08, time/iteration: 4.893e+00 ms [step: 10000] reached maximum training steps, finished training!


The inference domain can be visulaized and the results can be plotted using matplotlib.

Copy
Copied!
            

[12]:

Copy
Copied!
            

import matplotlib.pyplot as plt import numpy as np data = np.load('./outputs/inferencers/inf_data.npz', allow_pickle=True) data = np.atleast_1d(data.f.arr_0)[0] plt.figure() x = data['x'].flatten() pred_u = data['u'].flatten() plt.plot(np.sort(x), pred_u[np.argsort(x)], label='Neural Solver') plt.plot(np.sort(x), 0.5*(np.sort(x)*(np.sort(x)-1)), label='(1/2)(x-1)x') x_np = np.array([0., 1.]) u_np = 0.5*(x_np-1)*x_np plt.scatter(x_np, u_np, label='BC') plt.legend() plt.show()


notebook.nbconvert_26_0.png

© Copyright 2023, NVIDIA Modulus Team. Last updated on Oct 17, 2023.