nat.middleware.middleware#

Base middleware class for the NeMo Agent toolkit.

This module provides the base Middleware class that defines the middleware pattern for wrapping and modifying function calls. Middleware works like middleware in web frameworks - they can modify inputs, call the next middleware in the chain, process outputs, and continue.

Attributes#

Classes#

FunctionMiddlewareContext

Static metadata about the function being wrapped by middleware.

InvocationContext

Unified context for pre-invoke and post-invoke phases.

Middleware

Base class for middleware-style wrapping with pre/post-invoke hooks.

Module Contents#

CallNext#
CallNextStream#
class FunctionMiddlewareContext#

Static metadata about the function being wrapped by middleware.

Middleware receives this context object which describes the function they are wrapping. This allows middleware to make decisions based on the function’s name, configuration, schema, etc.

name: str#

Name of the function being wrapped.

config: Any#

Configuration object for the function.

description: str | None#

Optional description of the function.

input_schema: type[pydantic.BaseModel] | None#

Schema describing expected inputs or NoneType when absent.

single_output_schema: type[pydantic.BaseModel] | type[None]#

Schema describing single outputs or types.NoneType when absent.

stream_output_schema: type[pydantic.BaseModel] | type[None]#

Schema describing streaming outputs or types.NoneType when absent.

class InvocationContext(/, **data: Any)#

Bases: pydantic.BaseModel

Unified context for pre-invoke and post-invoke phases.

Used for both phases of middleware execution: - Pre-invoke: output is None, modify modified_args/modified_kwargs to transform inputs - Post-invoke: output contains the function result, modify output to transform results

This unified context simplifies the middleware interface by using a single context type for both hooks.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

model_config#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

function_context: FunctionMiddlewareContext = None#
original_args: tuple[Any, Ellipsis] = None#
original_kwargs: dict[str, Any] = None#
modified_args: tuple[Any, Ellipsis] = None#
modified_kwargs: dict[str, Any] = None#
output: Any = None#
class Middleware(*, is_final: bool = False)#

Bases: abc.ABC

Base class for middleware-style wrapping with pre/post-invoke hooks.

Middleware works like middleware in web frameworks:

  1. Preprocess: Inspect and optionally modify inputs (via pre_invoke)

  2. Call Next: Delegate to the next middleware or the target itself

  3. Postprocess: Process, transform, or augment the output (via post_invoke)

  4. Continue: Return or yield the final result

Example:

class LoggingMiddleware(FunctionMiddleware):
    @property
    def enabled(self) -> bool:
        return True

    async def pre_invoke(self, context: InvocationContext) -> InvocationContext | None:
        print(f"Current args: {context.modified_args}")
        print(f"Original args: {context.original_args}")
        return None  # Pass through unchanged

    async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
        print(f"Output: {context.output}")
        return None  # Pass through unchanged
Attributes:
is_final: If True, this middleware terminates the chain. No subsequent

middleware or the target will be called unless this middleware explicitly delegates to call_next.

_is_final = False#
property enabled: bool#
Abstractmethod:

Whether this middleware should execute.

abstractmethod pre_invoke(context: InvocationContext) InvocationContext | None#
Async:

Transform inputs before execution.

Called by specialized middleware invoke methods (e.g., function_middleware_invoke). Use to validate, transform, or augment inputs. At this phase, context.output is None.

Args:
context: Invocation context (Pydantic model) containing:
  • function_context: Static function metadata (frozen)

  • original_args: What entered the middleware chain (frozen)

  • original_kwargs: What entered the middleware chain (frozen)

  • modified_args: Current args (mutable)

  • modified_kwargs: Current kwargs (mutable)

  • output: None (function not yet called)

Returns:

InvocationContext: Return the (modified) context to signal changes None: Pass through unchanged (framework uses current context state)

Note:

Frozen fields (original_args, original_kwargs) cannot be modified. Attempting to modify them raises ValidationError.

Raises:

Any exception to abort execution

abstractmethod post_invoke(context: InvocationContext) InvocationContext | None#
Async:

Transform output after execution.

Called by specialized middleware invoke methods (e.g., function_middleware_invoke). For streaming, called per-chunk. Use to validate, transform, or augment outputs.

Args:
context: Invocation context (Pydantic model) containing:
  • function_context: Static function metadata (frozen)

  • original_args: What entered the middleware chain (frozen)

  • original_kwargs: What entered the middleware chain (frozen)

  • modified_args: What the function received (mutable)

  • modified_kwargs: What the function received (mutable)

  • output: Current output value (mutable)

Returns:

InvocationContext: Return the (modified) context to signal changes None: Pass through unchanged (framework uses current context.output)

Example:

async def post_invoke(self, context: InvocationContext) -> InvocationContext | None:
    # Wrap the output
    context.output = {"result": context.output, "processed": True}
    return context  # Signal modification
Raises:

Any exception to abort and propagate error

property is_final: bool#

Whether this middleware terminates the chain.

A final middleware prevents subsequent middleware and the target from running unless it explicitly calls call_next.

async middleware_invoke(
value: Any,
call_next: CallNext,
context: FunctionMiddlewareContext,
\*\*kwargs: Any,
) Any#

Middleware for single-output invocations.

Args:

value: The input value to process call_next: Callable to invoke the next middleware or target context: Metadata about the target being wrapped kwargs: Additional function arguments

Returns:

The (potentially modified) output from the target

The default implementation simply delegates to call_next. Override this to add preprocessing, postprocessing, or to short-circuit execution:

async def middleware_invoke(self, value, call_next, context, \*\*kwargs):
    # Preprocess: modify input
    modified_input = transform(value)

    # Call next: delegate to next middleware/target
    result = await call_next(modified_input, \*\*kwargs)

    # Postprocess: modify output
    modified_result = transform_output(result)

    # Continue: return final result
    return modified_result
async middleware_stream(
value: Any,
call_next: CallNextStream,
context: FunctionMiddlewareContext,
\*\*kwargs: Any,
) collections.abc.AsyncIterator[Any]#

Middleware for streaming invocations.

Args:

value: The input value to process call_next: Callable to invoke the next middleware or target stream context: Metadata about the target being wrapped kwargs: Additional function arguments

Yields:

Chunks from the stream (potentially modified)

The default implementation forwards to call_next untouched. Override this to add preprocessing, transform chunks, or perform cleanup:

async def middleware_stream(self, value, call_next, context, \*\*kwargs):
    # Preprocess: setup or modify input
    modified_input = transform(value)

    # Call next: get stream from next middleware/target
    async for chunk in call_next(modified_input, \*\*kwargs):
        # Process each chunk
        modified_chunk = transform_chunk(chunk)
        yield modified_chunk

    # Postprocess: cleanup after stream ends
    await cleanup()