3. MLIR-TensorRT Tutorial#

Step 1: Introduction#

This tutorial demonstrates the basic workflow for compiling an example JAX function to a TensorRT engine for the NVIDIA Aerial Framework runtime. We will use NVIDIA’s MLIR-TensorRT and TensorRT compilers.

Procedure:

  1. Define a JIT-able JAX function.

  2. Export the function to StableHLO.

  3. Compile the function using MLIR-TensorRT and TensorRT.

  4. Visualize results showing JAX and TensorRT produce identical outputs

Time: ~30 minutes

Step 2: Prerequisites and Environment Setup#

This tutorial requires:

  • Running inside the Aerial Framework Docker container

  • Completion of the Getting Started tutorial

Step 3: Configure CMake and Build Required Targets#

Python Environment Setup: This tutorial requires the RAN Python virtual environment with MLIR-TensorRT support. First time setup may take several minutes to download and install the required Python packages.

[ ]:
import os
import sys

from tutorial_utils import (
    build_cmake_target,
    configure_cmake,
    get_project_root,
    is_running_in_docker,
    load_ran_env_file,
    require_mlir_trt,
)

# Ensure running inside Docker container
if not is_running_in_docker():
    msg = (
        "This notebook must be run inside the Docker container. "
        "Please refer to the Getting Started tutorial for instructions."
    )
    raise RuntimeError(msg)

PROJECT_ROOT = get_project_root()
RAN_PY_PATH = PROJECT_ROOT / "ran" / "py"
PRESET = "gcc-release"
RAN_BUILD_DIR = PROJECT_ROOT / "out" / "build" / PRESET

# Configure CMake if needed
configure_cmake(RAN_BUILD_DIR, PRESET)

# Build required targets (first time may take a few minutes)
try:
    build_cmake_target(RAN_BUILD_DIR, ["py_ran_setup", "sync_env_python"])
except RuntimeError as e:
    print(f"\n✗ Build failed: {e}\n")
    print("To view full build output and fix issues, enter the container:")
    print("  docker exec -it aerial-framework-base-$USER bash -l")
    print(f"  cmake --build out/build/{PRESET} --target py_ran_setup sync_env_python")
    sys.exit(1)

# Load environment variables from .env.python (includes MLIR_TRT_COMPILER_PATH)
load_ran_env_file()

# Check if MLIR-TensorRT is enabled
require_mlir_trt()

print(f"\nBuild directory: {RAN_BUILD_DIR}")
print("✅ Step 3 complete: CMake configured and targets built")
Building py_ran_setup, sync_env_python...
✓ py_ran_setup, sync_env_python ready

Build directory: /opt/nvidia/aerial-framework/out/build/gcc-release
✅ Step 3 complete: CMake configured and targets built

Step 4: Import Dependencies#

Import the required packages from the RAN Python environment. These were installed when the docs environment was set up via CMake.

[ ]:
# TensorRT enables lazy loading of CUDA modules (improves loading time)
os.environ["CUDA_MODULE_LOADING"] = "LAZY"

# Available JAX backends are: cpu, cuda, and mlir_tensorrt
os.environ["JAX_PLATFORMS"] = "cuda"

# Limit JAX GPU memory pre-allocation to prevent OOM issues in CI
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.05"

# Third-party imports
import tempfile
from pathlib import Path

import jax
import numpy as np
from jax import export, numpy as jnp
from matplotlib import pyplot as plt

# Aerial Framework imports
from ran import mlir_trt_wrapper as mtw

# Set up TensorRT engine directory in a temporary location
build_dir = Path(tempfile.mkdtemp(prefix="mlir_trt_tutorial_"))
os.environ["RAN_TRT_ENGINE_PATH"] = str(build_dir)
print(f"Temporary build directory: {build_dir}")

print("✅ Step 4 complete: All imports successful!")
Temporary build directory: /tmp/mlir_trt_tutorial_xrcvexwx
✅ Step 4 complete: All imports successful!

