MLIR-TensorRT Wrapper#

Python wrapper for MLIR-TensorRT compiler and execution.

Overview#

The MLIR-TensorRT Wrapper module provides Python interfaces for:

  • Compilation - Compile JAX/MLIR to TensorRT engines

  • Execution - Execute TensorRT engines from Python

  • Custom Calls - Register and transform MLIR custom calls

  • Type Mapping - Map between MLIR and TensorRT types

API Reference#

A wrapper for the MLIR-TensorRT compiler and runtime.

ran.mlir_trt_wrapper.compile(
stablehlo_mlir: str,
name: str,
export_path: Path,
*,
export_mlir: bool = True,
mlir_entry_point: str = 'main',
mlir_tensorrt_compilation_flags: list[str] | None = None,
enable_strongly_typed: bool = True,
trt_plugin_configs: dict[str, dict[str, str | dict[str, str]]] | None = None,
mlir_trt_compiler: str | None = None,
) mlir_tensorrt.compiler.api.Executable[source]#

Compile StableHLO MLIR module to GPU executable.

Parameters:
  • stablehlo_mlir – StableHLO MLIR module as a string

  • name – Function or module name

  • export_path – Export path for compiled artifacts

  • export_mlir – Whether to save MLIR modules to disk

  • mlir_entry_point – Entry point function name

  • mlir_tensorrt_compilation_flags – Additional compilation flags. If None, defaults to [“tensorrt-builder-opt-level=0”]. Flags should be provided without ‘–’ prefix. User-provided flags are merged with defaults and mandatory flags (duplicates removed). Available flags: tensorrt-builder-opt-level=<0-5>, tensorrt-fp16, tensorrt-workspace-memory-pool-limit=<size>, tensorrt-enable-timing-cache, tensorrt-timing-cache-path=<path>, artifacts-dir=<path>, etc. Run ‘mlir-tensorrt-compiler –help’ for complete list

  • enable_strongly_typed – Enable strongly-typed TensorRT mode (default: True). Prevents difficult-to-debug type-related issues. Set to False only if you understand the implications (expert use only).

  • trt_plugin_configs – Plugin configurations mapping target names to config dicts with keys: dso_path, plugin_version, plugin_namespace, creator_func, creator_params. Transforms custom calls to TensorRT opaque plugins

  • mlir_trt_compiler – Path to mlir-tensorrt-compiler binary. If None, searches PATH

Return type:

MLIR-TensorRT executable object

Raises:
  • RuntimeError – If MLIR transformation, compilation, or artifact generation fails:

  • FileNotFoundError – If mlir-tensorrt-compiler is not found:

Examples

>>> import os
>>> import jax
>>> import jax.export
>>> from jax import numpy as jnp
>>> from ran import mlir_trt_wrapper as mtw
>>> from pathlib import Path
>>>
>>> def my_func(x, y):
...     return x + y
>>>
>>> jit_func = jax.jit(my_func)
>>> inputs = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))
>>> exported = jax.export.export(jit_func)(*inputs)
>>> stablehlo_mlir = exported.mlir_module()
>>>
>>> exe = mtw.compile(
...     stablehlo_mlir=stablehlo_mlir,
...     name="my_func",
...     export_path=Path("./output"),
... )
>>>
>>> # Expert use: disable strongly-typed mode
>>> exe = mtw.compile(
...     stablehlo_mlir=stablehlo_mlir,
...     name="my_func",
...     export_path=Path("./output"),
...     enable_strongly_typed=False,
... )
ran.mlir_trt_wrapper.execute(
exe: mlir_tensorrt.compiler.api.Executable,
inputs: tuple[jax.typing.ArrayLike, ...],
outputs: tuple[jax.typing.ArrayLike, ...],
*,
sync_stream: bool,
mlir_entry_point: str = 'main',
validate: bool = True,
) tuple[jax.typing.ArrayLike, ...][source]#

Execute MLIR-TensorRT executable using the runtime.

Parameters:
  • exe – MLIR-TensorRT executable

  • inputs – Input arrays (must be C-contiguous, float32/int32/complex64)

  • outputs – Output arrays (pre-allocated, must be C-contiguous)

  • mlir_entry_point – Entry point function name

  • sync_stream – Whether to synchronize CUDA stream before/after copying results. Set to True when using TensorRT plugins that execute asynchronously. Set to False for standard operations without custom plugins.

  • validate – Whether to validate array requirements (recommended)

