Darcy Flow with Physics-Informed Fourier Neural Operator

This tutorial solves the 2D Darcy flow problem using Physics-Informed Neural Operators (PINO) 1. You will learn:

  1. Differences between PINO and Fourier Neural Operators (FNO).

  2. How to set up and train PINO in Modulus.

  3. Defining a custom PDE constraint for grid data.

Note

This tutorial assumes that you are familiar with the basic functionality of Modulus and understand the PINO architecture. Please see the Lid Driven Cavity Background and Physics Informed Neural Operator sections for additional information. Additionally, this tutorial builds upon the Darcy Flow with Fourier Neural Operator which should be read prior to this one.

Warning

The Python package gdown is required for this example if you do not already have the example data downloaded and converted. Install using pip install gdown.

This problem illustrates developing a surrogate model that learns the mapping between a permeability and pressure field of a Darcy flow system. The mapping learned, \(\textbf{K} \rightarrow \textbf{U}\), should be true for a distribution of permeability fields \(\textbf{K} \sim p(\textbf{K})\) and not just a single solution.

The key difference between PINO and FNO is that PINO adds a physics-informed term to the loss function of FNO. As discussed further in the Physics Informed Neural Operator theory, the PINO loss function is described by:

(163)\[\mathcal{L} = \mathcal{L}_{data} + \mathcal{L}_{pde},\]

where

(164)\[\mathcal{L}_{data} = \lVert u - \mathcal{G}_\theta(a) \rVert^2 ,\]

where \(\mathcal{G}_\theta(a)\) is a FNO model with learnable parameters \(\theta\) and input field \(a\), and \(\mathcal{L}_{pde}\) is an appropriate PDE loss. For the 2D Darcy problem (see Darcy Flow with Fourier Neural Operator) this is given by

(165)\[\mathcal{L}_{pde} = \lVert -\nabla \cdot \left(k(\textbf{x})\nabla \mathcal{G}_\theta(a)(\textbf{x})\right) - f(\textbf{x}) \rVert^2 ,\]

where \(k(\textbf{x})\) is the permeability field, \(f(\textbf{x})\) is the forcing function equal to 1 in this case, and \(a=k\) in this case.

Note that the PDE loss involves computing various partial derivatives of the FNO ansatz, \(\mathcal{G}_\theta(a)\). In general this is nontrivial; in Modulus, three different methods for computing these are provided. These are based on the original PINO paper:

  1. Numerical differentiation computed via Finite-Difference Method (FDM)

  2. Numerical differentiation computed via spectral derivative

  3. Hybrid differentiation based on a combination of first-order “exact”2 derivatives and second-order FDM derivatives

The first 2 approaches are the same as proposed in the original paper. The third approach is a modification of the “exact” approach proposed in the paper. This method is slower and more memory intensive than the numerical derivative approaches when computing second order derivatives because it requires the computation of a Hessian matrix. Instead, a “hybrid” approach is provided which offers a compromise by combining first-order “exact” derivatives and second-order FDM derivatives.

The setup for this problem is largely the same as the FNO example (Darcy Flow with Fourier Neural Operator), except that the PDE loss is defined and the FNO model is contrained using it. This process is described in detail in Defining PDE Loss below.

Similar to the FNO chapter, the training and validation data for this example can be found on the Fourier Neural Operator Github page. However, an automated script for downloading and converting this dataset has been included. This requires the package gdown which can easily installed through pip install gdown.

Note

The python script for this problem can be found at examples/darcy/darcy_pino.py.

Configuration

The configuration for this problem is similar to the FNO example, but importantly there is an extra parameter custom.gradient_method where the method for computing the gradients in the PDE loss is selected. This can be one of fdm, fourier, hybrid corresponding to the three options above. The balence between the data and PDE terms in the loss function can also be controlled using the loss.weights parameter group.

Copy
Copied!
            

defaults: - modulus_default - arch: - fno - scheduler: tf_exponential_lr - optimizer: adam - loss: sum - _self_ jit: false custom: gradient_method: hybrid ntrain: 1000 ntest: 100 arch: fno: dimension: 2 nr_fno_layers: 4 fno_layer_size: 32 fno_modes: 12 padding: 9 output_fc_layer_sizes: - 128 scheduler: decay_rate: 0.95 decay_steps: 1000 training: rec_results_freq: 1000 max_steps: 10000 loss: weights: sol: 1.0 darcy: 0.1 batch_size: grid: 16 validation: 8

