Source code for ran.mlir_trt_wrapper.mlir_trt_compiler_wrapper
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrap mlir-tensorrt-compiler subprocess calls."""
import logging
import os
import shutil
import subprocess
import sys
from pathlib import Path
logger = logging.getLogger(__name__)
def _build_compiler_env() -> dict[str, str]:
"""Build environment for mlir-tensorrt-compiler subprocess.
The mlir-tensorrt-compiler binary needs to find libtvm_ffi.so at runtime.
This function adds the TVM library path to LD_LIBRARY_PATH for the subprocess
without modifying the parent process environment.
Returns
-------
Environment dict with TVM library path added to LD_LIBRARY_PATH
"""
env = os.environ.copy()
# Add TVM library path for libtvm_ffi.so
venv_path = Path(sys.executable).parent.parent
tvm_lib_path = (
venv_path
/ "lib"
/ f"python{sys.version_info.major}.{sys.version_info.minor}"
/ "site-packages"
/ "tvm_ffi"
/ "lib"
)
if tvm_lib_path.exists():
current_ld = env.get("LD_LIBRARY_PATH", "")
env["LD_LIBRARY_PATH"] = f"{tvm_lib_path}:{current_ld}" if current_ld else str(tvm_lib_path)
logger.debug(f"Added TVM library path to subprocess environment: {tvm_lib_path}")
else:
logger.warning(f"TVM library path not found: {tvm_lib_path}")
return env
def _post_process_emitc_cpp(output_dir: Path) -> None:
"""Post-process generated C++ code to fix EmitC backend issues.
This function fixes the EmitC backend issue where void main() is generated:
It changes 'void main(...)' to 'void _main(...)' to avoid returning void from main
and multiple main functions
Args:
output_dir: Directory containing the generated output.cpp file
"""
output_cpp = output_dir / "output.cpp"
if not output_cpp.exists():
logger.warning(f"Expected output.cpp not found in {output_dir}")
return
# Read the generated C++ file
with output_cpp.open("r") as f:
content = f.read()
# Apply fixes
original_content = content
# Fix: Change 'void main(...)' to 'void _main(...)'
content = content.replace("void main(", "void _main(")
# Only write if changes were made
if content != original_content:
with output_cpp.open("w") as f:
f.write(content)
logger.info(f"Post-processed {output_cpp} to fix EmitC backend issues")
else:
logger.info(f"No changes needed for {output_cpp}")
[docs]
def mlir_tensorrt_compiler(
mlir_filepath: Path,
output_dir: Path,
*,
entrypoint: str = "main",
host_target: str = "emitc",
mlir_tensorrt_compilation_flags: list[str] | None = None,
mlir_trt_compiler: str | None = None,
timeout: int = 180,
) -> bool:
"""
Wrap mlir-tensorrt-compiler tool.
Args:
mlir_filepath: Path to the input MLIR file
output_dir: Directory to output the compiled artifacts
entrypoint: The entry point function name (default: "main")
host_target: The host target for compilation (default: "emitc").
Available options:
- "emitc": Compile host code to C++ (default)
- "executor": Compile host code to MLIR-TRT interpretable executable
- "llvm": Compile host code to LLVM IR
mlir_tensorrt_compilation_flags: Optional list of compilation flags.
Supported flags: tensorrt-builder-opt-level, tensorrt-workspace-memory-pool-limit, etc.
mlir_trt_compiler: Path to mlir-tensorrt-compiler binary.
If None, checks MLIR_TRT_COMPILER_PATH env var, then searches PATH.
timeout: Timeout in seconds for compilation (default: 180).
Prevents hanging processes that can leak GPU memory.
Returns
-------
True if compilation was successful, False otherwise
Raises
------
FileNotFoundError: If mlir-tensorrt-compiler is not found
RuntimeError: If compilation fails or times out
"""
if mlir_tensorrt_compilation_flags is None:
mlir_tensorrt_compilation_flags = []
# --------------------------------
# Find the MLIR-TensorRT compiler binary
# --------------------------------
if mlir_trt_compiler is None:
# Priority 1: Check environment variable (set by CMake)
mlir_trt_compiler = os.getenv("MLIR_TRT_COMPILER_PATH")
if mlir_trt_compiler:
logger.debug(f"Using MLIR-TensorRT compiler from env: {mlir_trt_compiler}")
else:
# Priority 2: Try to find mlir-tensorrt-compiler in PATH
mlir_trt_compiler = shutil.which("mlir-tensorrt-compiler")
if mlir_trt_compiler:
logger.debug(f"Found MLIR-TensorRT compiler in PATH: {mlir_trt_compiler}")
if not mlir_trt_compiler:
error_msg = (
"mlir-tensorrt-compiler not found. "
"Either set MLIR_TRT_COMPILER_PATH environment variable, "
"add it to PATH, or pass mlir_trt_compiler parameter explicitly."
)
logger.error(error_msg)
raise FileNotFoundError(error_msg)
# Ensure output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
# --------------------------------
# Build the command for the mlir-tensorrt-compiler tool
# --------------------------------
cmd = [mlir_trt_compiler]
# Build the options string for --opts flag
# Remove '--' prefix from flags if present (CLI tool expects flags without --)
cleaned_flags = [flag.removeprefix("--") for flag in mlir_tensorrt_compilation_flags]
# Add the host-target and entrypoint flags
cleaned_flags.extend([f"host-target={host_target}", f"entrypoint={entrypoint}"])
# Join all flags into a single string for --opts
opts_str = " ".join(cleaned_flags)
cmd.extend(["--opts", opts_str])
# Add the input MLIR file and output directory
cmd.extend([str(mlir_filepath), "-o", str(output_dir)])
logger.info(f"Running mlir-tensorrt-compiler: {' '.join(cmd)}")
logger.info(f"Compilation timeout: {timeout}s")
# cmd contains only hardcoded paths and compiler flags, no user input
# ruff: noqa: S603
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
check=False,
env=_build_compiler_env(),
)
except FileNotFoundError as e:
error_msg = f"mlir-tensorrt-compiler not found: {e}"
logger.exception(error_msg)
raise FileNotFoundError(error_msg) from e
except subprocess.TimeoutExpired:
error_msg = f"MLIR-TensorRT compilation timed out after {timeout}s"
logger.error(error_msg)
raise RuntimeError(error_msg)
if result.returncode != 0:
error_msg = f"MLIR-TensorRT compilation failed with return code {result.returncode}"
if result.stdout:
logger.error(f"Compiler stdout:\n{result.stdout}")
error_msg += f"\nCompiler stdout:\n{result.stdout}"
if result.stderr:
logger.error(f"Compiler stderr:\n{result.stderr}")
error_msg += f"\nCompiler stderr:\n{result.stderr}"
logger.exception(error_msg)
raise RuntimeError(error_msg)
logger.info(f"MLIR-TensorRT compilation successful: {result.stdout}")
# Post-process generated C++ code to fix EmitC backend issues
if host_target == "emitc":
_post_process_emitc_cpp(output_dir)
# Force filesystem sync to ensure compiler output files are visible
# Prevents race conditions in parallel compilation where subprocess may exit
# before OS flushes .trtengine and other artifacts to disk
os.sync()
return True