Step 5: JAX Function#

Here we define the JAX function that we want to compile using TensorRT. We will use a simple finite impulse response (FIR) filter as an example. The function needs to be JIT-able and exportable to StableHLO, see https://docs.jax.dev/en/latest/jit-compilation.html for more details.

[ ]:

# FIR filter coefficients KERNEL_COEFFS = jnp.array([0.25, 0.5, 0.25]) # Flip the kernel coefficients for the convolution operation, # and reshape to (K, 1, 1) for the convolution operation. KERNEL = KERNEL_COEFFS[::-1][:, None, None] def my_func(x: jnp.ndarray) -> jnp.ndarray: """Finite impulse response (FIR) filter.""" y = jax.lax.conv_general_dilated( lhs=x[None, :, None], # input, shape (1, T, 1) rhs=KERNEL, # kernel, shape (K, 1, 1) window_strides=(1,), padding="SAME", lhs_dilation=None, rhs_dilation=None, dimension_numbers=("NWC", "WIO", "NWC"), ) return y[0, :, 0] # Test the FIR filter. N = 400 n = jnp.linspace(0.0, 1.0, N) key = jax.random.key(0) x = jnp.sin(2 * jnp.pi * n) + 0.3 * jax.random.normal(key, (N,)) y = my_func(x) plt.figure(figsize=(12, 6)) plt.plot(n, x, label="x[n] (input)", alpha=0.6) plt.plot(n, y, label="y[n] (filtered)", linewidth=2) plt.xlabel("n") plt.legend() plt.tight_layout() plt.show() print("✅ Step 5 complete: FIR filter defined and plotted.")
../../_images/tutorials_generated_mlir_trt_tutorial_7_0.png
✅ Step 5 complete: FIR filter defined and plotted.

Step 6: Export to StableHLO#

Export the FIR filter to StableHLO IR. The following function is used to prettyprint the StableHLO IR.

[ ]:
from jax._src.interpreters import mlir as jax_mlir
from jax._src.lib.mlir import ir


def _get_stablehlo_asm(module_str):
    with jax_mlir.make_ir_context():
        stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())
        return stablehlo_module.operation.get_asm(large_elements_limit=20)
[ ]:
# JIT compile the function
jit_my_func = jax.jit(my_func)

# Export to StableHLO
exported = export.export(jit_my_func)(x)

# Get the StableHLO MLIR module
stablehlo_mlir = exported.mlir_module()

print("StableHLO MLIR Representation:")
print("=" * 80)
print(_get_stablehlo_asm(stablehlo_mlir))

# Save the StableHLO MLIR to a file for inspection
stablehlo_file = build_dir / "stablehlo_mlir.mlir"
with open(stablehlo_file, "w") as f:
    f.write(stablehlo_mlir)

print(f"✅ Step 6 complete: Exported StableHLO MLIR to {stablehlo_file}")
StableHLO MLIR Representation:
================================================================================
module @jit_my_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<400xf32>) -> (tensor<400xf32> {jax.result_info = "result"}) {
    %cst = stablehlo.constant dense<[[[2.500000e-01]], [[5.000000e-01]], [[2.500000e-01]]]> : tensor<3x1x1xf32>
    %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<400xf32>) -> tensor<1x400x1xf32>
    %1 = stablehlo.convolution(%0, %cst) dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], window = {pad = [[1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x400x1xf32>, tensor<3x1x1xf32>) -> tensor<1x400x1xf32>
    %2 = stablehlo.slice %1 [0:1, 0:400, 0:1] : (tensor<1x400x1xf32>) -> tensor<1x400x1xf32>
    %3 = stablehlo.reshape %2 : (tensor<1x400x1xf32>) -> tensor<400xf32>
    return %3 : tensor<400xf32>
  }
}

✅ Step 6 complete: Exported StableHLO MLIR to /tmp/mlir_trt_tutorial_xrcvexwx/stablehlo_mlir.mlir

