MLIR-TensorRT Wrapper#
Python wrapper for MLIR-TensorRT compiler and execution.
Overview#
The MLIR-TensorRT Wrapper module provides Python interfaces for:
Compilation - Compile JAX/MLIR to TensorRT engines
Execution - Execute TensorRT engines from Python
Custom Calls - Register and transform MLIR custom calls
Type Mapping - Map between MLIR and TensorRT types
API Reference#
A wrapper for the MLIR-TensorRT compiler and runtime.
- ran.mlir_trt_wrapper.compile(
- stablehlo_mlir: str,
- name: str,
- export_path: Path,
- *,
- export_mlir: bool = True,
- mlir_entry_point: str = 'main',
- mlir_tensorrt_compilation_flags: list[str] | None = None,
- enable_strongly_typed: bool = True,
- trt_plugin_configs: dict[str, dict[str, str | dict[str, str]]] | None = None,
- mlir_trt_compiler: str | None = None,
Compile StableHLO MLIR module to GPU executable.
- Parameters:
stablehlo_mlir – StableHLO MLIR module as a string
name – Function or module name
export_path – Export path for compiled artifacts
export_mlir – Whether to save MLIR modules to disk
mlir_entry_point – Entry point function name
mlir_tensorrt_compilation_flags – Additional compilation flags. If None, defaults to [“tensorrt-builder-opt-level=0”]. Flags should be provided without ‘–’ prefix. User-provided flags are merged with defaults and mandatory flags (duplicates removed). Available flags: tensorrt-builder-opt-level=<0-5>, tensorrt-fp16, tensorrt-workspace-memory-pool-limit=<size>, tensorrt-enable-timing-cache, tensorrt-timing-cache-path=<path>, artifacts-dir=<path>, etc. Run ‘mlir-tensorrt-compiler –help’ for complete list
enable_strongly_typed – Enable strongly-typed TensorRT mode (default: True). Prevents difficult-to-debug type-related issues. Set to False only if you understand the implications (expert use only).
trt_plugin_configs – Plugin configurations mapping target names to config dicts with keys: dso_path, plugin_version, plugin_namespace, creator_func, creator_params. Transforms custom calls to TensorRT opaque plugins
mlir_trt_compiler – Path to mlir-tensorrt-compiler binary. If None, searches PATH
- Return type:
MLIR-TensorRT executable object
- Raises:
RuntimeError – If MLIR transformation, compilation, or artifact generation fails:
FileNotFoundError – If mlir-tensorrt-compiler is not found:
Examples
>>> import os >>> import jax >>> import jax.export >>> from jax import numpy as jnp >>> from ran import mlir_trt_wrapper as mtw >>> from pathlib import Path >>> >>> def my_func(x, y): ... return x + y >>> >>> jit_func = jax.jit(my_func) >>> inputs = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) >>> exported = jax.export.export(jit_func)(*inputs) >>> stablehlo_mlir = exported.mlir_module() >>> >>> exe = mtw.compile( ... stablehlo_mlir=stablehlo_mlir, ... name="my_func", ... export_path=Path("./output"), ... ) >>> >>> # Expert use: disable strongly-typed mode >>> exe = mtw.compile( ... stablehlo_mlir=stablehlo_mlir, ... name="my_func", ... export_path=Path("./output"), ... enable_strongly_typed=False, ... )
- ran.mlir_trt_wrapper.execute(
- exe: mlir_tensorrt.compiler.api.Executable,
- inputs: tuple[jax.typing.ArrayLike, ...],
- outputs: tuple[jax.typing.ArrayLike, ...],
- *,
- sync_stream: bool,
- mlir_entry_point: str = 'main',
- validate: bool = True,
Execute MLIR-TensorRT executable using the runtime.
- Parameters:
exe – MLIR-TensorRT executable
inputs – Input arrays (must be C-contiguous, float32/int32/complex64)
outputs – Output arrays (pre-allocated, must be C-contiguous)
mlir_entry_point – Entry point function name
sync_stream – Whether to synchronize CUDA stream before/after copying results. Set to True when using TensorRT plugins that execute asynchronously. Set to False for standard operations without custom plugins.
validate – Whether to validate array requirements (recommended)
- Return type:
Output arrays with results copied from device to host
- Raises:
ValueError – If arrays don’t meet runtime requirements (when validate=True):
RuntimeError – If runtime initialization, execution, or result copying fails:
Important
Arrays must meet these requirements: - C-contiguous memory layout (use np.ascontiguousarray(arr)) - Supported dtypes: float32, float16, bfloat16 (if ml_dtypes available), int32, bool - Not supported: float64, int64, complex64, complex128 (convert to supported types)
Examples
>>> import numpy as np >>> import mlir_tensorrt.compiler.api as compiler >>> from pathlib import Path >>> >>> # Load compiled executable >>> with open(Path("output/add_func.bin"), "rb") as f: ... exe = compiler.Executable.from_bytes(f.read()) >>> >>> # Prepare C-contiguous float32 arrays >>> x = np.array([1.0, 2.0], dtype=np.float32) >>> y = np.array([3.0, 4.0], dtype=np.float32) >>> output = np.zeros(2, dtype=np.float32) >>> >>> # Execute >>> result = execute(exe, (x, y), (output,)) >>> print(result[0]) # [4.0, 6.0]
Notes
For JAX arrays, convert to numpy first: jax_array = jnp.array([1.0, 2.0]) np_array = np.asarray(jax_array, dtype=np.float32, order=’C’)
- ran.mlir_trt_wrapper.mlir_tensorrt_compiler(
- mlir_filepath: Path,
- output_dir: Path,
- *,
- entrypoint: str = 'main',
- host_target: str = 'emitc',
- mlir_tensorrt_compilation_flags: list[str] | None = None,
- mlir_trt_compiler: str | None = None,
- timeout: int = 180,
Wrap mlir-tensorrt-compiler tool.
- Parameters:
mlir_filepath – Path to the input MLIR file
output_dir – Directory to output the compiled artifacts
entrypoint – The entry point function name (default: “main”)
host_target – The host target for compilation (default: “emitc”). Available options: - “emitc”: Compile host code to C++ (default) - “executor”: Compile host code to MLIR-TRT interpretable executable - “llvm”: Compile host code to LLVM IR
mlir_tensorrt_compilation_flags – Optional list of compilation flags. Supported flags: tensorrt-builder-opt-level, tensorrt-workspace-memory-pool-limit, etc.
mlir_trt_compiler – Path to mlir-tensorrt-compiler binary. If None, checks MLIR_TRT_COMPILER_PATH env var, then searches PATH.
timeout – Timeout in seconds for compilation (default: 180). Prevents hanging processes that can leak GPU memory.
- Return type:
True if compilation was successful, False otherwise
- Raises:
FileNotFoundError – If mlir-tensorrt-compiler is not found:
RuntimeError – If compilation fails or times out:
- ran.mlir_trt_wrapper.register_custom_call_primitive( ) Callable[source]#
Register a custom call primitive.
- Parameters:
name – The name of the custom call primitive
fn_lowering – The lowering function for the custom call primitive that implements
primitive. (the custom call)
fn_abstract – The abstract function for the custom call primitive that evaluates
outputs. (the shapes and dtypes of the)
- Return type:
A callable that can be used to call the custom call primitive
- ran.mlir_trt_wrapper.transform_mlir_custom_call_to_trt_plugin( ) str[source]#
Transform stablehlo.custom_call operations to tensorrt.opaque_plugin operations.
This function parses MLIR code containing stablehlo.custom_call operations and transforms them into equivalent tensorrt.opaque_plugin operations. This is a temporary solution until the functionality is added to the MLIR TensorRT.
The regex pattern matches MLIR custom call operations with the following structure: %result_var = stablehlo.custom_call @target_name(operands) {attributes} : (input_types) -> output_types
- Parameters:
mlir_string – The MLIR string from exported.mlir_module()
plugin_configs –
Dictionary mapping target names to their plugin configurations. Each target name (e.g., “tensorrt_dmrs_plugin”) maps to a configuration dict. The key can optionally include a suffix matching the backend_config attribute (e.g., “tensorrt_cufft_plugin_forward”) to provide different configurations for the same target name with different backend_config values.
Configuration dict keys: - dso_path: Path to the plugin DSO file (required) - plugin_version: Version string for the plugin (optional, default: “1”) - plugin_namespace: Namespace for the plugin (optional, default: “”) - creator_func: Creator function name (optional, auto-generated if not provided) - creator_params: Dictionary of creator parameters (optional, defaults to {“dummy_param”: 0} if empty)
When a custom_call has a backend_config attribute, the lookup order is: 1. Try “{target_name}_{backend_config}” (e.g., “tensorrt_cufft_plugin_forward”) 2. Fall back to “{target_name}” if the suffixed key doesn’t exist
- Return type:
Transformed MLIR string with tensorrt.opaque_plugin operations
- Raises:
MLIRTransformationError – If the transformation fails, input is invalid, or: a custom_call operation is found without a corresponding config
Examples
Single plugin: >>> mlir_input = ‘’’ … %4 = stablehlo.custom_call @tensorrt_sequential_sum_plugin(%3) … {api_version = 2 : i32} : (tensor<30000xf32>) -> tensor<30000xf32> … ‘’’ >>> configs = { … “tensorrt_sequential_sum_plugin”: { … “dso_path”: “/path/to/plugin.so”, … “creator_params”: {“i32_param”: 10}, … } … } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs)
Multiple plugins: >>> mlir_input = ‘’’ … %4 = stablehlo.custom_call @tensorrt_dmrs_plugin(%3) … {api_version = 2 : i32} : (tensor<100xf32>) -> tensor<100xf32> … %5 = stablehlo.custom_call @tensorrt_cufft_plugin(%4) … {api_version = 2 : i32} : (tensor<100xf32>) -> tensor<100xf32> … ‘’’ >>> configs = { … “tensorrt_dmrs_plugin”: { … “dso_path”: “/path/to/dmrs.so”, … “creator_params”: {“sequence_length”: 3276}, … }, … “tensorrt_cufft_plugin”: { … “dso_path”: “/path/to/cufft.so”, … “creator_params”: {“fft_size”: 2048}, … }, … } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs)
Multiple plugins with backend_config for disambiguation: >>> mlir_input = ‘’’ … %4 = stablehlo.custom_call @tensorrt_cufft_plugin(%3) … {api_version = 2 : i32, backend_config = “forward”} : … (tensor<100xf32>) -> tensor<100xf32> … %5 = stablehlo.custom_call @tensorrt_cufft_plugin(%4) … {api_version = 2 : i32, backend_config = “inverse”} : … (tensor<100xf32>) -> tensor<100xf32> … ‘’’ >>> configs = { … “tensorrt_cufft_plugin_forward”: { … “dso_path”: “/path/to/cufft.so”, … “creator_params”: { … “fft_size”: 2048, … “direction”: 0, … }, … }, … “tensorrt_cufft_plugin_inverse”: { … “dso_path”: “/path/to/cufft.so”, … “creator_params”: { … “fft_size”: 2048, … “direction”: 1, … }, … }, … } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs)
- ran.mlir_trt_wrapper.validate_array(
- arr: jax.typing.ArrayLike,
- arg_name: str = 'array',
Validate array meets MLIR-TensorRT runtime requirements.
- Parameters:
arr – Array to validate
arg_name – Name of the argument for error messages
- Raises:
TypeError – If array is not numpy-compatible:
ValueError – If array format is unsupported:
Examples
>>> import numpy as np >>> x = np.array([1.0, 2.0], dtype=np.float32) >>> validate_array(x, "input_0") # OK >>> >>> bad = np.array([1.0, 2.0], dtype=np.float64) >>> validate_array(bad, "input_1") # Raises ValueError
Notes
MLIR-TensorRT runtime requires: - C-contiguous memory layout - Supported dtypes: float32, float16, bfloat16 (if ml_dtypes available), int32, bool - Not supported: float64, int64, complex64, complex128 (convert to supported types)