nat.plugins.security.middleware.guardrails.nemo_guardrails_middleware#

NeMo Guardrails policy middleware for NAT function boundaries.

Exceptions#

PostInvokeBlockedError

Raised by on_post_invoke_blocked when a post_invoke rail blocks the function output.

Classes#

GuardrailsMiddleware

Hosts NeMo Guardrails as a policy engine at configured function boundaries.

Module Contents#

exception PostInvokeBlockedError(message: str)#

Bases: RuntimeError

Raised by on_post_invoke_blocked when a post_invoke rail blocks the function output.

Carries the rail’s message so callers can surface it or re-raise as appropriate. Override GuardrailsMiddleware.on_post_invoke_blocked to return a value instead.

Initialize self. See help(type(self)) for accurate signature.

block_message: str#
class GuardrailsMiddleware(
config: nat.plugins.security.middleware.guardrails.nemo_guardrails_middleware_config.GuardrailsMiddlewareConfig,
builder: nat.builder.builder.Builder,
)#

Bases: nat.middleware.dynamic.dynamic_function_middleware.DynamicFunctionMiddleware

Hosts NeMo Guardrails as a policy engine at configured function boundaries.

Input rails run on pre_invoke against the function’s input value(s); output rails run on post_invoke against the function’s output value(s). Each selected value is evaluated on its own:

  • pre_invoke (input): a passed value leaves the argument unchanged; a modified value is written back into the invocation context’s modified_args/modified_kwargs (siblings untouched) so the function runs on the modified input; a blocked value skips the function call and returns the refusal message.

  • post_invoke (output): a passed value returns the original output; a modified value is written back in place and the structurally-preserved output is returned; a blocked value returns the refusal message in place of the output.

A NeMo Guardrails policy can declare many rails, and the library runs them as a chain (all input rails, then all output rails, within one LLMRails instance), so one middleware can apply every rail a function needs.

Setting is_final=True enforces GuardrailsMiddleware to operate directly on the function call itself, ensuring call_next always invokes the function and not another middleware.

Initialize Guardrails middleware and register configured function targets.

Args:

config: Guardrails middleware configuration with required RailsConfig on config.guardrails. builder: Workflow builder used for rail LLM bindings.

_llm_rails: nemoguardrails.LLMRails#
_guardrails_config: nat.plugins.security.middleware.guardrails.nemo_guardrails_middleware_config.GuardrailsMiddlewareConfig#
_rail_llms: set[str]#
_rail_llms_bound: bool = False#
async pre_invoke(
context: nat.middleware.middleware.InvocationContext,
) nat.middleware.middleware.InvocationContext | None#

Run input rails over the input fields and block on refusal.

Args:

context: Invocation context for the current boundary.

Returns:

Updated context when an input is blocked or rewritten; otherwise None.

async post_invoke(
context: nat.middleware.middleware.InvocationContext,
) nat.middleware.middleware.InvocationContext | None#

Run output rails over the output fields and block on refusal.

Args:

context: Invocation context including function output.

Returns:

Updated context when an output is blocked or rewritten; otherwise None.

async function_middleware_invoke(
*args: Any,
call_next: nat.middleware.function_middleware.CallNext,
context: nat.middleware.middleware.FunctionMiddlewareContext,
\*\*kwargs: Any,
) Any#

Run input and output Guardrails rails around a non-streaming function call.

Args:

args: Positional arguments for the wrapped function. call_next: Next middleware or target function in the chain. context: Static metadata for the wrapped function. kwargs: Keyword arguments for the wrapped function.

Returns:

Function output, possibly rewritten or replaced by policy.

async function_middleware_stream(
*args: Any,
call_next: nat.middleware.function_middleware.CallNextStream,
context: nat.middleware.middleware.FunctionMiddlewareContext,
\*\*kwargs: Any,
) collections.abc.AsyncIterator[Any]#

Run Guardrails rails around a streaming call.

When stream_output_rails is False (default), the full stream is buffered and evaluated with generate_async() before any output is yielded.

When stream_output_rails is True, output rails are applied token-by-token via LLMRails.stream_async(). The Colang policy must set rails.output.streaming.enabled: true. Blocks are signalled by a JSON sentinel yielded into the stream; the middleware translates these into on_post_invoke_blocked.

Args:

args: Positional arguments for the wrapped function. call_next: Next middleware or target stream in the chain. context: Static metadata for the wrapped function. kwargs: Keyword arguments for the wrapped function.

Yields:

Stream output after output rails evaluate the payload.

async _stream_with_output_rails(
ctx: nat.middleware.middleware.InvocationContext,
call_next: nat.middleware.function_middleware.CallNextStream,
) collections.abc.AsyncIterator[Any]#

Apply output rails to a live token stream via LLMRails.stream_async().

Iterates the stream produced by call_next, passes it through the configured output rails, and translates any JSON block sentinel into on_post_invoke_blocked.

Args:

ctx: Invocation context carrying the (possibly modified) arguments. call_next: Next middleware or target stream in the chain.

Yields:

Evaluated stream chunks, or the block message when a rail fires.

_set_modified_rail_value(
obj: Any,
name: str,
) collections.abc.Callable[[str], None]#

Build a setter that writes a modified rail value back to an object attribute.

Args:

obj: Object whose attribute is reassigned. name: Attribute name to write.

Returns:

A callable assigning its argument to obj.name.

