nat.middleware#

Middleware implementations for NeMo Agent Toolkit.

Submodules#

Attributes#

Classes#

FunctionMiddleware

Base class for function middleware with pre/post-invoke hooks.

FunctionMiddlewareChain

Composes middleware into an execution chain.

FunctionMiddlewareContext

Static metadata about the function being wrapped by middleware.

Middleware

Base class for middleware-style wrapping with pre/post-invoke hooks.

RedTeamingMiddleware

Middleware for red teaming that intercepts and modifies function inputs/outputs.

Functions#

validate_middleware(...)

Validate a sequence of middleware, enforcing ordering guarantees.

Package Contents#

class FunctionMiddleware(*, is_final: bool = False)#

Bases: nat.middleware.middleware.Middleware

Base class for function middleware with pre/post-invoke hooks.

Middleware intercepts function calls and can: - Transform inputs before execution (pre_invoke) - Transform outputs after execution (post_invoke) - Override function_middleware_invoke for full control

Lifecycle: - Framework checks enabled property before calling any methods - If disabled, middleware is skipped entirely (no methods called) - Users do NOT need to check enabled in their implementations

Inherited abstract members that must be implemented: - enabled: Property that returns whether middleware should run - pre_invoke: Transform inputs before function execution - post_invoke: Transform outputs after function execution

Context Flow: - FunctionMiddlewareContext (frozen): Static function metadata only - InvocationContext: Unified context for both pre and post invoke phases - Pre-invoke: output is None, modify modified_args/modified_kwargs - Post-invoke: output has the result, modify output to transform

Example:

class LoggingMiddleware(FunctionMiddleware):
    def __init__(self, config: LoggingConfig):
        super().__init__()
        self._config = config

    @property
    def enabled(self) -> bool:
        return self._config.enabled

    async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None:
        logger.info(f"Calling {context.function_context.name} with {context.modified_args}")
        logger.info(f"Original args: {context.original_args}")
        return None  # Pass through unchanged

    async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
        logger.info(f"Result: {context.output}")
        return None  # Pass through unchanged
property enabled: bool#

Check if this middleware is enabled.

Returns:

True if the middleware should be applied, False otherwise. Default implementation always returns True.

async pre_invoke(
context: nat.middleware.middleware.InvocationContext,
) nat.middleware.middleware.InvocationContext | None#

Pre-invocation hook called before the function is invoked.

Args:

context: Invocation context containing function metadata and args

Returns:

InvocationContext if modified, or None to pass through unchanged. Default implementation does nothing.

async post_invoke(
context: nat.middleware.middleware.InvocationContext,
) nat.middleware.middleware.InvocationContext | None#

Post-invocation hook called after the function returns.

Args:

context: Invocation context containing function metadata, args, and output

Returns:

InvocationContext if modified, or None to pass through unchanged. Default implementation does nothing.

async middleware_invoke(
*args: Any,
call_next: nat.middleware.middleware.CallNext,
context: nat.middleware.middleware.FunctionMiddlewareContext,
\*\*kwargs: Any,
) Any#

Delegate to function_middleware_invoke for function-specific handling.

async middleware_stream(
*args: Any,
call_next: nat.middleware.middleware.CallNextStream,
context: nat.middleware.middleware.FunctionMiddlewareContext,
\*\*kwargs: Any,
) collections.abc.AsyncIterator[Any]#

Delegate to function_middleware_stream for function-specific handling.

async function_middleware_invoke(
*args: Any,
call_next: nat.middleware.middleware.CallNext,
context: nat.middleware.middleware.FunctionMiddlewareContext,
\*\*kwargs: Any,
) Any#

Execute middleware hooks around function call.

Default implementation orchestrates: pre_invoke → call_next → post_invoke

Override for full control over execution flow (e.g., caching, retry logic, conditional execution).

Note: Framework checks enabled before calling this method. You do NOT need to check enabled yourself.

Args:

args: Positional arguments for the function (first arg is typically the input value). call_next: Callable to invoke next middleware or target function. context: Static function metadata. kwargs: Keyword arguments for the function.

Returns:

The (potentially transformed) function output.

async function_middleware_stream(
*args: Any,
call_next: nat.middleware.middleware.CallNextStream,
context: nat.middleware.middleware.FunctionMiddlewareContext,
\*\*kwargs: Any,
) collections.abc.AsyncIterator[Any]#

Execute middleware hooks around streaming function call.

Pre-invoke runs once before streaming starts. Post-invoke runs per-chunk as they stream through.

Override for custom streaming behavior (e.g., buffering, aggregation, chunk filtering).

Note: Framework checks enabled before calling this method. You do NOT need to check enabled yourself.

Args:

args: Positional arguments for the function (first arg is typically the input value). call_next: Callable to invoke next middleware or target stream. context: Static function metadata. kwargs: Keyword arguments for the function.

Yields:

Stream chunks (potentially transformed by post_invoke).

class FunctionMiddlewareChain(
*,
middleware: collections.abc.Sequence[FunctionMiddleware],
context: nat.middleware.middleware.FunctionMiddlewareContext,
)#

Composes middleware into an execution chain.

The chain builder checks each middleware’s enabled property. Disabled middleware is skipped entirely—no methods are called.

Execution order: - Pre-invoke: first middleware → last middleware → function - Post-invoke: function → last middleware → first middleware

Context: - FunctionMiddlewareContext contains only static function metadata - Original args/kwargs are captured by the orchestration layer - Middleware receives InvocationContext with frozen originals and mutable args/output

Initialize the middleware chain.

Args:

middleware: Sequence of middleware to chain (order matters) context: Static function metadata

_middleware#
_context#
build_single(
final_call: nat.middleware.middleware.CallNext,
) nat.middleware.middleware.CallNext#

Build the middleware chain for single-output invocations.

Disabled middleware (enabled=False) is skipped entirely.

Args:

final_call: The final function to call (the actual function implementation)

Returns:

A callable that executes the entire middleware chain

build_stream(
final_call: nat.middleware.middleware.CallNextStream,
) nat.middleware.middleware.CallNextStream#

Build the middleware chain for streaming invocations.

Disabled middleware (enabled=False) is skipped entirely.

Args:

final_call: The final function to call (the actual function implementation)

Returns:

A callable that executes the entire middleware chain

validate_middleware(
middleware: collections.abc.Sequence[nat.middleware.middleware.Middleware] | None,
) tuple[nat.middleware.middleware.Middleware, Ellipsis]#

Validate a sequence of middleware, enforcing ordering guarantees.

CallNext#
CallNextStream#
class FunctionMiddlewareContext#

Static metadata about the function being wrapped by middleware.

Middleware receives this context object which describes the function they are wrapping. This allows middleware to make decisions based on the function’s name, configuration, schema, etc.

name: str#

Name of the function being wrapped.

config: Any#

Configuration object for the function.

description: str | None#

Optional description of the function.

input_schema: type[pydantic.BaseModel] | None#

Schema describing expected inputs or NoneType when absent.

single_output_schema: type[pydantic.BaseModel] | type[None]#

Schema describing single outputs or types.NoneType when absent.

stream_output_schema: type[pydantic.BaseModel] | type[None]#

Schema describing streaming outputs or types.NoneType when absent.

class Middleware(*, is_final: bool = False)#

Bases: abc.ABC

Base class for middleware-style wrapping with pre/post-invoke hooks.

Middleware works like middleware in web frameworks:

  1. Preprocess: Inspect and optionally modify inputs (via pre_invoke)

  2. Call Next: Delegate to the next middleware or the target itself

  3. Postprocess: Process, transform, or augment the output (via post_invoke)

  4. Continue: Return or yield the final result

Example:

class LoggingMiddleware(FunctionMiddleware):
    @property
    def enabled(self) -> bool:
        return True

    async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None:
        print(f"Current args: {context.modified_args}")
        print(f"Original args: {context.original_args}")
        return None  # Pass through unchanged

    async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
        print(f"Output: {context.output}")
        return None  # Pass through unchanged
Attributes:
is_final: If True, this middleware terminates the chain. No subsequent

middleware or the target will be called unless this middleware explicitly delegates to call_next.

_is_final = False#
property enabled: bool#
Abstractmethod:

Whether this middleware should execute.

abstractmethod pre_invoke(context: InvocationContext) InvocationContext | None#
Async:

Transform inputs before execution.

Called by specialized middleware invoke methods (e.g., function_middleware_invoke). Use to validate, transform, or augment inputs. At this phase, context.output is None.

Args:
context: Invocation context (Pydantic model) containing:
  • function_context: Static function metadata (frozen)

  • original_args: What entered the middleware chain (frozen)

  • original_kwargs: What entered the middleware chain (frozen)

  • modified_args: Current args (mutable)

  • modified_kwargs: Current kwargs (mutable)

  • output: None (function not yet called)

Returns:

InvocationContext: Return the (modified) context to signal changes None: Pass through unchanged (framework uses current context state)

Note:

Frozen fields (original_args, original_kwargs) cannot be modified. Attempting to modify them raises ValidationError.

Raises:

Any exception to abort execution

abstractmethod post_invoke(context: InvocationContext) InvocationContext | None#
Async:

Transform output after execution.

Called by specialized middleware invoke methods (e.g., function_middleware_invoke). For streaming, called per-chunk. Use to validate, transform, or augment outputs.

Args:
context: Invocation context (Pydantic model) containing:
  • function_context: Static function metadata (frozen)

  • original_args: What entered the middleware chain (frozen)

  • original_kwargs: What entered the middleware chain (frozen)

  • modified_args: What the function received (mutable)

  • modified_kwargs: What the function received (mutable)

  • output: Current output value (mutable)

Returns:

InvocationContext: Return the (modified) context to signal changes None: Pass through unchanged (framework uses current context.output)

Example:

async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
    # Wrap the output
    context.output = {"result": context.output, "processed": True}
    return context  # Signal modification
Raises:

Any exception to abort and propagate error

property is_final: bool#

Whether this middleware terminates the chain.

A final middleware prevents subsequent middleware and the target from running unless it explicitly calls call_next.

async middleware_invoke(
value: Any,
call_next: CallNext,
context: FunctionMiddlewareContext,
\*\*kwargs: Any,
) Any#

Middleware for single-output invocations.

Args:

value: The input value to process call_next: Callable to invoke the next middleware or target context: Metadata about the target being wrapped kwargs: Additional function arguments

Returns:

The (potentially modified) output from the target

The default implementation simply delegates to call_next. Override this to add preprocessing, postprocessing, or to short-circuit execution:

async def middleware_invoke(self, value, call_next, context, \*\*kwargs):
    # Preprocess: modify input
    modified_input = transform(value)

    # Call next: delegate to next middleware/target
    result = await call_next(modified_input, \*\*kwargs)

    # Postprocess: modify output
    modified_result = transform_output(result)

    # Continue: return final result
    return modified_result
async middleware_stream(
value: Any,
call_next: CallNextStream,
context: FunctionMiddlewareContext,
\*\*kwargs: Any,
) collections.abc.AsyncIterator[Any]#

Middleware for streaming invocations.

Args:

value: The input value to process call_next: Callable to invoke the next middleware or target stream context: Metadata about the target being wrapped kwargs: Additional function arguments

Yields:

Chunks from the stream (potentially modified)

The default implementation forwards to call_next untouched. Override this to add preprocessing, transform chunks, or perform cleanup:

async def middleware_stream(self, value, call_next, context, \*\*kwargs):
    # Preprocess: setup or modify input
    modified_input = transform(value)

    # Call next: get stream from next middleware/target
    async for chunk in call_next(modified_input, \*\*kwargs):
        # Process each chunk
        modified_chunk = transform_chunk(chunk)
        yield modified_chunk

    # Postprocess: cleanup after stream ends
    await cleanup()
