Deep Operator Network

This tutorial illustrates how to learn abstract operators using data-informed and physics-informed Deep operator network (DeepONet) in Modulus Sym. In this tutorial, you will learn

  1. How to use DeepONet architecture in Modulus Sym

  2. How to set up data-informed and physics-informed DeepONet for learning operators

Note

This tutorial assumes that you have completed the tutorial Introductory Example and are familiar with Modulus Sym APIs.

Problem Description

Consider a 1D initial value problem

(176)\[\frac{du}{dx} = a(x), \quad x \in [0, 1],\]

subject to an initial condition \(u(0)=0\). The anti-derivative operator \(G\) over \([0,1]\) given by

(177)\[G:\quad a(x) \mapsto G(a)(x):= \int_0^x a(t) dt, \quad x \in [0,1].\]

You’re going to setup a DeepONet to learn the operator \(G\). In this case, the \(a\) will be the input of branch net and the \(x\) will be the input of trunk net. As the input of branch net, \(a\) is discretized on a fixed uniform grid. They are not necessary to be the same as the query coordinates \(x\) at which a DeepONet model is evaluated. For example, you may give the data of \(a\) as \(\{a(0),\ a(0.5),\ a(1)\}\) but evaluate the output at \(\{G(a)(0.1), G(u)(0.8), G(u)(0.9)\}\). This is one of the advantages of DeepONet compared with Fourier neural operator.

Data Preparation

As data preparation, generate \(10,000\) different input functions \(a\) from a zero mean Gaussian random field (GRF) with an exponential quadratic kernel of a length scale \(l=0.2\). Then obtain the corresponding \(10,000\) ODE solutions \(u\) using an explicit Runge-Kutta method. For each input output pair of \((a, u)\), it is worth noting that only one observation of \(u(\cdot)\) is selected. It highlights the flexibility of DeepONet in terms of tackling various data structure. The training and validation data are provided in /examples/anti_derivative/data/. With this data, you can start the data informed DeepONet code.

Note

The python script for this problem can be found at /examples/anti_derivative/data_informed.py.

Case Setup

Let us first import the necessary packages.

Copy
Copied!
            

import os import sys import warnings import torch import numpy as np import modulus.sym from modulus.sym.hydra import to_absolute_path, instantiate_arch, ModulusConfig, to_yaml from modulus.sym.solver import Solver from modulus.sym.domain import Domain from modulus.sym.models.fully_connected import FullyConnectedArch from modulus.sym.models.fourier_net import FourierNetArch from modulus.sym.models.deeponet import DeepONetArch from modulus.sym.domain.constraint.continuous import DeepONetConstraint from modulus.sym.domain.validator.discrete import GridValidator from modulus.sym.dataset.discrete import DictGridDataset

Initializing the Model

In this case, you will use a fully-connected network as the branch net and a Fourier feature network as the trunk net. In branch net, the Key("a", 100) and Key("branch", 128) specify the input and the output shape corresponding to one input function \(a\). Similarly, in trunk net, the Key("x", 1) and Key("trunk", 128) specify the input and the output shape corresponding to one coordinate point \(x\). In the config, these models are defined under the arch config group.

Copy
Copied!
            

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. defaults: - modulus_default - /arch/fully_connected_cfg@arch.branch - /arch/fourier_cfg@arch.trunk - /arch/deeponet_cfg@arch.deeponet - scheduler: tf_exponential_lr - optimizer: adam - loss: sum - _self_ arch: branch: nr_layers: 4 layer_size: 128 trunk: frequencies: "('axis',[iforiinrange(5)])" nr_layers: 4 layer_size: 128 deeponet: output_keys: u scheduler: decay_rate: 0.9 decay_steps: 100 training: rec_validation_freq: 1000 max_steps: 10000 batch_size: train: 10000 validation: 100 save_filetypes: "np"

The models are initialized in the Python script using the following:

Copy
Copied!
            

trunk_net = FourierNetArch( input_keys=[Key("x")], output_keys=[Key("trunk", 128)], ) branch_net = FullyConnectedArch( input_keys=[Key("a", 100)], output_keys=[Key("branch", 128)], ) deeponet = DeepONetArch( output_keys=[Key("u")], branch_net=branch_net, trunk_net=trunk_net, ) nodes = [deeponet.make_node("deepo")]

Note

The DeepONet architecture in Modulus Sym is extremely flexible allowing users to use different branch and trunk nets. For example a convolutional model can be used in the branch network while a fully-connected is used in the trunk.

Loading Data

Then import the data from the .npy file.

Copy
Copied!
            