_set_modified_rail_value_in_list(
items: list[Any],
index: int,
) collections.abc.Callable[[str], None]#

Build a setter that writes a modified rail value back to a list element.

Args:

items: List whose element is reassigned. index: Position to write.

Returns:

A callable assigning its argument to items[index].

_iter_targets_at_path(
value: Any,
path: str,
) collections.abc.Iterator[tuple[str, collections.abc.Callable[[str], None]]]#

Yield each string reached by a dotted path with a setter to rewrite it in place.

Args:

value: Root object to traverse (model instance or list of them). path: Dotted attribute path, e.g. reviews.review.

Yields:

(text, setter) pairs where setter(new_text) rewrites that leaf.

_handle_modified_rail_response(
response: nemoguardrails.rails.llm.options.GenerationResponse,
*,
fallback: str,
) str#

Resolve the output text from a rail response that passed or was modified.

Args:

response: Rail generation response from a passed or modified evaluation. fallback: Value returned when the response carries no extractable text.

Returns:

The response text, preferring the last assistant-role message when the response is a message list; fallback when nothing can be extracted.

_handle_blocked_rail_response(
response: nemoguardrails.rails.llm.options.GenerationResponse,
) str#

Return the safe refusal message from a blocked NeMo Guardrails response.

Args:

response: Generation response from a rail that returned a blocked verdict.

Returns:

The refusal text from the rail, or _DEFAULT_REFUSAL when the blocked response carries no message.

_apply_modified_input(
context: nat.middleware.middleware.InvocationContext,
text: str,
) None#

Write the rail’s modified input text back to the invocation context.

Args:

context: Invocation context whose modified_args[0] is updated in-place. text: Modified input text returned by the rail.

_resolve_guarded_targets(name: str) list[str]#

Expand the workflow_functions config entry for a function into dotted field paths.

Args:

name: Fully-qualified function name (e.g. retail_tools__get_product_info).

Returns:

Dotted field paths for the function; empty when workflow_functions is a list or the function has no explicit field selection.

_register_function(
discovered: nat.middleware.utils.workflow_inventory.DiscoveredFunction,
) None#

Validate configured field paths, then register the function for interception.

Args:

discovered: Discovered workflow function from the inventory.

_validate_guarded_field_paths(
discovered: nat.middleware.utils.workflow_inventory.DiscoveredFunction,
) None#

Raise when a configured field path matches no string field on the function schema.

Args:

discovered: Registered workflow function carrying the input and output schemas.

Raises:
ValueError: When a configured dotted path resolves to no string field on either

the input or the output schema.

_validation_schemas(
schema: type[pydantic.BaseModel],
) list[type[pydantic.BaseModel]]#

Return the schemas a path may resolve against, accounting for NAT’s output wrapper.

Args:

schema: Declared input or output schema.

Returns:

The schema itself, plus the unwrapped type of its value field when present.

_path_resolves_to_string(
schema: type[pydantic.BaseModel],
path: str,
) bool#

Return whether a dotted path resolves to a string (or list-of-string) leaf on a schema.

Args:

schema: Pydantic model to walk from. path: Dotted attribute path, e.g. reviews.review.

Returns:

True when every segment exists and the leaf annotation is str.

_unwrap_optional_and_list(annotation: Any) Any#

Strip Optional and list-like wrappers from a type annotation.

Args:

annotation: Field annotation to unwrap.

Returns:

The inner type with Optional and sequence layers removed; the annotation unchanged when a union is ambiguous or a wrapper carries no element type.

_gather_guardrail_inputs(
value: Any,
paths: list[str],
whole_setter: collections.abc.Callable[[str], None],
) collections.abc.Iterator[tuple[str, collections.abc.Callable[[str], None]]]#

Yield (text, setter) rail targets for one boundary value.

Args:

value: Boundary value (input argument for pre, output for post). paths: Configured dotted field paths; empty selects the default targets. whole_setter: Setter applied when guarding a non-model value as a whole.

Yields:

(text, setter) pairs to evaluate against a rail.

_iter_default_targets(
value: Any,
whole_setter: collections.abc.Callable[[str], None],
) collections.abc.Iterator[tuple[str, collections.abc.Callable[[str], None]]]#

Yield rail targets when no field paths are configured.

Args:

value: Boundary value (input object for pre, output for post). whole_setter: Setter used when the value is guarded as a single string.

Yields:

(text, setter) pairs for each top-level string (or list-of-string element).

_rail_blocked(
response: nemoguardrails.rails.llm.options.GenerationResponse,
) bool#

Return whether any activated rail signaled a block.

Args:

response: Rail generation response.

Returns:

True when an activated rail set stop.

async bind_llms_to_rail() None#

Register NAT-configured LLMs as NeMo Guardrails rail action parameters.

_should_intercept_llm(llm_name: str) bool#

Return whether the middleware should wrap LLM creation for the given name.

Args:

llm_name: NAT LLM component name.

Returns:

False for rail-bound LLMs so bindings are not double-wrapped.

on_post_invoke_blocked(
context: nat.middleware.middleware.InvocationContext,
block_message: str,
) Any#

Extension point called when a post_invoke rail blocks the function output.

The default returns the rail’s block_message as the function output — the policy’s own response to the blocked content. Override in a subclass to raise or return a different value:

Args:

context: Invocation context at the time of the block. block_message: Message from the blocking rail.

Returns:

Value to use as the function output. The rail’s block message by default.