Return type:

Output arrays with results copied from device to host

Raises:
  • ValueError – If arrays don’t meet runtime requirements (when validate=True):

  • RuntimeError – If runtime initialization, execution, or result copying fails:

Important

Arrays must meet these requirements: - C-contiguous memory layout (use np.ascontiguousarray(arr)) - Supported dtypes: float32, float16, bfloat16 (if ml_dtypes available), int32, bool - Not supported: float64, int64, complex64, complex128 (convert to supported types)

Examples

>>> import numpy as np
>>> import mlir_tensorrt.compiler.api as compiler
>>> from pathlib import Path
>>>
>>> # Load compiled executable
>>> with open(Path("output/add_func.bin"), "rb") as f:
...     exe = compiler.Executable.from_bytes(f.read())
>>>
>>> # Prepare C-contiguous float32 arrays
>>> x = np.array([1.0, 2.0], dtype=np.float32)
>>> y = np.array([3.0, 4.0], dtype=np.float32)
>>> output = np.zeros(2, dtype=np.float32)
>>>
>>> # Execute
>>> result = execute(exe, (x, y), (output,))
>>> print(result[0])  # [4.0, 6.0]

Notes

For JAX arrays, convert to numpy first: jax_array = jnp.array([1.0, 2.0]) np_array = np.asarray(jax_array, dtype=np.float32, order=’C’)

ran.mlir_trt_wrapper.mlir_tensorrt_compiler(
mlir_filepath: Path,
output_dir: Path,
*,
entrypoint: str = 'main',
host_target: str = 'emitc',
mlir_tensorrt_compilation_flags: list[str] | None = None,
mlir_trt_compiler: str | None = None,
timeout: int = 180,
) bool[source]#

Wrap mlir-tensorrt-compiler tool.

Parameters:
  • mlir_filepath – Path to the input MLIR file

  • output_dir – Directory to output the compiled artifacts

  • entrypoint – The entry point function name (default: “main”)

  • host_target – The host target for compilation (default: “emitc”). Available options: - “emitc”: Compile host code to C++ (default) - “executor”: Compile host code to MLIR-TRT interpretable executable - “llvm”: Compile host code to LLVM IR

  • mlir_tensorrt_compilation_flags – Optional list of compilation flags. Supported flags: tensorrt-builder-opt-level, tensorrt-workspace-memory-pool-limit, etc.

  • mlir_trt_compiler – Path to mlir-tensorrt-compiler binary. If None, checks MLIR_TRT_COMPILER_PATH env var, then searches PATH.

  • timeout – Timeout in seconds for compilation (default: 180). Prevents hanging processes that can leak GPU memory.

Return type:

True if compilation was successful, False otherwise

Raises:
ran.mlir_trt_wrapper.register_custom_call_primitive(
name: str,
fn_lowering: Callable,
fn_abstract: Callable,
) Callable[source]#

Register a custom call primitive.

Parameters:
  • name – The name of the custom call primitive

  • fn_lowering – The lowering function for the custom call primitive that implements

  • primitive. (the custom call)

  • fn_abstract – The abstract function for the custom call primitive that evaluates

  • outputs. (the shapes and dtypes of the)

Return type:

A callable that can be used to call the custom call primitive

ran.mlir_trt_wrapper.transform_mlir_custom_call_to_trt_plugin(
mlir_string: str,
plugin_configs: dict[str, dict[str, str | dict[str, str]]],
) str[source]#

Transform stablehlo.custom_call operations to tensorrt.opaque_plugin operations.

This function parses MLIR code containing stablehlo.custom_call operations and transforms them into equivalent tensorrt.opaque_plugin operations. This is a temporary solution until the functionality is added to the MLIR TensorRT.

The regex pattern matches MLIR custom call operations with the following structure: %result_var = stablehlo.custom_call @target_name(operands) {attributes} : (input_types) -> output_types