Defining PDE Loss

For this example, a custom PDE residual calculation is defined using the various approaches proposed above. Defining a custom PDE residual using sympy and auto-diff is discussed in 1D Wave Equation, but in this problem we will not be relying on standard auto-diff for calculating the derivatives. Rather, we want to explicitly define how the residual is calculated using a custom torch.nn.Module called Darcy. The purpose of this module is to compute and return the Darcy PDE residual given the input and output tensors of the FNO model, which is done via its .forward(...) method:

Copy
Copied!
            

from modulus.architecture.layers import fourier_derivatives# helper function for computing spectral derivatives from ops import dx, ddx# helper function for computing finite difference derivatives

Copy
Copied!
            

class Darcy(torch.nn.Module): "Custom Darcy PDE definition for PINO" def __init__(self, gradient_method: str = "hybrid"): super().__init__() self.gradient_method = str(gradient_method) def forward(self, input_var: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # get inputs u = input_var["sol"] c = input_var["coeff"] dcdx = input_var["Kcoeff_y"] # data is reversed dcdy = input_var["Kcoeff_x"] dxf = 1.0 / u.shape[-2] dyf = 1.0 / u.shape[-1] # Compute gradients based on method # Exact first order and FDM second order if self.gradient_method == "hybrid": dudx_exact = input_var["sol__x"] dudy_exact = input_var["sol__y"] dduddx_fdm = ddx( u, dx=dxf, channel=0, dim=0, order=1, padding="replication" ) dduddy_fdm = ddx( u, dx=dyf, channel=0, dim=1, order=1, padding="replication" ) # compute darcy equation darcy = ( 1.0 + (dcdx * dudx_exact) + (c * dduddx_fdm) + (dcdy * dudy_exact) + (c * dduddy_fdm) ) # FDM gradients elif self.gradient_method == "fdm": dudx_fdm = dx(u, dx=dxf, channel=0, dim=0, order=1, padding="replication") dudy_fdm = dx(u, dx=dyf, channel=0, dim=1, order=1, padding="replication") dduddx_fdm = ddx( u, dx=dxf, channel=0, dim=0, order=1, padding="replication" ) dduddy_fdm = ddx( u, dx=dyf, channel=0, dim=1, order=1, padding="replication" ) # compute darcy equation darcy = ( 1.0 + (dcdx * dudx_fdm) + (c * dduddx_fdm) + (dcdy * dudy_fdm) + (c * dduddy_fdm) ) # Fourier derivative elif self.gradient_method == "fourier": dim_u_x = u.shape[2] dim_u_y = u.shape[3] u = F.pad( u, (0, dim_u_y - 1, 0, dim_u_x - 1), mode="reflect" ) # Constant seems to give best results f_du, f_ddu = fourier_derivatives(u, [2.0, 2.0]) dudx_fourier = f_du[:, 0:1, :dim_u_x, :dim_u_y] dudy_fourier = f_du[:, 1:2, :dim_u_x, :dim_u_y] dduddx_fourier = f_ddu[:, 0:1, :dim_u_x, :dim_u_y] dduddy_fourier = f_ddu[:, 1:2, :dim_u_x, :dim_u_y] # compute darcy equation darcy = ( 1.0 + (dcdx * dudx_fourier) + (c * dduddx_fourier) + (dcdy * dudy_fourier) + (c * dduddy_fourier) ) else: raise ValueError(f"Derivative method{self.gradient_method}not supported.") # Zero outer boundary darcy = F.pad(darcy[:, :, 2:-2, 2:-2], [2, 2, 2, 2], "constant", 0) # Return darcy output_var = { "darcy": dxf * darcy, } # weight boundary loss higher

The gradients of the FNO solution are computed according to the gradient method selected above. The FNO model automatically outputs first order gradients when the hybrid method is used, and so no extra computation of these is necessary. Furthermore, note that the gradients of the permeability field are already included as tensors in the FNO input training data (with keys Kcoeff_x and Kcoeff_y) and so these do not need to be computed.

Next, incorporate this module into Modulus by wrapping it into a Modulus Node. This ensures the module is incorporated into Modulus’ computational graph and can be used to optimise the FNO.

Copy
Copied!
            

from modulus.node import Node

Copy
Copied!
            

cfg=cfg.arch.fno, domain_length=[1.0, 1.0], ) # Make custom Darcy residual node for PINO inputs = [ "sol", "coeff", "Kcoeff_x", "Kcoeff_y", ] if cfg.custom.gradient_method == "hybrid": inputs += [ "sol__x", "sol__y", ] darcy_node = Node( inputs=inputs,