# load training data file_path = "data/anti_derivative.npy" if not os.path.exists(to_absolute_path(file_path)): warnings.warn( f"Directory{file_path}does not exist. Cannot continue. Please download the additional files from NGC https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_sym_examples_supplemental_materials" ) sys.exit() data = np.load(to_absolute_path(file_path), allow_pickle=True).item() x_train = data["x_train"] a_train = data["a_train"] u_train = data["u_train"] # load test data x_test = data["x_test"] a_test = data["a_test"] u_test = data["u_test"]

Adding Data Constraints

To add the data constraint, use DeepONetConstraint.

Copy
Copied!
            

# make domain domain = Domain() data = DeepONetConstraint.from_numpy( nodes=nodes, invar={"a": a_train, "x": x_train}, outvar={"u": u_train}, batch_size=cfg.batch_size.train, ) domain.add_constraint(data, "data")

Adding Data Validator

You can set validators to verify the results.

Copy
Copied!
            

# add validators for k in range(10): invar_valid = { "a": a_test[k * 100 : (k + 1) * 100], "x": x_test[k * 100 : (k + 1) * 100], } outvar_valid = {"u": u_test[k * 100 : (k + 1) * 100]} dataset = DictGridDataset(invar_valid, outvar_valid) validator = GridValidator(nodes=nodes, dataset=dataset, plotter=None) domain.add_validator(validator, "validator_{}".format(k))

Training the Model

Start the training by executing the python script.

Copy
Copied!
            

python data_informed.py

Results

The validation results (ground truth, DeepONet prediction, and difference, respectively) are shown as below (Fig. 89, Fig. 90, Fig. 91).

data_deeponet_0.png

Fig. 89 Data informed DeepONet validation result, sample 1

data_deeponet_1.png

Fig. 90 Data informed DeepONet validation result, sample 2

data_deeponet_2.png

Fig. 91 Data informed DeepONet validation result, sample 3

This section uses the physics-informed DeepONet to learn the anti-derivative operator without any observations except for the given initial condition of the ODE system. Although there is no need for the training data, you will need some data for validation.

Note

The python script for this problem can be found at /examples/anti_derivative/physics_informed.py.

Case Setup

Most of the setup for physics-informed DeepONet is the same as the data informed version. First you import the needed packages.

Copy
Copied!
            

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os

Initializing the Model

In the run function, setup the branch and trunk nets, respectively. This part is the same as the data informed version.

Copy
Copied!
            

# make list of nodes to unroll graph on trunk_net = FourierNetArch( input_keys=[Key("x")], output_keys=[Key("trunk", 128)], ) branch_net = FullyConnectedArch( input_keys=[Key("a", 100)], output_keys=[Key("branch", 128)], ) deeponet = DeepONetArch( output_keys=[Key("u")], branch_net=branch_net, trunk_net=trunk_net, ) nodes = [deeponet.make_node("deepo")]

Loading Data

Then, import the data as the data informed version.

Copy
Copied!
            

# load training data file_path = "data/anti_derivative.npy" if not os.path.exists(to_absolute_path(file_path)): warnings.warn( f"Directory{file_path}does not exist. Cannot continue. Please download the additional files from NGC https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_sym_examples_supplemental_materials" ) sys.exit() data = np.load(to_absolute_path(file_path), allow_pickle=True).item() x_train = data["x_train"] a_train = data["a_train"] u_train = data["u_train"] x_r_train = data["x_r_train"] a_r_train = data["a_r_train"] u_r_train = data["u_r_train"] # load test data x_test = data["x_test"] a_test = data["a_test"] u_test = data["u_test"]

Adding Constraints

Now the main difference of physics informed version compared with data informed is highlighted. First, impose the initial value constraint that \(a(0)=0\). The way to achieve this is to set the input of the trunk net as all zero data. Then the output function will be evaluated at only \(0\).

Copy
Copied!
            

# make domain domain = Domain() # add constraints to domain IC = DeepONetConstraint.from_numpy( nodes=nodes, invar={"a": a_train, "x": np.zeros_like(x_train)}, outvar={"u": np.zeros_like(u_train)}, batch_size=cfg.batch_size.train, ) domain.add_constraint(IC, "IC")

Next, impose the derivative constraint that \(\frac{d}{dx}u(x) = a(x)\). Note here that u__x is the derivative of u w.r.t x.

Copy
Copied!
            

interior = DeepONetConstraint.from_numpy( nodes=nodes, invar={"a": a_r_train, "x": x_r_train}, outvar={"u__x": u_r_train}, batch_size=cfg.batch_size.train, ) domain.add_constraint(interior, "Residual")

Adding Data Validator

Finally, add the validator. This is the same as data informed version.

Copy
Copied!
            

