1D Wave Equation


This tutorial, walks you through the process of setting up a custom PDE in Modulus. It demonstrates the process on a time-dependent problem of a simple 1D wave equation. It also shows how to solve transient physics in Modulus. In this tutorial you will learn the following:

  1. How to write your own Partial Differential Equation and boundary conditions in Modulus.

  2. How to solve a time-dependent problem in Modulus.

  3. How to impose initial conditions and boundary conditions for a transient problem.

  4. How to generate validation data from analytical solutions.


This tutorial assumes that you have completed the Lid Driven Cavity Background tutorial and have familiarized yourself with the basics of Modulus APIs.

Problem Description

In this tutorial, you will solve a simple 1D wave equation . The wave is described by the below equation.

(127)\[\begin{split}\begin{aligned} \begin{split}\label{transient:eq1} u_{tt} & = c^2 u_{xx}\\ u(0,t) & = 0, \\ u(\pi, t) & = 0,\\ u(x,0) & = \sin(x), \\ u_t(x, 0) & = \sin(x). \\ \end{split}\end{aligned}\end{split}\]

Where, the wave speed \(c=1\) and the analytical solution to the above problem is given by \(\sin(x)(\sin(t) + \cos(t))\).

Writing custom PDEs and boundary/initial conditions

In this tutorial, you will write the 1D wave equation. The wave equation in n-dimensions can be found at https://en.wikipedia.org/wiki/Wave_equation. You will also see how to handle derivative type boundary conditions. The PDEs defined in the source code directory modulus/PDES/ can be used for reference.

In this tutorial you will make a file wave_equation.py and define the wave equation in 1D in it. The PDES class allows you to write the equations symbolically in Sympy. This allows you to quickly write your equations in the most natural way possible. The Sympy equations are converted to Pytorch expressions in the back-end and can also be printed to ensure correct implementation.

First create a class WaveEquation1D that inherits from PDES.