class RedTeamingMiddleware(
*,
attack_payload: str,
target_function_or_group: str | None = None,
payload_placement: Literal['replace', 'append_start', 'append_middle', 'append_end'] = 'append_end',
target_location: Literal['input', 'output'] = 'input',
target_field: str | None = None,
target_field_resolution_strategy: Literal['random', 'first', 'last', 'all', 'error'] = 'error',
call_limit: int | None = None,
)#

Bases: nat.middleware.function_middleware.FunctionMiddleware

Middleware for red teaming that intercepts and modifies function inputs/outputs.

This middleware enables systematic security testing by injecting attack payloads into function inputs or outputs. It supports flexible targeting, field-level modifications, and multiple attack modes.

Features:

  • Target specific functions or entire function groups

  • Search for specific fields in input/output schemas

  • Apply attacks via replace or append modes

  • Support for both regular and streaming calls

  • Type-safe operations on strings, numbers

Example:

# In YAML config
middleware:
  prompt_injection:
    _type: red_teaming
    attack_payload: "Ignore previous instructions"
    target_function_or_group: my_llm.generate
    payload_placement: append_start
    target_location: input
    target_field: prompt
Args:

attack_payload: The malicious payload to inject. target_function_or_group: Function or group to target (None for all). payload_placement: How to apply (replace, append_start, append_middle, append_end). target_location: Whether to attack input or output. target_field: Field name or path to attack (None for direct value).

Initialize red teaming middleware.

Args:

attack_payload: The value to inject to the function input or output. target_function_or_group: Optional function/group to target. payload_placement: How to apply the payload (replace or append modes). target_location: Whether to place the payload in the input or output. target_field: JSONPath to the field to attack. target_field_resolution_strategy: Strategy (random/first/last/all/error). call_limit: Maximum number of times the middleware will apply a payload.

_attack_payload#
_target_function_or_group = None#
_payload_placement = 'append_end'#
_target_location = 'input'#
_target_field = None#
_target_field_resolution_strategy = 'error'#
_call_count: int = 0#
_call_limit = None#
_should_apply_payload(context_name: str) bool#

Check if this function should be attacked based on targeting configuration.

Args:

context_name: The name of the function from context (e.g., “calculator__add”)

Returns:

True if the function should be attacked, False otherwise

_find_middle_sentence_index(text: str) int#

Find the index to insert text at the middle sentence boundary.

Args:

text: The text to analyze

Returns:

The character index where the middle sentence ends

_apply_payload_to_simple_type(
original_value: list | str | int | float,
attack_payload: str,
payload_placement: str,
) Any#

Apply the attack payload to simple types (str, int, float) value.

Args:

original_value: The original value to attack attack_payload: The payload to inject payload_placement: How to apply the payload

Returns:

The modified value with attack applied

Raises:

ValueError: If attack cannot be applied due to type mismatch

_resolve_multiple_field_matches(matches)#
_apply_payload_to_complex_type(
value: list | dict | pydantic.BaseModel,
) list | dict | pydantic.BaseModel#
_apply_payload_to_function_value(value: Any) Any#
_apply_payload_to_function_value_with_exception(
value: Any,
context: nat.middleware.function_middleware.FunctionMiddlewareContext,
) Any#
async function_middleware_invoke(
*args: Any,
call_next: nat.middleware.function_middleware.CallNext,
context: nat.middleware.function_middleware.FunctionMiddlewareContext,
\*\*kwargs: Any,
) Any#

Invoke middleware for single-output functions.

Args:

args: Positional arguments passed to the function (first arg is typically the input value). call_next: Callable to invoke next middleware/function. context: Metadata about the function being wrapped. kwargs: Keyword arguments passed to the function.

Returns:

The output value (potentially modified if attacking output).