GPU Optimized Layers#

PhysicsNeMo is a framework for scientific AI workloads on NVIDIA GPUs. It is designed to help scientists and practitioners continue focusing on delivering accurate model results quickly by including many optimizations for efficient computational performance.

Below, find some examples of the optimizations enabled in PhysicsNeMo. Many of these layers are used in PhysicsNeMo models, but they are also available to use directly for developers with PhysicsNeMo installed. You can make use of these layers in your custom models.

Warp Accelerated Ball Query#

The PhysicsNeMo DoMINO model takes inspiration from classic stencil-based kernels in High Performance Computing and simulation codes. When learning projections from one set of points to another, DoMINO uses a radius-based selection. For each point in a set of queries, up to max_points from points are returned. This operation is similar to the query_ball_point function from scipy.spatial.KDTree, however, in physicsnemo.utils.neighbors you will find a GPU accelerated implementation leveraging the NVIDIA Warp library.

Transformer Engine Accelerated LayerNorm#

Many models, such as PhysicsNeMo’s implementation of MeshGraphNet, use LayerNorm to normalize data in the model and accelerate training. While accurate, for some instances of MeshGraphNet, the LayerNorm implementation in PyTorch was accounting for more than 25% of the execution time. To mitigate this, PhysicsNeMo provides an optimized wrapper to LayerNorm that can take advantage of the more optimized version of LayerNorm from TransformerEngine

In MeshGraphNet, training up to 200k nodes and 1.2M eddes in a single graph shows approximately 1/3 reduction in runtime. Also, using the PytorchGeometric backend instead of DGL, almost halves the latency for the training iteration.

../../_images/torch.float16_training_time.png

Fig. 10 Training time for MeshGraphNet comparing TransformerEngine LayerNorm to PyTorch LayerNorm, as well as PyTorch Geometric and DGL backends. Values are relative to DGL and torch.nn.LayerNorm, lower is better.#