"""Wave equation
Reference: https://en.wikipedia.org/wiki/Wave_equation

from sympy import Symbol, Function, Number
from modulus.pdes import PDES

class WaveEquation1D(PDES):
    Wave equation 1D
    The equation is given as an example for implementing
    your own PDE. A more universal implementation of the
    wave equation can be found by
    `from modulus.PDES.wave_equation import WaveEquation`.

    c : float, string
        Wave speed coefficient. If a string then the
        wave speed is input into the equation.

    name = "WaveEquation1D"

Now create the initialization method for this class that defines the equation(s) of interest. You will define the wave equation using the wave speed(\(c\) ). If \(c\) is given as a string you will convert it to functional form. This allows you to solve problems with spatially/temporally varying wave speed. This is also used in the subsequent inverse example.

As shown in the code block below, first the input variables \(x\) and \(t\) were defined with Sympy symbols. Then the functions for \(u\) and \(c\) that are dependent on \(x\) and \(t\) are defined. Using these you can write the simple equation \(u_{tt} = (c^2 u_x)_x\). Store this equation in the class by adding it to the dictionary of equations.

    def __init__(self, c=1.0):
        # coordinates
        x = Symbol("x")

        # time
        t = Symbol("t")

        # make input variables
        input_variables = {"x": x, "t": t}

        # make u function
        u = Function("u")(*input_variables)

        # wave speed coefficient
        if type(c) is str:
            c = Function(c)(*input_variables)
        elif type(c) in [float, int]:
            c = Number(c)

        # set equations
        self.equations = {}
        self.equations["wave_equation"] = u.diff(t, 2) - (c ** 2 * u.diff(x)).diff(x)

Note the structure of the equation for 'wave_equation'. You will have to move all the terms of the PDE either to LHS or RHS and just have the source term on one side. This way, while using the equations in the constraints, you can assign a custom source function to the 'wave_equation' key instead of 0 to add the source to the PDE.

Great! You’ve just written your our own PDE in Modulus. To verify the implementation, please see the script modulus/PDES/wave_equation.py. Also, once you have understood the process to code a simple PDE, you can easily extend the procedure for different PDEs in multi-dimensions (2d, 3d, etc.) by making additional input variables, constants, etc. You can also bundle multiple PDEs together in a same file by adding new keys to the equations dictionary.

Now you can write the solver file where you can make use of the newly coded wave equation to solve the 1D wave problem.

Case Setup

This tutorial uses Line1D to sample points in a single dimension. The time-dependent equation is solved by supplying \(t\) as a variable parameter to the param_ranges , with the ranges being the time domain of interest. param_ranges is also used when solving problems involving variation in geometric or variable PDE constants.


  • This solves the problem by treating time as a continuous variable. The examples of discrete time stepping in the form of continuous time window approach that is presented in Moving Time Window: Taylor Green Vortex Decay.

  • The python script for this problem can be found at examples/wave_equation/wave_1d.py. The PDE coded in wave_equation.py is also in the same directory for reference.

Importing the required packages

The new packages/modules imported in this tutorial are geometry_1d for using the 1D geometry. Import WaveEquation1D from the file you just created.

import numpy as np
from sympy import Symbol, sin

import modulus
from modulus.hydra import to_yaml, instantiate_arch
from modulus.hydra.config import ModulusConfig
from modulus.continuous.solvers.solver import Solver
from modulus.continuous.domain.domain import Domain
from modulus.geometry.csg.csg_1d import Line1D
from modulus.continuous.constraints.constraint import (

from modulus.continuous.validator.validator import PointwiseValidator
from modulus.key import Key
from modulus.node import Node
from wave_equation import WaveEquation1D

Creating Nodes and Domain

This part of of the problem is similar to the tutorial Lid Driven Cavity Background. WaveEquation class is used to compute the wave equation and the wave speed is defined based on the problem statement. A neural network with x and t as input and u as output is also created.

@modulus.main(config_path="conf", config_name="config")
def run(cfg: ModulusConfig) -> None:

    # make list of nodes to unroll graph on
    we = WaveEquation1D(c=1.0)
    wave_net = instantiate_arch(
        input_keys=[Key("x"), Key("t")],
    nodes = we.make_nodes() + [wave_net.make_node(name="wave_network", jit=cfg.jit)]

Creating Geometry and Adding Constraints

For generating geometry of this problem, use the Line1D(pt1, pt2). The boundaries for Line1D are the end points and the interior covers all the points in between the two endpoints.

As described earlier, use the param_ranges attribute to solve for time. To define the initial conditions, set param_ranges={t_symbol: 0.0}. You will solve the wave equation for \(t=(0, 2\pi)\). The derivative boundary condition can be handled by specifying the key 'u__t'. The derivatives of the variables can be specified by adding '__t' for time derivative and '__x' for spatial derivative ('u__x' for \(\partial u/\partial x\), 'u__x__x' for \(\partial^2 u/\partial x^2\), etc.).

The below code uses these tools to generate the geometry, initial/boundary conditions and the equations.

    # add constraints to solver
    # make geometry
    x, t_symbol = Symbol("x"), Symbol("t")
    L = float(np.pi)
    geo = Line1D(0, L)
    time_range = {t_symbol: (0, 2 * L)}

    # make domain
    domain = Domain()

    # initial condition
    IC = PointwiseInteriorConstraint(
        outvar={"u": sin(x), "u__t": sin(x)},
        bounds={x: (0, L)},
        lambda_weighting={"u": 1.0, "u__t": 1.0},
        param_ranges={t_symbol: 0.0},
    domain.add_constraint(IC, "IC")

    # boundary condition
    BC = PointwiseBoundaryConstraint(
        outvar={"u": 0},
    domain.add_constraint(BC, "BC")

    # interior
    interior = PointwiseInteriorConstraint(
        outvar={"wave_equation": 0},
        bounds={x: (0, L)},
    domain.add_constraint(interior, "interior")

Adding Validation data from analytical solutions

For this problem, the analytical solution can be solved simultaneously instead of importing a .csv file. This code shows the process define such a dataset:

    deltaT = 0.01
    deltaX = 0.01
    x = np.arange(0, L, deltaX)
    t = np.arange(0, 2 * L, deltaT)
    X, T = np.meshgrid(x, t)
    X = np.expand_dims(X.flatten(), axis=-1)
    T = np.expand_dims(T.flatten(), axis=-1)
    u = np.sin(X) * (np.cos(T) + np.sin(T))
    invar_numpy = {"x": X, "t": T}
    outvar_numpy = {"u": u}
    validator = PointwiseValidator(invar_numpy, outvar_numpy, nodes, batch_size=128)


The figure below shows the comparison of Modulus results with the analytical solution. You can see that the error in Modulus prediction increases as the time increases. Some advanced approaches to tackle transient problems are covered in Moving Time Window: Taylor Green Vortex Decay.

Left: Modulus. Center: Analytical Solution. Right: Difference

Fig. 38 Left: Modulus. Center: Analytical Solution. Right: Difference

Temporal loss weighting and time-marching schedule

We have observed that two simple tricks, namely temporal loss weighting and time-marching schedule, can improve the performance of the continuous time approach for transient simulations. The idea behind the temporal loss weighting is to weight the loss terms temporally such that the terms corresponding to earlier times have a larger weight compared to those corresponding to later times in the time domain. For example, our temporal loss weighting can take the following linear form

(128)\[\lambda_T = C_T \left( 1 - \frac{t}{T} \right) + 1\]

Here, \(\lambda_T\) is the temporal loss weight, \(C_T\) is a constant that controls the weight scale, \(T\) is the upper bound for the time domain, and \(t\) is time.

The idea behind the time marching schedule is to consider the time domain upper bound T to be variable and a function of the training iteration s. This variable can then change such that more training iterations are taken for the earlier times compared to later times. Several schedules can be considered, for instance, we can use the following

(129)\[T_v (s) = \min \left( 1, \frac{2s}{S} \right)\]

Where \(T_v (s)\) is the variable time domain upper bound, \(s\) is the training iteration number, and \(S\) is the maximum number of training iterations. At each training iteration, we will then sample continuously from the time domain in the range of \([0, T_v (s)]\).

The below figures show the Modulus validation error for models trained with and without using temporal loss weighting and time marching for transient 1D, 2D wave examples and a 2D channel flow over a bump. It is evident that these two simple tricks can improve the training accuracy.

Modulus validation error for the 1D transient wave example: (a) standard continuous time approach; (b) continuous time approach with temporal loss weighting and time marching.

Fig. 39 Modulus validation error for the 1D transient wave example: (a) standard continuous time approach; (b) continuous time approach with temporal loss weighting and time marching.

Modulus validation error for the 2D transient wave example: (a) standard continuous time approach; (b) continuous time approach with temporal loss weighting and time marching.

Fig. 40 Modulus validation error for the 2D transient wave example: (a) standard continuous time approach; (b) continuous time approach with temporal loss weighting and time marching.

Modulus validation error for a 2D transient channel flow over a bump: (a) standard continuous time approach; (b) continuous time approach with temporal loss weighting.

Fig. 41 Modulus validation error for a 2D transient channel flow over a bump: (a) standard continuous time approach; (b) continuous time approach with temporal loss weighting.