Source code for ran.mlir_trt_wrapper.transform_mlir_custom_call_to_trt_plugin

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MLIR transformation utilities for converting custom calls to TensorRT plugins."""

import re


class MLIRTransformationError(Exception):
    """Raised when MLIR transformation fails."""


[docs] def transform_mlir_custom_call_to_trt_plugin( # noqa: PLR0915 mlir_string: str, plugin_configs: dict[str, dict[str, str | dict[str, str]]], ) -> str: """Transform stablehlo.custom_call operations to tensorrt.opaque_plugin operations. This function parses MLIR code containing stablehlo.custom_call operations and transforms them into equivalent tensorrt.opaque_plugin operations. This is a temporary solution until the functionality is added to the MLIR TensorRT. The regex pattern matches MLIR custom call operations with the following structure: %result_var = stablehlo.custom_call @target_name(operands) {attributes} : (input_types) -> output_types Args: mlir_string: The MLIR string from exported.mlir_module() plugin_configs: Dictionary mapping target names to their plugin configurations. Each target name (e.g., "tensorrt_dmrs_plugin") maps to a configuration dict. The key can optionally include a suffix matching the backend_config attribute (e.g., "tensorrt_cufft_plugin_forward") to provide different configurations for the same target name with different backend_config values. Configuration dict keys: - dso_path: Path to the plugin DSO file (required) - plugin_version: Version string for the plugin (optional, default: "1") - plugin_namespace: Namespace for the plugin (optional, default: "") - creator_func: Creator function name (optional, auto-generated if not provided) - creator_params: Dictionary of creator parameters (optional, defaults to {"dummy_param": 0} if empty) When a custom_call has a backend_config attribute, the lookup order is: 1. Try "{target_name}_{backend_config}" (e.g., "tensorrt_cufft_plugin_forward") 2. Fall back to "{target_name}" if the suffixed key doesn't exist Returns ------- Transformed MLIR string with tensorrt.opaque_plugin operations Raises ------ MLIRTransformationError: If the transformation fails, input is invalid, or a custom_call operation is found without a corresponding config Examples -------- Single plugin: >>> mlir_input = ''' ... %4 = stablehlo.custom_call @tensorrt_sequential_sum_plugin(%3) ... {api_version = 2 : i32} : (tensor<30000xf32>) -> tensor<30000xf32> ... ''' >>> configs = { ... "tensorrt_sequential_sum_plugin": { ... "dso_path": "/path/to/plugin.so", ... "creator_params": {"i32_param": 10}, ... } ... } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs) Multiple plugins: >>> mlir_input = ''' ... %4 = stablehlo.custom_call @tensorrt_dmrs_plugin(%3) ... {api_version = 2 : i32} : (tensor<100xf32>) -> tensor<100xf32> ... %5 = stablehlo.custom_call @tensorrt_cufft_plugin(%4) ... {api_version = 2 : i32} : (tensor<100xf32>) -> tensor<100xf32> ... ''' >>> configs = { ... "tensorrt_dmrs_plugin": { ... "dso_path": "/path/to/dmrs.so", ... "creator_params": {"sequence_length": 3276}, ... }, ... "tensorrt_cufft_plugin": { ... "dso_path": "/path/to/cufft.so", ... "creator_params": {"fft_size": 2048}, ... }, ... } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs) Multiple plugins with backend_config for disambiguation: >>> mlir_input = ''' ... %4 = stablehlo.custom_call @tensorrt_cufft_plugin(%3) ... {api_version = 2 : i32, backend_config = "forward"} : ... (tensor<100xf32>) -> tensor<100xf32> ... %5 = stablehlo.custom_call @tensorrt_cufft_plugin(%4) ... {api_version = 2 : i32, backend_config = "inverse"} : ... (tensor<100xf32>) -> tensor<100xf32> ... ''' >>> configs = { ... "tensorrt_cufft_plugin_forward": { ... "dso_path": "/path/to/cufft.so", ... "creator_params": { ... "fft_size": 2048, ... "direction": 0, ... }, ... }, ... "tensorrt_cufft_plugin_inverse": { ... "dso_path": "/path/to/cufft.so", ... "creator_params": { ... "fft_size": 2048, ... "direction": 1, ... }, ... }, ... } >>> result = transform_mlir_custom_call_to_trt_plugin(mlir_input, configs) """ # Regex pattern to match stablehlo.custom_call operations # Groups: 1=result_var, 2=target_name, 3=operands, 4=attributes, 5=input_types, # 6=output_types custom_call_pattern = ( r"%(\w+(?::\d+)?) = stablehlo\.custom_call @(\w+)\(([^)]+)\) " r"\{([^}]*)\} : \(([^)]+)\) -> ([^)]+)" ) def _validate_plugin_config(plugin_config: dict) -> None: """Validate that required plugin configuration is present.""" dso_path = plugin_config.get("dso_path") if not dso_path: msg = "dso_path is required in plugin_config" raise MLIRTransformationError(msg) def replace_custom_call(match: re.Match[str]) -> str: """Replace a single custom call match with tensorrt.opaque_plugin operation.""" try: result_var = match.group(1) target_name = match.group(2) operands = match.group(3) attributes = match.group(4) input_types = match.group(5) output_types = match.group(6) # Extract backend_config from attributes if present backend_config = None backend_config_match = re.search(r'backend_config\s*=\s*"([^"]+)"', attributes) if backend_config_match: backend_config = backend_config_match.group(1) # Build lookup key: first try target_name with backend_config suffix, # then fall back to just target_name config_key = target_name if backend_config: # Try with suffix first (e.g., "tensorrt_cufft_plugin_forward") config_key_with_suffix = f"{target_name}_{backend_config}" if config_key_with_suffix in plugin_configs: config_key = config_key_with_suffix # Look up the configuration for this specific plugin plugin_config = plugin_configs.get(config_key) if plugin_config is None: msg = ( f"No configuration found for plugin '{config_key}' " f"(target_name='{target_name}', backend_config='{backend_config}'). " f"Available plugins: {list(plugin_configs.keys())}" ) raise MLIRTransformationError(msg) # noqa: TRY301 # Extract plugin name from target name (remove common prefixes/suffixes) plugin_name = _extract_plugin_name(target_name) # Validate plugin configuration _validate_plugin_config(plugin_config) # Get plugin configuration dso_path_raw = plugin_config.get("dso_path") dso_path = dso_path_raw if isinstance(dso_path_raw, str) else "" plugin_version_raw = plugin_config.get("plugin_version", "1") plugin_version = plugin_version_raw if isinstance(plugin_version_raw, str) else "1" plugin_namespace_raw = plugin_config.get("plugin_namespace", "") plugin_namespace = plugin_namespace_raw if isinstance(plugin_namespace_raw, str) else "" creator_func_raw = plugin_config.get( "creator_func", f"get{_capitalize_first(plugin_name)}Creator" ) creator_func = ( creator_func_raw if isinstance(creator_func_raw, str) else f"get{_capitalize_first(plugin_name)}Creator" ) creator_params_raw = plugin_config.get("creator_params", {}) creator_params: dict[str, str | int | float | bool] = dict( creator_params_raw if isinstance(creator_params_raw, dict) else {} ) # Get optional layer_name for better profiling visibility layer_name_raw = plugin_config.get("layer_name", "") layer_name = layer_name_raw if isinstance(layer_name_raw, str) else "" # Build creator_params string with default if empty if not creator_params: # Provide a default parameter to avoid compiler issues with empty # creator_params creator_params = {"dummy_param": "0"} params_str = _build_creator_params_string(creator_params) # Create the tensorrt.opaque_plugin operation replacement = _build_tensorrt_plugin_operation( result_var, dso_path, plugin_name, plugin_version, plugin_namespace, creator_func, params_str, operands, input_types, output_types, layer_name, ) except Exception as e: msg = f"Failed to transform custom call: {e}" raise MLIRTransformationError(msg) from e else: return replacement # Perform the transformation try: transformed_mlir = re.sub(custom_call_pattern, replace_custom_call, mlir_string) except Exception as e: msg = f"Regex substitution failed: {e}" raise MLIRTransformationError(msg) from e else: return transformed_mlir
def _extract_plugin_name(target_name: str) -> str: """Extract clean plugin name from target name. Args: target_name: The target name from the custom call Returns ------- Clean plugin name with common prefixes/suffixes removed """ # Handle specific plugin name mappings if target_name == "tensorrt_fft_plugin": return "FftTrt" if target_name == "tensorrt_dmrs_plugin": return "DmrsTrt" if target_name == "tensorrt_sequential_sum_plugin": return "SequentialSum" if target_name == "tensorrt_cholesky_factor_inv_plugin": return "CholeskyFactorInv" # Remove common prefixes and suffixes for other plugins return target_name.replace("tensorrt_", "").replace("_plugin", "") def _capitalize_first(text: str) -> str: """Capitalize the first letter of a string. Args: text: Input string Returns ------- String with first letter capitalized """ return text[0].upper() + text[1:] if text else "" def _build_creator_params_string( creator_params: dict[str, str | int | float | bool], ) -> str: """Build MLIR string representation of creator parameters. Args: creator_params: Dictionary of parameter names to values Returns ------- MLIR-formatted parameter string """ if not creator_params: return "" param_parts = [] for key, value in creator_params.items(): # Handle different data types properly if isinstance(value, int): # For TensorRT plugins, integer parameters should be i32 to match # the expected PluginFieldType::kINT32 in the C++ plugin code param_parts.append(f"{key} = {value} : i32") elif isinstance(value, float): param_parts.append(f"{key} = {value}") elif isinstance(value, bool): param_parts.append(f"{key} = {str(value).lower()}") else: # String values need quotes param_parts.append(f'{key} = "{value}"') return ",\n ".join(param_parts) def _build_tensorrt_plugin_operation( # noqa: PLR0913 result_var: str, dso_path: str, plugin_name: str, plugin_version: str, plugin_namespace: str, creator_func: str, params_str: str, operands: str, input_types: str, output_types: str, layer_name: str = "", ) -> str: """Build the complete tensorrt.opaque_plugin operation string. Args: result_var: Result variable name dso_path: Path to the plugin DSO plugin_name: Name of the plugin plugin_version: Version of the plugin plugin_namespace: Namespace for the plugin creator_func: Creator function name params_str: Formatted parameter string operands: Input operands input_types: Input type specification output_types: Output type specification layer_name: Optional human-readable layer name for profiling Returns ------- Complete MLIR tensorrt.opaque_plugin operation string """ # Build the creator_params section if parameters exist creator_params_section = "" if params_str: creator_params_section = f",\n creator_params = {{\n {params_str}\n }}" # Build the layer_name section if provided layer_name_section = "" if layer_name: layer_name_section = f',\n layer_name = "{layer_name}"' # Construct the complete operation return f"""%{result_var} = tensorrt.opaque_plugin {{ dso_path = "{dso_path}", plugin_name = "{plugin_name}", plugin_version = "{plugin_version}", plugin_namespace = "{plugin_namespace}", creator_func = "{creator_func}"{creator_params_section}{layer_name_section} }} ({operands}) : ({input_types}) -> {output_types}"""