Source code for ran.trt_plugins.manager.trt_plugin_manager

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TensorRT plugin manager for RAN package."""

from __future__ import annotations

import ctypes
import logging
import os
import shutil
from pathlib import Path
from typing import Any

from dotenv import load_dotenv
import tensorrt as trt

from ran.trt_plugins.manager._trt_plugin_field_map import _convert_type_to_trt

logger = logging.getLogger(__name__)


def _load_env_file_if_needed() -> None:
    """Load .env.python file if RAN_TRT_PLUGIN_DSO_PATH is not set.

    Priority 1: Use RAN_ENV_PYTHON_FILE if set
    Priority 2: Search upward from the current file's directory for .env.python
    Only loads if environment variables are not already set.
    """
    if os.environ.get("RAN_TRT_PLUGIN_DSO_PATH"):
        return  # Already set, nothing to do

    # Priority 1: Check for explicit env file path
    env_file_path = os.environ.get("RAN_ENV_PYTHON_FILE")
    env_file: Path | None = None
    if env_file_path:
        env_file = Path(env_file_path)
        logger.debug(f"Using RAN_ENV_PYTHON_FILE: {env_file}")
        if not env_file.exists():
            logger.warning(
                f"RAN_ENV_PYTHON_FILE points to non-existent file: {env_file}. "
                "Falling back to searching parent directories."
            )
            env_file = None  # Fall through to search

    # Priority 2: Search upward from current file for .env.python
    if not env_file:
        logger.debug("Searching for .env.python in parent directories")
        current_dir = Path(__file__).resolve().parent
        while current_dir != current_dir.parent:  # Stop at filesystem root
            candidate = current_dir / ".env.python"
            logger.debug(f"Checking: {candidate}")
            if candidate.exists():
                env_file = candidate
                logger.debug(f"Found .env.python at: {env_file}")
                break
            current_dir = current_dir.parent

    if not env_file:
        logger.debug(
            "No .env.python file found. "
            "Environment variables must be set manually or tests may fail."
        )
        return

    # Load the .env file using python-dotenv
    logger.info(f"Auto-loading environment from: {env_file}")
    load_dotenv(dotenv_path=env_file, override=False)


[docs] def get_ran_trt_plugin_dso_path() -> str: """Get the RAN TensorRT plugin DSO path from environment variable. This function returns the path to libran_trt_plugin.so from the RAN_TRT_PLUGIN_DSO_PATH environment variable. When running tests via CMake targets (py_ran_test, py_ran_wheel_test), this environment variable is automatically set by CMake. For interactive/manual usage, attempts to auto-load from .env.python file. Returns ------- Absolute path to libran_trt_plugin.so Raises ------ EnvironmentError: If RAN_TRT_PLUGIN_DSO_PATH is not set FileNotFoundError: If the DSO file does not exist at the specified path """ _load_env_file_if_needed() env_dso_path = os.environ.get("RAN_TRT_PLUGIN_DSO_PATH") if not env_dso_path: raise EnvironmentError( "RAN_TRT_PLUGIN_DSO_PATH environment variable not set. " "This should be set automatically when running tests via CMake targets " "For manual testing, set it to the path of libran_trt_plugin.so." ) dso_path = Path(env_dso_path) if not dso_path.exists(): raise FileNotFoundError( f"TensorRT plugin DSO not found at path specified by " f"RAN_TRT_PLUGIN_DSO_PATH: {dso_path}" ) return str(dso_path.absolute())
def get_ran_trt_engine_path() -> str: """Get the RAN TensorRT engine directory path from environment variable. This function returns the path to the TensorRT engine directory from the RAN_TRT_ENGINE_PATH environment variable. When running tests via CMake targets (py_ran_test, py_ran_wheel_test), this environment variable is automatically set by CMake. For interactive/manual usage, attempts to auto-load from .env.python file. Returns ------- Absolute path to TensorRT engine directory Raises ------ EnvironmentError: If RAN_TRT_ENGINE_PATH is not set """ _load_env_file_if_needed() env_engine_path = os.environ.get("RAN_TRT_ENGINE_PATH") if not env_engine_path: raise EnvironmentError( "RAN_TRT_ENGINE_PATH environment variable not set. " "This should be set automatically when running tests via CMake targets. " "For manual testing, set it to the directory where TensorRT engines are stored." ) return str(Path(env_engine_path).absolute()) def get_ran_pytest_build_dir() -> str: """Get the RAN pytest build directory path from environment variable. This function returns the path to the pytest build directory from the RAN_PYTEST_BUILD_DIR environment variable. When running tests via CMake targets, this environment variable is automatically set by CMake and points to the ran/py build directory (e.g., out/build/clang-release/ran/py). For interactive/manual usage, attempts to auto-load from .env.python file. Returns ------- Absolute path to pytest build directory Raises ------ EnvironmentError: If RAN_PYTEST_BUILD_DIR is not set """ _load_env_file_if_needed() env_build_dir = os.environ.get("RAN_PYTEST_BUILD_DIR") if not env_build_dir: raise EnvironmentError( "RAN_PYTEST_BUILD_DIR environment variable not set. " "This should be set automatically when running tests via CMake targets. " "For manual testing, set it to the pytest build directory." ) return str(Path(env_build_dir).absolute()) def should_skip_engine_generation(required_engines: list[str]) -> bool: """Check if TRT engine generation should be skipped. Checks if SKIP_TRT_ENGINE_GENERATION=1 and all required engines exist in RAN_TRT_ENGINE_PATH. Returns True if engines should be skipped (already exist), False if engines should be regenerated. Args: required_engines: List of engine filenames (e.g., ["dmrs_test.trtengine"]) Returns: True if engine generation should be skipped (all engines exist), False if engines should be regenerated """ _load_env_file_if_needed() skip_generation = os.getenv("SKIP_TRT_ENGINE_GENERATION", "0") == "1" if not skip_generation: return False engine_dir = Path(get_ran_trt_engine_path()) all_engines_exist = all((engine_dir / engine_name).exists() for engine_name in required_engines) if all_engines_exist: logger.info(f"Skipping engine generation - all engines exist in {engine_dir}") logger.info("Set SKIP_TRT_ENGINE_GENERATION=0 to force regeneration") return True logger.info("Engines not found, proceeding with generation...") return False def copy_trt_engine_for_cpp_tests( source_dir: str | Path, dest_engine_name: str, *, required: bool = True, ) -> Path: """Copy TensorRT engine from build directory to RAN_TRT_ENGINE_PATH for C++ tests. Args: source_dir: Directory containing tensorrt_cluster_engine_data.trtengine dest_engine_name: Destination filename (e.g., "dmrs_test.trtengine") required: If True, raises FileNotFoundError if source missing Returns: Path to destination engine file Raises: FileNotFoundError: If required=True and source doesn't exist """ # Source: generated engine in build directory engine_source = Path(source_dir) / "tensorrt_cluster_engine_data.trtengine" # Destination: central engine location expected by C++ tests engine_dest_dir = Path(get_ran_trt_engine_path()) engine_dest_dir.mkdir(parents=True, exist_ok=True) engine_dest = engine_dest_dir / dest_engine_name # Copy engine file if it exists if engine_source.exists(): shutil.copy2(engine_source, engine_dest) logger.info(f"Copied TensorRT engine to {engine_dest}") elif required: raise FileNotFoundError( f"TensorRT engine not found at {engine_source}. Engine file is required for C++ tests." ) else: logger.warning(f"TensorRT engine not found at {engine_source}") return engine_dest def copy_test_data_for_cpp_tests( source_dir: str | Path, test_vector_subdir: str, file_patterns: list[str], ) -> Path: """Copy test data files from build directory to RAN_TRT_ENGINE_PATH/test_vectors for C++ tests. Args: source_dir: Directory containing the test data files test_vector_subdir: Subdirectory name under test_vectors/ (e.g., "ai_tukey_filter") file_patterns: List of file patterns to copy (e.g., ["*.bin", "*_meta.txt"]) Returns: Path to destination test_vectors subdirectory """ # Destination: test_vectors subdirectory under RAN_TRT_ENGINE_PATH engine_path = Path(get_ran_trt_engine_path()) test_vectors_dir = engine_path / "test_vectors" / test_vector_subdir test_vectors_dir.mkdir(parents=True, exist_ok=True) # Copy all matching files source_path = Path(source_dir) files_copied = 0 for pattern in file_patterns: for source_file in source_path.glob(pattern): if source_file.is_file(): dest_file = test_vectors_dir / source_file.name shutil.copy2(source_file, dest_file) files_copied += 1 logger.debug(f"Copied {source_file.name} to {test_vectors_dir}") if files_copied > 0: logger.info(f"Copied {files_copied} test data files to {test_vectors_dir}") else: logger.warning(f"No files matched patterns {file_patterns} in {source_path}") return test_vectors_dir
[docs] class TensorRTPluginManager: """Manager for RAN TensorRT plugins."""
[docs] def __init__(self) -> None: """Initialize the plugin manager.""" self._trt_plugin_lib_path: Path | None = None self._trt_plugin_lib_loaded = False
[docs] def load_plugin_library(self) -> bool: """Load the TensorRT plugin library from RAN_TRT_PLUGIN_DSO_PATH. The library path is obtained from the RAN_TRT_PLUGIN_DSO_PATH environment variable via config.get_ran_trt_plugin_dso_path(). When running tests via CMake targets, this environment variable is automatically set by CMake. Returns ------- True if library loaded successfully. Raises ------ EnvironmentError: If RAN_TRT_PLUGIN_DSO_PATH is not set. FileNotFoundError: If the DSO file does not exist at the specified path. RuntimeError: If plugin initialization fails. """ if self._trt_plugin_lib_loaded: return True # Get library path from config lib_path = Path(get_ran_trt_plugin_dso_path()) logger.info(f"Loading TensorRT plugin library: {lib_path}") # Load the library plugin_lib = ctypes.CDLL(str(lib_path)) self._trt_plugin_lib_path = lib_path # Initialize TensorRT plugins (standard ones) trt.init_libnvinfer_plugins(None, "") # Initialize our custom RAN plugins if not hasattr(plugin_lib, "init_ran_plugins"): err_msg = f"init_ran_plugins function not found in {lib_path}" raise RuntimeError(err_msg) init_func = plugin_lib.init_ran_plugins init_func.argtypes = [ctypes.c_void_p, ctypes.c_char_p] init_func.restype = ctypes.c_bool # Default to empty namespace for convenience result = init_func(None, b"") if not result: err_msg = "Failed to initialize custom TensorRT plugins." raise RuntimeError(err_msg) self._trt_plugin_lib_loaded = True logger.info(f"TensorRT plugins loaded successfully from {lib_path}") return True
[docs] def create_plugin( self, plugin_name: str, fields: dict[str, Any] | None = None ) -> object | None: """Create a plugin instance. Args: plugin_name: Name of the plugin to create (C++ TensorRT plugin name). fields: Dictionary of plugin fields to pass to the plugin creator. Returns ------- Plugin instance if successful, None otherwise. """ if fields is None: fields = {} # Load plugins if not already loaded if not self._trt_plugin_lib_loaded: self.load_plugin_library() plugin_registry = trt.get_plugin_registry() creators = plugin_registry.all_creators plugin_creator = None for creator in creators: if creator.name == plugin_name: plugin_creator = creator break if plugin_creator is None: err_msg = f"Plugin creator '{plugin_name}' not found in TensorRT registry." logger.warning(err_msg) return None # Create plugin field collection field_collection = None if fields is None: # No fields provided - create empty collection field_collection = trt.PluginFieldCollection([]) else: # Convert Python dict to TensorRT plugin fields plugin_fields = [] for field_name, field_value in fields.items(): trt_type = _convert_type_to_trt(field_value) plugin_field = trt.PluginField(field_name, field_value, trt_type) plugin_fields.append(plugin_field) field_collection = trt.PluginFieldCollection(plugin_fields) return plugin_creator.create_plugin(plugin_name, field_collection, trt.TensorRTPhase.BUILD)