Parameters:
  • mlir_string – The MLIR string from exported.mlir_module()

  • plugin_configs

    Dictionary mapping target names to their plugin configurations. Each target name (e.g., “tensorrt_dmrs_plugin”) maps to a configuration dict. The key can optionally include a suffix matching the backend_config attribute (e.g., “tensorrt_cufft_plugin_forward”) to provide different configurations for the same target name with different backend_config values.

    Configuration dict keys: - dso_path: Path to the plugin DSO file (required) - plugin_version: Version string for the plugin (optional, default: “1”) - plugin_namespace: Namespace for the plugin (optional, default: “”) - creator_func: Creator function name (optional, auto-generated if not provided) - creator_params: Dictionary of creator parameters (optional, defaults to {“dummy_param”: 0} if empty)

    When a custom_call has a backend_config attribute, the lookup order is: 1. Try “{target_name}_{backend_config}” (e.g., “tensorrt_cufft_plugin_forward”) 2. Fall back to “{target_name}” if the suffixed key doesn’t exist

Return type:

Transformed MLIR string with tensorrt.opaque_plugin operations

Raises:

MLIRTransformationError – If the transformation fails, input is invalid, or: a custom_call operation is found without a corresponding config

Examples

Single plugin: >>> mlir_input = ‘’’ … %4 = stablehlo.custom_call @tensorrt_sequential_sum_plugin(%3) … {api_version = 2 : i32} : (tensor<30000xf32>) -> tensor<30000xf32> … ‘’’ >>> configs = { … “tensorrt_sequential_sum_plugin”: { … “dso_path”: “/path/to/plugin.so”, … “creator_params”: {“i32_param”: 10}, … } … } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs)

Multiple plugins: >>> mlir_input = ‘’’ … %4 = stablehlo.custom_call @tensorrt_dmrs_plugin(%3) … {api_version = 2 : i32} : (tensor<100xf32>) -> tensor<100xf32> … %5 = stablehlo.custom_call @tensorrt_cufft_plugin(%4) … {api_version = 2 : i32} : (tensor<100xf32>) -> tensor<100xf32> … ‘’’ >>> configs = { … “tensorrt_dmrs_plugin”: { … “dso_path”: “/path/to/dmrs.so”, … “creator_params”: {“sequence_length”: 3276}, … }, … “tensorrt_cufft_plugin”: { … “dso_path”: “/path/to/cufft.so”, … “creator_params”: {“fft_size”: 2048}, … }, … } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs)

Multiple plugins with backend_config for disambiguation: >>> mlir_input = ‘’’ … %4 = stablehlo.custom_call @tensorrt_cufft_plugin(%3) … {api_version = 2 : i32, backend_config = “forward”} : … (tensor<100xf32>) -> tensor<100xf32> … %5 = stablehlo.custom_call @tensorrt_cufft_plugin(%4) … {api_version = 2 : i32, backend_config = “inverse”} : … (tensor<100xf32>) -> tensor<100xf32> … ‘’’ >>> configs = { … “tensorrt_cufft_plugin_forward”: { … “dso_path”: “/path/to/cufft.so”, … “creator_params”: { … “fft_size”: 2048, … “direction”: 0, … }, … }, … “tensorrt_cufft_plugin_inverse”: { … “dso_path”: “/path/to/cufft.so”, … “creator_params”: { … “fft_size”: 2048, … “direction”: 1, … }, … }, … } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs)

ran.mlir_trt_wrapper.validate_array(
arr: jax.typing.ArrayLike,
arg_name: str = 'array',
) None[source]#

Validate array meets MLIR-TensorRT runtime requirements.

Parameters:
  • arr – Array to validate

  • arg_name – Name of the argument for error messages

Raises:
  • TypeError – If array is not numpy-compatible:

  • ValueError – If array format is unsupported:

Examples

>>> import numpy as np
>>> x = np.array([1.0, 2.0], dtype=np.float32)
>>> validate_array(x, "input_0")  # OK
>>>
>>> bad = np.array([1.0, 2.0], dtype=np.float64)
>>> validate_array(bad, "input_1")  # Raises ValueError

Notes

MLIR-TensorRT runtime requires: - C-contiguous memory layout - Supported dtypes: float32, float16, bfloat16 (if ml_dtypes available), int32, bool - Not supported: float64, int64, complex64, complex128 (convert to supported types)