Step 7: Compile to TensorRT and Execute#

Now we’ll use the MLIR-TensorRT and TensorRT compilers to lower the StableHLO MLIR to a TensorRT engine.

[ ]:

mlir_tensorrt_compilation_flags = [ "tensorrt-builder-opt-level=0", "tensorrt-workspace-memory-pool-limit=50MiB", f"artifacts-dir={build_dir}", ] exe = mtw.compile( stablehlo_mlir=stablehlo_mlir, name="my_func", export_path=build_dir, mlir_entry_point="main", mlir_tensorrt_compilation_flags=mlir_tensorrt_compilation_flags, trt_plugin_configs={}, ) print("✓ Compiled to TensorRT engine") # Show build directory contents print("\nBuild directory contents:") for item in sorted(build_dir.iterdir()): if item.is_file(): size_kb = item.stat().st_size / 1024 print(f" {item.name} ({size_kb:.1f} KB)") else: print(f" {item.name}/ (directory)") print("✅ Step 7 complete: Function compiled to TensorRT engine")
✓ Compiled to TensorRT engine

Build directory contents:
  my_func.bin (209.3 KB)
  my_func.opaque_plugin.stablehlo.mlir (3.4 KB)
  my_func.original.stablehlo.mlir (3.4 KB)
  my_func.stablehlo.mlir (1.0 KB)
  output.cpp (1.9 KB)
  stablehlo_mlir.mlir (3.4 KB)
  tensorrt_cluster_engine_data.trtengine (207.9 KB)
✅ Step 7 complete: Function compiled to TensorRT engine

Step 8: Execute the TensorRT Engine#

Execute the TensorRT engine using the MLIR-TensorRT runtime. This step is useful to verify that the TensorRT engine outputs match the JAX outputs (and, therefore, the engine is ready for integration into the Aerial Framework runtime).

[ ]:

# Execute the TensorRT engine outputs = (np.zeros_like(y),) mtw.execute( exe=exe, inputs=(np.array(x),), outputs=outputs, sync_stream=True, mlir_entry_point="main", ) print(f"\nResults match: {np.allclose(outputs[0], y, rtol=1e-5, atol=1e-5)}") # Plot all results in a single figure plt.figure(figsize=(12, 6)) plt.plot(n, x, "gray", alpha=0.4, linewidth=1, label="x[n] (Input)") plt.plot(n, y, "gray", alpha=0.4, linewidth=1, label="y[n] (JAX)") plt.plot( n, outputs[0], "b-", linewidth=1.5, marker="o", markersize=3, markevery=5, label="y[n] (TensorRT)", ) plt.xlabel("n", fontsize=11) plt.title("FIR Filter: x[n] → y[n]\n", fontsize=12) plt.legend(loc="best", fontsize=10) plt.tight_layout() plt.show() print("✅ Step 8 complete: TensorRT engine executed and verified") print("✅ Tutorial complete!")

Results match: True
../../_images/tutorials_generated_mlir_trt_tutorial_14_1.png
✅ Step 8 complete: TensorRT engine executed and verified
✅ Tutorial complete!

Step 9: Summary and Next Steps#

What we accomplished:

  1. ✅ Defined a simple FIR filter in JAX

  2. ✅ Exported the FIR filter to StableHLO

  3. ✅ Compiled the FIR filter to a TensorRT engine

  4. ✅ Verified correctness of the TensorRT engine

Next steps:

  • Experiment with my_func.

    • Try different functions to see how the MLIR-TensorRT compiler handles them.

    • The MLIR-TensorRT compiler will try to lower the function to TensorRT operations and data types.

      • If successful, MLIR-TensorRT will handover to the TensorRT compiler backend to produce a single TensorRT engine.

      • If MLIR-TensorRT cannot lower the function to TensorRT operations and data types, it will fall back to TileIR (experimental) and MLIR-LLVM kernel generator backends.

  • Move onto the PUSCH Receiver.