# add validators for k in range(10): invar_valid = { "a": a_test[k * 100 : (k + 1) * 100], "x": x_test[k * 100 : (k + 1) * 100], } outvar_valid = {"u": u_test[k * 100 : (k + 1) * 100]} dataset = DictGridDataset(invar_valid, outvar_valid) validator = GridValidator(nodes=nodes, dataset=dataset, plotter=None) domain.add_validator(validator, "validator_{}".format(k))

Training the Model

Start the training by executing the python script.

Copy
Copied!
            

python physics_informed.py

Results

The validation results (ground truth, DeepONet prediction, and difference, respectively) are shown as below (Fig. 92, Fig. 93, Fig. 94).

physics_deeponet_0.png

Fig. 92 Physics informed DeepONet validation result, sample 1

physics_deeponet_1.png

Fig. 93 Physics informed DeepONet validation result, sample 2

physics_deeponet_2.png

Fig. 94 Physics informed DeepONet validation result, sample 3

Case Setup

In this section, you will set up a data-informed DeepONet for learning the solution operator of a 2D Darcy flow in Modulus Sym. The problem setup and training data are the same as in Fourier Neural Operators. Please see the tutorial Darcy Flow with Fourier Neural Operator for more details. It is worth emphasising that one can employ any built-in Modulus Sym network architectures in a DeepONet model.

Note

The python script for this problem can be found at examples/darcy/darcy_DeepO.py.

Loading Data

Loading both the training and validation datasets into memory follows a similar process as the Darcy Flow with Fourier Neural Operator example.

Copy
Copied!
            

# load training/ test data branch_input_keys = [Key("coeff")] trunk_input_keys = [Key("x"), Key("y")] output_keys = [Key("sol")] download_FNO_dataset("Darcy_241", outdir="datasets/") invar_train, outvar_train = load_deeponet_dataset( "datasets/Darcy_241/piececonst_r241_N1024_smooth1.hdf5", [k.name for k in branch_input_keys], [k.name for k in output_keys], n_examples=1000, ) invar_test, outvar_test = load_deeponet_dataset( "datasets/Darcy_241/piececonst_r241_N1024_smooth2.hdf5", [k.name for k in branch_input_keys], [k.name for k in output_keys], n_examples=10, )

Initializing the Model

Initializing DeepONet and domain is similar to the anti-derivative example but this time we will use a convolutional model. Similar to the FNO example the model can be configured entirely through the config file. A pix2pix convolutional model will be used as the branch net a while a fully-connected will be used as the trunk. The DeepONet architecture will automatically handle the dimensionality difference.

Copy
Copied!
            

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. defaults: - modulus_default - /arch/pix2pix_cfg@arch.branch - /arch/fourier_cfg@arch.trunk - /arch/deeponet_cfg@arch.deeponet - scheduler: tf_exponential_lr - optimizer: adam - loss: sum - _self_ arch: branch: input_keys: [coeff] output_keys: [branch] dimension: 2 conv_layer_size: 32 trunk: input_keys: [x, y] output_keys: ['trunk', 256] frequencies: "('axis',[0,1,2,3,4])" nr_layers: 4 layer_size: 128 deeponet: output_keys: sol branch_dim: 1024 scheduler: decay_rate: 0.9 decay_steps: 2000 training: rec_validation_freq: 1000 max_steps: 100000 batch_size: train: 1000 save_filetypes: "np"

The models are initialized inside the Python script using the following:

Copy
Copied!
            

# make list of nodes to unroll graph on branch_net = instantiate_arch( cfg=cfg.arch.branch, ) trunk_net = instantiate_arch( cfg=cfg.arch.trunk, ) deeponet = instantiate_arch( cfg=cfg.arch.deeponet, branch_net=branch_net, trunk_net=trunk_net, ) nodes = [deeponet.make_node(name="deepo")]

Adding Data Constraints and Validators

Then you can add data constraints as before

Copy
Copied!
            

# make domain domain = Domain() # add constraint to domain data = DeepONetConstraint.from_numpy( nodes=nodes, invar=invar_train, outvar=outvar_train, batch_size=cfg.batch_size.train, ) domain.add_constraint(data, "data")

Copy
Copied!
            

# add validators val = PointwiseValidator( nodes=nodes, invar=invar_test, true_outvar=outvar_test, plotter=None, ) domain.add_validator(val, "val")

Training the Model

The training can now be simply started by executing the python script.

Copy
Copied!
            

python darcy_DeepO.py

Results and Post-processing

The validation results (ground truth, DeepONet prediction, and difference, respectively) are shown as below.

deeponet_darcy_1.png

Fig. 95 DeepONet validation result, sample 1

deeponet_darcy_2.png

Fig. 96 DeepONet validation result, sample 2

deeponet_darcy_3.png

Fig. 97 DeepONet validation result, sample 3

Previous Darcy Flow with Physics-Informed Fourier Neural Operator
Next FourCastNet
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.