Source code for ran.mlir_trt_wrapper.execute

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Executor for compiled StableHLO MLIR modules using MLIR-TensorRT runtime."""

import ctypes
import logging

import mlir_tensorrt.compiler.api as compiler  # type: ignore
import mlir_tensorrt.runtime.api as runtime  # type: ignore
import numpy as np
from jax.typing import ArrayLike

from ran.mlir_trt_wrapper.mlir_trt_types import (
    HAS_ML_DTYPES,
    validate_array,
)

if HAS_ML_DTYPES:
    from ml_dtypes import bfloat16
else:
    bfloat16 = None

logger = logging.getLogger(__name__)


[docs] def execute( exe: compiler.Executable, inputs: tuple[ArrayLike, ...], outputs: tuple[ArrayLike, ...], *, sync_stream: bool, mlir_entry_point: str = "main", validate: bool = True, ) -> tuple[ArrayLike, ...]: """Execute MLIR-TensorRT executable using the runtime. Args: exe: MLIR-TensorRT executable inputs: Input arrays (must be C-contiguous, float32/int32/complex64) outputs: Output arrays (pre-allocated, must be C-contiguous) mlir_entry_point: Entry point function name sync_stream: Whether to synchronize CUDA stream before/after copying results. Set to True when using TensorRT plugins that execute asynchronously. Set to False for standard operations without custom plugins. validate: Whether to validate array requirements (recommended) Returns ------- Output arrays with results copied from device to host Raises ------ ValueError: If arrays don't meet runtime requirements (when validate=True) RuntimeError: If runtime initialization, execution, or result copying fails Important: Arrays must meet these requirements: - **C-contiguous memory layout** (use `np.ascontiguousarray(arr)`) - **Supported dtypes**: float32, float16, bfloat16 (if ml_dtypes available), int32, bool - **Not supported**: float64, int64, complex64, complex128 (convert to supported types) Examples -------- >>> import numpy as np >>> import mlir_tensorrt.compiler.api as compiler >>> from pathlib import Path >>> >>> # Load compiled executable >>> with open(Path("output/add_func.bin"), "rb") as f: ... exe = compiler.Executable.from_bytes(f.read()) >>> >>> # Prepare C-contiguous float32 arrays >>> x = np.array([1.0, 2.0], dtype=np.float32) >>> y = np.array([3.0, 4.0], dtype=np.float32) >>> output = np.zeros(2, dtype=np.float32) >>> >>> # Execute >>> result = execute(exe, (x, y), (output,)) >>> print(result[0]) # [4.0, 6.0] Notes ----- For JAX arrays, convert to numpy first: `jax_array = jnp.array([1.0, 2.0])` `np_array = np.asarray(jax_array, dtype=np.float32, order='C')` """ if validate: for i, inp in enumerate(inputs): validate_array(inp, f"input[{i}]") for i, out in enumerate(outputs): validate_array(out, f"output[{i}]") try: client = runtime.RuntimeClient() devices = client.get_devices() session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) session = runtime.RuntimeSession(client, session_options, exe) except Exception as e: error_msg = "Failed to initialize MLIR-TensorRT runtime session" logger.exception(error_msg) raise RuntimeError(error_msg) from e # Create device memrefs for inputs and outputs # Special handling for bfloat16: reinterpret as uint16 for memref creation try: inputs_device = [] for x in inputs: arr = np.asarray(x) if HAS_ML_DTYPES and arr.dtype == bfloat16: # Reinterpret bfloat16 as uint16 for memref creation arr_reinterpreted = arr.view(np.uint16) memref = client.create_memref( arr_reinterpreted, dtype=runtime.ScalarTypeCode.bf16, device=devices[0] ) else: memref = client.create_memref(arr, device=devices[0]) inputs_device.append(memref) outputs_device = [] for y in outputs: arr = np.asarray(y) if HAS_ML_DTYPES and arr.dtype == bfloat16: # Reinterpret bfloat16 as uint16 for memref creation arr_reinterpreted = arr.view(np.uint16) memref = client.create_memref( arr_reinterpreted, dtype=runtime.ScalarTypeCode.bf16, device=devices[0] ) else: memref = client.create_memref(arr, device=devices[0]) outputs_device.append(memref) except Exception as e: error_msg = "Failed to create device memrefs for inputs/outputs" logger.exception(error_msg) raise RuntimeError(error_msg) from e # Execute on device try: session.execute_function( name=mlir_entry_point, in_args=inputs_device, out_args=outputs_device, ) # Optionally synchronize stream to ensure all GPU work completes before copying results # This is necessary when TensorRT plugins execute asynchronously on their own streams if sync_stream: devices[0].stream.sync() logger.debug("CUDA stream synchronized after execute") except Exception as e: error_msg = f"TensorRT execution failed for function '{mlir_entry_point}'" logger.exception(error_msg) raise RuntimeError(error_msg) from e # Copy results from device to host output buffers try: for i, output_device in enumerate(outputs_device): result = client.copy_to_host(output_device) output_arr = np.asarray(outputs[i]) ctypes.memmove(output_arr.ctypes.data, result.ptr, output_arr.nbytes) # Synchronize again after copy to ensure host buffers are populated before returning # This ensures copy_to_host (which may be asynchronous) completes before we use the data if sync_stream: devices[0].stream.sync() logger.debug("CUDA stream synchronized after copy to host") except Exception as e: error_msg = "Failed to copy results from device to host" logger.exception(error_msg) raise RuntimeError(error_msg) from e return outputs