Source code for nv_ingest.framework.util.flow_control.udf_intercept

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import functools
import hashlib
import inspect
import logging
import os
import time
from typing import Dict, List, Any, Optional, Callable
from dataclasses import dataclass

from nv_ingest_api.internal.primitives.ingest_control_message import IngestControlMessage, remove_all_tasks_by_type
from nv_ingest_api.util.imports.callable_signatures import ingest_callable_signature

logger = logging.getLogger(__name__)


[docs] @dataclass class CachedUDF: """Cached UDF function with metadata""" function: callable function_name: str signature_validated: bool created_at: float last_used: float use_count: int
[docs] class UDFCache: """LRU cache for compiled and validated UDF functions""" def __init__(self, max_size: int = 128, ttl_seconds: Optional[int] = 3600): self.max_size = max_size self.ttl_seconds = ttl_seconds self.cache: Dict[str, CachedUDF] = {} self.access_order: List[str] = [] # For LRU tracking def _generate_cache_key(self, udf_function_str: str, udf_function_name: str) -> str: """Generate cache key from UDF string and function name""" content = f"{udf_function_str.strip()}:{udf_function_name}" return hashlib.sha256(content.encode()).hexdigest() def _evict_lru(self): """Remove least recently used item""" if self.access_order: lru_key = self.access_order.pop(0) self.cache.pop(lru_key, None) def _cleanup_expired(self): """Remove expired entries if TTL is configured""" if not self.ttl_seconds: return current_time = time.time() expired_keys = [ key for key, cached_udf in self.cache.items() if current_time - cached_udf.created_at > self.ttl_seconds ] for key in expired_keys: self.cache.pop(key, None) if key in self.access_order: self.access_order.remove(key)
[docs] def get(self, udf_function_str: str, udf_function_name: str) -> Optional[CachedUDF]: """Get cached UDF function if available""" self._cleanup_expired() cache_key = self._generate_cache_key(udf_function_str, udf_function_name) if cache_key in self.cache: # Update access tracking if cache_key in self.access_order: self.access_order.remove(cache_key) self.access_order.append(cache_key) # Update usage stats cached_udf = self.cache[cache_key] cached_udf.last_used = time.time() cached_udf.use_count += 1 return cached_udf return None
[docs] def put( self, udf_function_str: str, udf_function_name: str, function: callable, signature_validated: bool = True ) -> str: """Cache a compiled and validated UDF function""" cache_key = self._generate_cache_key(udf_function_str, udf_function_name) # Evict LRU if at capacity while len(self.cache) >= self.max_size: self._evict_lru() current_time = time.time() cached_udf = CachedUDF( function=function, function_name=udf_function_name, signature_validated=signature_validated, created_at=current_time, last_used=current_time, use_count=1, ) self.cache[cache_key] = cached_udf self.access_order.append(cache_key) return cache_key
[docs] def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" total_uses = sum(udf.use_count for udf in self.cache.values()) most_used = max(self.cache.values(), key=lambda x: x.use_count, default=None) return { "size": len(self.cache), "max_size": self.max_size, "total_uses": total_uses, "most_used_function": most_used.function_name if most_used else None, "most_used_count": most_used.use_count if most_used else 0, }
# Global cache instance _udf_cache = UDFCache(max_size=128, ttl_seconds=3600)
[docs] def compile_and_validate_udf(udf_function_str: str, udf_function_name: str, task_num: int) -> callable: """Compile and validate UDF function (extracted for caching)""" # Execute the UDF function string in a controlled namespace namespace: Dict[str, Any] = {} try: exec(udf_function_str, namespace) except Exception as e: raise ValueError(f"UDF task {task_num} failed to execute: {str(e)}") # Extract the specified function from the namespace if udf_function_name in namespace and callable(namespace[udf_function_name]): udf_function = namespace[udf_function_name] else: raise ValueError(f"UDF task {task_num}: Specified UDF function '{udf_function_name}' not found or not callable") # Validate the UDF function signature try: ingest_callable_signature(inspect.signature(udf_function)) except Exception as e: raise ValueError(f"UDF task {task_num} has invalid function signature: {str(e)}") return udf_function
[docs] def execute_targeted_udfs( control_message: IngestControlMessage, stage_name: str, directive: str ) -> IngestControlMessage: """Execute UDFs that target this stage with the given directive.""" # Early exit if no UDF tasks exist - check by task type, not task ID udf_tasks_exist = any(task.type == "udf" for task in control_message.get_tasks()) if not udf_tasks_exist: return control_message # Remove all UDF tasks and get them - handle case where no tasks found try: all_udf_tasks = remove_all_tasks_by_type(control_message, "udf") except ValueError: # No UDF tasks found - this can happen due to race conditions logger.debug(f"No UDF tasks found for stage '{stage_name}' directive '{directive}'") return control_message # Execute applicable UDFs and collect remaining ones remaining_tasks = [] for task_properties in all_udf_tasks: # Check if this UDF targets this stage with the specified directive target_stage = task_properties.get("target_stage", "") run_before = task_properties.get("run_before", False) run_after = task_properties.get("run_after", False) # Determine if this UDF should execute should_execute = False if directive == "run_before" and run_before and target_stage == stage_name: should_execute = True elif directive == "run_after" and run_after and target_stage == stage_name: should_execute = True if should_execute: try: # Get UDF function details udf_function_str = task_properties.get("udf_function", "").strip() udf_function_name = task_properties.get("udf_function_name", "").strip() task_id = task_properties.get("task_id", "unknown") # Skip empty UDF functions if not udf_function_str: logger.debug(f"UDF task {task_id} has empty function, skipping") remaining_tasks.append(task_properties) continue # Validate function name if not udf_function_name: raise ValueError(f"UDF task {task_id} missing required 'udf_function_name' property") # Get or compile UDF function cached_udf = _udf_cache.get(udf_function_str, udf_function_name) if cached_udf: udf_function = cached_udf.function logger.debug(f"UDF task {task_id}: Using cached function '{udf_function_name}'") else: udf_function = compile_and_validate_udf(udf_function_str, udf_function_name, task_id) _udf_cache.put(udf_function_str, udf_function_name, udf_function) logger.debug(f"UDF task {task_id}: Cached function '{udf_function_name}'") # Execute the UDF control_message = udf_function(control_message) # Validate return type if not isinstance(control_message, IngestControlMessage): raise ValueError(f"UDF task {task_id} must return IngestControlMessage") logger.info(f"Executed UDF {task_id} '{udf_function_name}' {directive} stage '{stage_name}'") except Exception as e: logger.error(f"UDF {task_id} failed {directive} stage '{stage_name}': {e}") # Keep failed task for next stage remaining_tasks.append(task_properties) else: # Keep non-applicable task for next stage remaining_tasks.append(task_properties) # Re-add all remaining UDF tasks for task_properties in remaining_tasks: from nv_ingest_api.internal.primitives.control_message_task import ControlMessageTask task = ControlMessageTask(type="udf", id=task_properties.get("task_id", "unknown"), properties=task_properties) control_message.add_task(task) return control_message
[docs] def remove_task_by_id(control_message: IngestControlMessage, task_id: str) -> IngestControlMessage: """Remove a specific task by ID from the control message""" try: control_message.remove_task(task_id) except RuntimeError as e: logger.warning(f"Could not remove task {task_id}: {e}") return control_message
[docs] def udf_intercept_hook(stage_name: Optional[str] = None, enable_run_before: bool = True, enable_run_after: bool = True): """ Decorator that executes UDFs targeted at this stage. This decorator integrates with the existing UDF system, providing full UDF compilation, caching, and execution capabilities. UDFs can target specific stages using run_before or run_after directives. Args: stage_name: Name of the stage (e.g., "image_dedup", "text_extract"). If None, will attempt to use self.stage_name from the decorated method's instance. enable_run_before: Whether to execute UDFs with run_before=True (default: True) enable_run_after: Whether to execute UDFs with run_after=True (default: True) Examples: # Automatic stage name detection (recommended) @traceable("image_deduplication") @udf_intercept_hook() # Uses self.stage_name automatically @filter_by_task(required_tasks=["dedup"]) def on_data(self, control_message: IngestControlMessage) -> IngestControlMessage: return control_message # Explicit stage name (fallback/override) @traceable("data_sink") @udf_intercept_hook("data_sink", enable_run_after=False) @filter_by_task(required_tasks=["store"]) def on_data(self, control_message: IngestControlMessage) -> IngestControlMessage: return control_message # Only run_after UDFs (e.g., for source stages) @traceable("data_source") @udf_intercept_hook(enable_run_before=False) # Uses self.stage_name automatically def on_data(self, control_message: IngestControlMessage) -> IngestControlMessage: return control_message """ def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: # Check if UDF processing is globally disabled if os.getenv("INGEST_DISABLE_UDF_PROCESSING"): logger.debug("UDF processing is disabled via INGEST_DISABLE_UDF_PROCESSING environment variable") return func(*args, **kwargs) # Determine the stage name to use resolved_stage_name = stage_name # If no explicit stage_name provided, try to get it from self.stage_name if resolved_stage_name is None and len(args) >= 1: stage_instance = args[0] # 'self' in method calls if hasattr(stage_instance, "stage_name") and stage_instance.stage_name: resolved_stage_name = stage_instance.stage_name logger.debug(f"Using auto-detected stage name: '{resolved_stage_name}'") else: logger.warning( "No stage_name provided and could not auto-detect from instance. Skipping UDF intercept." ) return func(*args, **kwargs) elif resolved_stage_name is None: logger.warning( "No stage_name provided and no instance available for auto-detection. Skipping UDF intercept." ) return func(*args, **kwargs) # Extract control_message from args (handle both self.method and function cases) control_message = None if len(args) >= 2 and hasattr(args[1], "get_tasks"): control_message = args[1] # self.method case args_list = list(args) elif len(args) >= 1 and hasattr(args[0], "get_tasks"): control_message = args[0] # function case args_list = list(args) if control_message: # Execute UDFs that should run before this stage (if enabled) if enable_run_before: control_message = execute_targeted_udfs(control_message, resolved_stage_name, "run_before") # Update args with modified control_message if len(args) >= 2 and hasattr(args[1], "get_tasks"): args_list[1] = control_message else: args_list[0] = control_message # Execute the original stage logic result = func(*tuple(args_list), **kwargs) # Execute UDFs that should run after this stage (if enabled) if enable_run_after and hasattr(result, "get_tasks"): # Result is control_message result = execute_targeted_udfs(result, resolved_stage_name, "run_after") return result else: return func(*args, **kwargs) return wrapper return decorator
[docs] def get_udf_cache_stats() -> Dict[str, Any]: """Get UDF cache performance statistics""" return _udf_cache.get_stats()