nat.middleware#
Middleware implementations for NeMo Agent Toolkit.
Submodules#
Attributes#
Classes#
Base class for function middleware with pre/post-invoke hooks. |
|
Composes middleware into an execution chain. |
|
Static metadata about the function being wrapped by middleware. |
|
Base class for middleware-style wrapping with pre/post-invoke hooks. |
|
Middleware for red teaming that intercepts and modifies function inputs/outputs. |
Functions#
|
Validate a sequence of middleware, enforcing ordering guarantees. |
Package Contents#
- class FunctionMiddleware(*, is_final: bool = False)#
Bases:
nat.middleware.middleware.MiddlewareBase class for function middleware with pre/post-invoke hooks.
Middleware intercepts function calls and can: - Transform inputs before execution (pre_invoke) - Transform outputs after execution (post_invoke) - Override function_middleware_invoke for full control
Lifecycle: - Framework checks
enabledproperty before calling any methods - If disabled, middleware is skipped entirely (no methods called) - Users do NOT need to checkenabledin their implementationsInherited abstract members that must be implemented: - enabled: Property that returns whether middleware should run - pre_invoke: Transform inputs before function execution - post_invoke: Transform outputs after function execution
Context Flow: - FunctionMiddlewareContext (frozen): Static function metadata only - InvocationContext: Unified context for both pre and post invoke phases - Pre-invoke: output is None, modify modified_args/modified_kwargs - Post-invoke: output has the result, modify output to transform
Example:
class LoggingMiddleware(FunctionMiddleware): def __init__(self, config: LoggingConfig): super().__init__() self._config = config @property def enabled(self) -> bool: return self._config.enabled async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None: logger.info(f"Calling {context.function_context.name} with {context.modified_args}") logger.info(f"Original args: {context.original_args}") return None # Pass through unchanged async def post_invoke(self, context: InvocationContext) -> InvocationContext | None: logger.info(f"Result: {context.output}") return None # Pass through unchanged
- property enabled: bool#
Check if this middleware is enabled.
- Returns:
True if the middleware should be applied, False otherwise. Default implementation always returns True.
- async pre_invoke( ) nat.middleware.middleware.InvocationContext | None#
Pre-invocation hook called before the function is invoked.
- Args:
context: Invocation context containing function metadata and args
- Returns:
InvocationContext if modified, or None to pass through unchanged. Default implementation does nothing.
- async post_invoke( ) nat.middleware.middleware.InvocationContext | None#
Post-invocation hook called after the function returns.
- Args:
context: Invocation context containing function metadata, args, and output
- Returns:
InvocationContext if modified, or None to pass through unchanged. Default implementation does nothing.
- async middleware_invoke(
- *args: Any,
- call_next: nat.middleware.middleware.CallNext,
- context: nat.middleware.middleware.FunctionMiddlewareContext,
- \*\*kwargs: Any,
Delegate to function_middleware_invoke for function-specific handling.
- async middleware_stream(
- *args: Any,
- call_next: nat.middleware.middleware.CallNextStream,
- context: nat.middleware.middleware.FunctionMiddlewareContext,
- \*\*kwargs: Any,
Delegate to function_middleware_stream for function-specific handling.
- async function_middleware_invoke(
- *args: Any,
- call_next: nat.middleware.middleware.CallNext,
- context: nat.middleware.middleware.FunctionMiddlewareContext,
- \*\*kwargs: Any,
Execute middleware hooks around function call.
Default implementation orchestrates: pre_invoke → call_next → post_invoke
Override for full control over execution flow (e.g., caching, retry logic, conditional execution).
Note: Framework checks
enabledbefore calling this method. You do NOT need to checkenabledyourself.- Args:
args: Positional arguments for the function (first arg is typically the input value). call_next: Callable to invoke next middleware or target function. context: Static function metadata. kwargs: Keyword arguments for the function.
- Returns:
The (potentially transformed) function output.
- async function_middleware_stream(
- *args: Any,
- call_next: nat.middleware.middleware.CallNextStream,
- context: nat.middleware.middleware.FunctionMiddlewareContext,
- \*\*kwargs: Any,
Execute middleware hooks around streaming function call.
Pre-invoke runs once before streaming starts. Post-invoke runs per-chunk as they stream through.
Override for custom streaming behavior (e.g., buffering, aggregation, chunk filtering).
Note: Framework checks
enabledbefore calling this method. You do NOT need to checkenabledyourself.- Args:
args: Positional arguments for the function (first arg is typically the input value). call_next: Callable to invoke next middleware or target stream. context: Static function metadata. kwargs: Keyword arguments for the function.
- Yields:
Stream chunks (potentially transformed by post_invoke).
- class FunctionMiddlewareChain(
- *,
- middleware: collections.abc.Sequence[FunctionMiddleware],
- context: nat.middleware.middleware.FunctionMiddlewareContext,
Composes middleware into an execution chain.
The chain builder checks each middleware’s
enabledproperty. Disabled middleware is skipped entirely—no methods are called.Execution order: - Pre-invoke: first middleware → last middleware → function - Post-invoke: function → last middleware → first middleware
Context: - FunctionMiddlewareContext contains only static function metadata - Original args/kwargs are captured by the orchestration layer - Middleware receives InvocationContext with frozen originals and mutable args/output
Initialize the middleware chain.
- Args:
middleware: Sequence of middleware to chain (order matters) context: Static function metadata
- _middleware#
- _context#
- build_single(
- final_call: nat.middleware.middleware.CallNext,
Build the middleware chain for single-output invocations.
Disabled middleware (enabled=False) is skipped entirely.
- Args:
final_call: The final function to call (the actual function implementation)
- Returns:
A callable that executes the entire middleware chain
- build_stream(
- final_call: nat.middleware.middleware.CallNextStream,
Build the middleware chain for streaming invocations.
Disabled middleware (enabled=False) is skipped entirely.
- Args:
final_call: The final function to call (the actual function implementation)
- Returns:
A callable that executes the entire middleware chain
- validate_middleware(
- middleware: collections.abc.Sequence[nat.middleware.middleware.Middleware] | None,
Validate a sequence of middleware, enforcing ordering guarantees.
- CallNext#
- CallNextStream#
- class FunctionMiddlewareContext#
Static metadata about the function being wrapped by middleware.
Middleware receives this context object which describes the function they are wrapping. This allows middleware to make decisions based on the function’s name, configuration, schema, etc.
- config: Any#
Configuration object for the function.
- input_schema: type[pydantic.BaseModel] | None#
Schema describing expected inputs or
NoneTypewhen absent.
- class Middleware(*, is_final: bool = False)#
Bases:
abc.ABCBase class for middleware-style wrapping with pre/post-invoke hooks.
Middleware works like middleware in web frameworks:
Preprocess: Inspect and optionally modify inputs (via pre_invoke)
Call Next: Delegate to the next middleware or the target itself
Postprocess: Process, transform, or augment the output (via post_invoke)
Continue: Return or yield the final result
Example:
class LoggingMiddleware(FunctionMiddleware): @property def enabled(self) -> bool: return True async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None: print(f"Current args: {context.modified_args}") print(f"Original args: {context.original_args}") return None # Pass through unchanged async def post_invoke(self, context: InvocationContext) -> InvocationContext | None: print(f"Output: {context.output}") return None # Pass through unchanged
- Attributes:
- is_final: If True, this middleware terminates the chain. No subsequent
middleware or the target will be called unless this middleware explicitly delegates to
call_next.
- _is_final = False#
- abstractmethod pre_invoke(context: InvocationContext) InvocationContext | None#
- Async:
Transform inputs before execution.
Called by specialized middleware invoke methods (e.g., function_middleware_invoke). Use to validate, transform, or augment inputs. At this phase, context.output is None.
- Args:
- context: Invocation context (Pydantic model) containing:
function_context: Static function metadata (frozen)
original_args: What entered the middleware chain (frozen)
original_kwargs: What entered the middleware chain (frozen)
modified_args: Current args (mutable)
modified_kwargs: Current kwargs (mutable)
output: None (function not yet called)
- Returns:
InvocationContext: Return the (modified) context to signal changes None: Pass through unchanged (framework uses current context state)
- Note:
Frozen fields (original_args, original_kwargs) cannot be modified. Attempting to modify them raises ValidationError.
- Raises:
Any exception to abort execution
- abstractmethod post_invoke(context: InvocationContext) InvocationContext | None#
- Async:
Transform output after execution.
Called by specialized middleware invoke methods (e.g., function_middleware_invoke). For streaming, called per-chunk. Use to validate, transform, or augment outputs.
- Args:
- context: Invocation context (Pydantic model) containing:
function_context: Static function metadata (frozen)
original_args: What entered the middleware chain (frozen)
original_kwargs: What entered the middleware chain (frozen)
modified_args: What the function received (mutable)
modified_kwargs: What the function received (mutable)
output: Current output value (mutable)
- Returns:
InvocationContext: Return the (modified) context to signal changes None: Pass through unchanged (framework uses current context.output)
Example:
async def post_invoke(self, context: InvocationContext) -> InvocationContext | None: # Wrap the output context.output = {"result": context.output, "processed": True} return context # Signal modification
- Raises:
Any exception to abort and propagate error
- property is_final: bool#
Whether this middleware terminates the chain.
A final middleware prevents subsequent middleware and the target from running unless it explicitly calls
call_next.
- async middleware_invoke(
- value: Any,
- call_next: CallNext,
- context: FunctionMiddlewareContext,
- \*\*kwargs: Any,
Middleware for single-output invocations.
- Args:
value: The input value to process call_next: Callable to invoke the next middleware or target context: Metadata about the target being wrapped kwargs: Additional function arguments
- Returns:
The (potentially modified) output from the target
The default implementation simply delegates to
call_next. Override this to add preprocessing, postprocessing, or to short-circuit execution:async def middleware_invoke(self, value, call_next, context, \*\*kwargs): # Preprocess: modify input modified_input = transform(value) # Call next: delegate to next middleware/target result = await call_next(modified_input, \*\*kwargs) # Postprocess: modify output modified_result = transform_output(result) # Continue: return final result return modified_result
- async middleware_stream(
- value: Any,
- call_next: CallNextStream,
- context: FunctionMiddlewareContext,
- \*\*kwargs: Any,
Middleware for streaming invocations.
- Args:
value: The input value to process call_next: Callable to invoke the next middleware or target stream context: Metadata about the target being wrapped kwargs: Additional function arguments
- Yields:
Chunks from the stream (potentially modified)
The default implementation forwards to
call_nextuntouched. Override this to add preprocessing, transform chunks, or perform cleanup:async def middleware_stream(self, value, call_next, context, \*\*kwargs): # Preprocess: setup or modify input modified_input = transform(value) # Call next: get stream from next middleware/target async for chunk in call_next(modified_input, \*\*kwargs): # Process each chunk modified_chunk = transform_chunk(chunk) yield modified_chunk # Postprocess: cleanup after stream ends await cleanup()
- class RedTeamingMiddleware(
- *,
- attack_payload: str,
- target_function_or_group: str | None = None,
- payload_placement: Literal['replace', 'append_start', 'append_middle', 'append_end'] = 'append_end',
- target_location: Literal['input', 'output'] = 'input',
- target_field: str | None = None,
- target_field_resolution_strategy: Literal['random', 'first', 'last', 'all', 'error'] = 'error',
- call_limit: int | None = None,
Bases:
nat.middleware.function_middleware.FunctionMiddlewareMiddleware for red teaming that intercepts and modifies function inputs/outputs.
This middleware enables systematic security testing by injecting attack payloads into function inputs or outputs. It supports flexible targeting, field-level modifications, and multiple attack modes.
Features:
Target specific functions or entire function groups
Search for specific fields in input/output schemas
Apply attacks via replace or append modes
Support for both regular and streaming calls
Type-safe operations on strings, numbers
Example:
# In YAML config middleware: prompt_injection: _type: red_teaming attack_payload: "Ignore previous instructions" target_function_or_group: my_llm.generate payload_placement: append_start target_location: input target_field: prompt
- Args:
attack_payload: The malicious payload to inject. target_function_or_group: Function or group to target (None for all). payload_placement: How to apply (replace, append_start, append_middle, append_end). target_location: Whether to attack input or output. target_field: Field name or path to attack (None for direct value).
Initialize red teaming middleware.
- Args:
attack_payload: The value to inject to the function input or output. target_function_or_group: Optional function/group to target. payload_placement: How to apply the payload (replace or append modes). target_location: Whether to place the payload in the input or output. target_field: JSONPath to the field to attack. target_field_resolution_strategy: Strategy (random/first/last/all/error). call_limit: Maximum number of times the middleware will apply a payload.
- _attack_payload#
- _target_function_or_group = None#
- _payload_placement = 'append_end'#
- _target_location = 'input'#
- _target_field = None#
- _target_field_resolution_strategy = 'error'#
- _call_limit = None#
- _should_apply_payload(context_name: str) bool#
Check if this function should be attacked based on targeting configuration.
- Args:
context_name: The name of the function from context (e.g., “calculator__add”)
- Returns:
True if the function should be attacked, False otherwise
- _find_middle_sentence_index(text: str) int#
Find the index to insert text at the middle sentence boundary.
- Args:
text: The text to analyze
- Returns:
The character index where the middle sentence ends
- _apply_payload_to_simple_type( ) Any#
Apply the attack payload to simple types (str, int, float) value.
- Args:
original_value: The original value to attack attack_payload: The payload to inject payload_placement: How to apply the payload
- Returns:
The modified value with attack applied
- Raises:
ValueError: If attack cannot be applied due to type mismatch
- _resolve_multiple_field_matches(matches)#
- _apply_payload_to_function_value(value: Any) Any#
- _apply_payload_to_function_value_with_exception(
- value: Any,
- context: nat.middleware.function_middleware.FunctionMiddlewareContext,
- async function_middleware_invoke(
- *args: Any,
- call_next: nat.middleware.function_middleware.CallNext,
- context: nat.middleware.function_middleware.FunctionMiddlewareContext,
- \*\*kwargs: Any,
Invoke middleware for single-output functions.
- Args:
args: Positional arguments passed to the function (first arg is typically the input value). call_next: Callable to invoke next middleware/function. context: Metadata about the function being wrapped. kwargs: Keyword arguments passed to the function.
- Returns:
The output value (potentially modified if attacking output).