Finally, define the PDE loss term by adding a constraint to the darcy output variable (see Adding Constraints below).

Loading Data

Loading both the training and validation datasets follows a similar process as the FNO example:

Copy
Copied!
            

if DistributedManager().distributed: print("Multi-GPU currently not supported for this example. Exiting.") return # load training/ test data input_keys = [ Key("coeff", scale=(7.48360e00, 4.49996e00)), Key("Kcoeff_x"), Key("Kcoeff_y"), ] output_keys = [ Key("sol", scale=(5.74634e-03, 3.88433e-03)), ] download_FNO_dataset("Darcy_241", outdir="datasets/") invar_train, outvar_train = load_FNO_dataset( "datasets/Darcy_241/piececonst_r241_N1024_smooth1.hdf5", [k.name for k in input_keys], [k.name for k in output_keys], n_examples=cfg.custom.ntrain, ) invar_test, outvar_test = load_FNO_dataset( "datasets/Darcy_241/piececonst_r241_N1024_smooth2.hdf5",

Initializing the Model

Initializing the model also follows a similar process as the FNO example:

Copy
Copied!
            

[k.name for k in output_keys], n_examples=cfg.custom.ntest, ) # Define FNO model if cfg.custom.gradient_method == "hybrid": output_keys += [ Key("sol", derivatives=[Key("x")]), Key("sol", derivatives=[Key("y")]), ] model = instantiate_arch( input_keys=[input_keys[0]],

However, in the case where the hybrid gradient method is used, you need to additionally instruct the model to output the appropriate gradients by specifying these gradients in its output keys.

Adding Constraints

Finally, add constraints to your model in a similar fashion to the FNO example. The same SupervisedGridConstraint can be used; to include the PDE loss term you need to define additional target values for the darcy output variable defined above (zeros, to minimise the PDE residual) and add them to the outvar_train dictionary:

Copy
Copied!
            

name="Darcy Node", ) nodes = model.make_nodes(name="FNO", jit=False) + [darcy_node] # make domain domain = Domain() # add constraints to domain outvar_train["darcy"] = np.zeros_like(outvar_train["sol"]) # constrain darcy node supervised = SupervisedGridConstraint( nodes=nodes, invar=invar_train, outvar=outvar_train, batch_size=cfg.batch_size.grid,

The same data validator as the FNO example is used.

The training can now be simply started by executing the python script.

Copy
Copied!
            

python darcy_PINO.py

Warning

Multi-GPU training is currently not supported for this problem.

Results and Post-processing

The checkpoint directory is saved based on the results recording frequency specified in the rec_results_freq parameter of its derivatives. See Results Frequency for more information. The network directory folder (in this case 'outputs/darcy_pino/validators') contains several plots of different validation predictions.

pino_darcy_pred.png

Fig. 69PINO validation prediction. (Left to right) Input permeability and its spatial derivatives, true pressure, predicted pressure, error.

Comparison to FNO

The TensorBoard plot below compares the validation loss of PINO (all three gradient methods) and FNO. You can see that with large amounts of training data (1000 training examples), both FNO and PINO perform similarly.

pino_darcy_tensorboard1.png

Fig. 70Comparison between PINO and FNO accuracy for surrogate modeling Darcy flow.

A benefit of PINO is that the PDE loss regularises the model, meaning that it can be more effective in “small data” regimes. The plot below shows the validation loss when both models are trained with only 100 training examples:

pino_darcy_tensorboard2.png

Fig. 71Comparison between PINO and FNO accuracy for surrogate modeling Darcy flow (small data regime).

We find that in this case the PINO outperforms the FNO.

References / Footnotes

<a class="fn-backref" href="#id1" target="_self">1</a>

Li, Zongyi, et al. “Physics-informed neural operator for learning partial differential equations.” arXiv preprint arXiv:2111.03794 (2021).

<a class="fn-backref" href="#id2" target="_self">2</a>

Note that the “exact” method is technically not exact because it uses a combination of numerical spectral derivatives and exact differentiation. See the original paper for more details.

© Copyright 2021-2022, NVIDIA. Last updated on Apr 7, 2022.