Custom LLM Models#
The NVIDIA NeMo Guardrails library defines a small LLMModel protocol that every backend implements. The built-in DefaultFramework ships an OpenAIChatModel for any OpenAI-compatible HTTP endpoint, and the optional LangChainFramework ships a LangChainLLMAdapter that wraps any LangChain BaseChatModel or BaseLLM. When neither matches your backend, you can implement LLMModel directly.
This guide covers when to do that, the contract you must satisfy, a minimal worked example, and pointers to the reference implementations and to the testing helpers.
When to Use a Custom LLMModel#
There are three options for connecting a backend to the NVIDIA NeMo Guardrails library. Pick the best fit.
Backend shape |
Recommended path |
Where it lives |
|---|---|---|
OpenAI-compatible HTTP endpoint, such as vLLM, TGI, OpenRouter, self-hosted, NIM, and other endpoints |
Use |
Custom LLM Providers and the configuration reference |
You already have a LangChain |
Use |
|
Custom HTTP API that is not OpenAI-shaped, and you do not want a LangChain dependency |
Implement |
This guide |
Concretely, choose a custom LLMModel when:
Your provider speaks a non-OpenAI wire format and you do not want to depend on LangChain.
You want full control over retries, headers, streaming parsing, and tool-call accumulation.
You want a lean install footprint (no
langchain-*packages) and you control the HTTP layer yourself.
The LLMModel Contract#
The protocol is nemoguardrails.types.LLMModel. It is @runtime_checkable, so the framework registry can verify with isinstance(model, LLMModel).
A custom model class must implement two async methods and three properties.
from typing import AsyncIterator, List, Optional, Union
from nemoguardrails import (
ChatMessage,
LLMResponse,
LLMResponseChunk,
)
class LLMModel:
async def generate_async(
self,
prompt: Union[str, List[ChatMessage]],
*,
stop: Optional[List[str]] = None,
**kwargs,
) -> LLMResponse: ...
async def stream_async(
self,
prompt: Union[str, List[ChatMessage]],
*,
stop: Optional[List[str]] = None,
**kwargs,
) -> AsyncIterator[LLMResponseChunk]:
yield ... # async generator: implementations use `yield`, not `return`
@property
def model_name(self) -> str: ...
@property
def provider_name(self) -> Optional[str]: ...
@property
def provider_url(self) -> Optional[str]: ...
prompt#
Adapters must accept either a plain string or a list of ChatMessage objects. ChatMessage is a stdlib dataclass with role, content, optional tool_calls, optional tool_call_id, optional name, and a provider_metadata dict for non-standard fields. Convert messages to whatever shape your SDK expects.
stop and **kwargs#
stop is the canonical name for stop sequences; keep it as a keyword-only argument. **kwargs carries everything the caller passed under parameters in config.yml plus any per-call overrides, such as temperature, max_tokens, and top_p. Forward these to the underlying SDK.
generate_async returns LLMResponse#
LLMResponse is a dataclass in nemoguardrails/types.py:
@dataclass
class LLMResponse:
content: str
reasoning: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
model: Optional[str] = None
finish_reason: Optional[FinishReason] = None
stop_sequence: Optional[str] = None
request_id: Optional[str] = None
usage: Optional[UsageInfo] = None
provider_metadata: Optional[Dict[str, Any]] = None
content is required and must be a string (use the empty string when the model only produced tool calls). finish_reason is one of "stop", "length", "tool_calls", "content_filter", "error", or "other". Populate tool_calls only when the response is a function-calling/tool-calling response.
stream_async is an async generator#
Implementations must be async def generator functions that yield LLMResponseChunk objects. The protocol’s return type is AsyncIterator[LLMResponseChunk]. Each chunk has the shape:
@dataclass
class LLMResponseChunk:
delta_content: Optional[str] = None
delta_reasoning: Optional[str] = None
delta_tool_calls: Optional[List[ToolCall]] = None
model: Optional[str] = None
finish_reason: Optional[FinishReason] = None
request_id: Optional[str] = None
usage: Optional[UsageInfo] = None
provider_metadata: Optional[Dict[str, Any]] = None
Follow these conventions so the rest of the pipeline works:
Yield text deltas in
delta_contentas soon as they arrive.Yield
delta_reasoningfor chain-of-thought tokens emitted before the visible answer (OpenAI reasoning models, NIMreasoning_content).Tool-call streaming is incremental on the wire: provider chunks usually carry argument fragments. Accumulate them and emit a single completed
delta_tool_callslist on the chunk whosefinish_reason == "tool_calls". The referenceOpenAIChatModel._finalize_tool_callsshows the pattern.Set
finish_reasononly on the final chunk that carries it. Earlier chunks should leave itNone.Emit a final usage-only chunk (no
delta_content, onlyusageandrequest_id) when the provider sends an end-of-stream usage record. The pipeline tolerates either inline or trailing usage.
Tool calling#
ToolCall and ToolCallFunction are dataclasses:
@dataclass
class ToolCallFunction:
name: str
arguments: Dict[str, Any]
@dataclass
class ToolCall:
id: str
type: str = "function"
function: ToolCallFunction = field(default_factory=lambda: ToolCallFunction(name="", arguments={}))
function.arguments is a Dict[str, Any], not a JSON string. If your provider returns arguments as a JSON string, json.loads() it before constructing the ToolCall. If parsing fails for a streamed response, fall back to an empty dict; the tool layer will surface the real error when the function is invoked.
Properties#
model_namereturns the concrete model identifier (for examplegpt-4o-mini,meta/llama-3.1-70b-instruct). Used in logs and error contexts.provider_namereturns the engine name as it appears inconfig.yml(for exampleopenai,nim,my_engine). ReturnNoneonly if you genuinely cannot determine it.provider_urlreturns the base URL for HTTP backends, orNonefor backends that do not have one (for example a SageMaker endpoint addressed by ARN).
Error handling#
The pipeline expects errors to be normalized. Raise the exception classes defined in nemoguardrails.exceptions:
LLMConnectionErrorfor network or DNS failures.LLMTimeoutErrorfor read or connect timeouts.LLMAuthenticationErrorfor 401 or 403.LLMRateLimitErrorfor 429.LLMResponseValidationErrorfor malformed provider responses.LLMClientErroris the common base if you need a generic fallback.
Populate model_name, provider_name, and base_url on the exception when you raise it so downstream logs are usable. The reference OpenAIChatModel._enrich shows the pattern.
Minimal Working Example#
Below is a 40-line EchoLLMModel that returns canned responses without making any network call. It is useful as a starting skeleton and as a sanity check for new framework wiring.
Create a config directory my_config/ next to your smoke-test script with two files:
my_config/
├── config.py # EchoLLMModel + register_provider call, run at import time
└── config.yml # references the registered engine name
my_config/config.py:
import asyncio
from typing import Any, AsyncIterator, List, Optional, Union
from nemoguardrails import (
ChatMessage,
LLMResponse,
LLMResponseChunk,
UsageInfo,
register_provider,
)
class EchoLLMModel:
"""Returns a canned response. Useful as a skeleton or in offline tests."""
def __init__(self, model: str, response: str = "echo", **kwargs: Any):
self._model = model
self._response = response
self._default_kwargs = kwargs
@property
def model_name(self) -> str:
return self._model
@property
def provider_name(self) -> Optional[str]:
return "echo"
@property
def provider_url(self) -> Optional[str]:
return None
async def generate_async(
self,
prompt: Union[str, List[ChatMessage]],
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LLMResponse:
return LLMResponse(
content=self._response,
model=self._model,
finish_reason="stop",
usage=UsageInfo(input_tokens=0, output_tokens=len(self._response)),
)
async def stream_async(
self,
prompt: Union[str, List[ChatMessage]],
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[LLMResponseChunk]:
for token in self._response.split():
await asyncio.sleep(0)
yield LLMResponseChunk(delta_content=token + " ", model=self._model)
yield LLMResponseChunk(model=self._model, finish_reason="stop")
register_provider("echo", EchoLLMModel)
The register_provider call attaches EchoLLMModel as the echo engine on whichever framework is currently active. By default, that is DefaultFramework. For the framework layer, refer to Custom LLM Framework.
my_config/config.yml:
models:
- type: main
engine: echo
model: echo-v1
parameters:
response: "Hello from echo"
Trying it out#
Run a smoke test from the parent directory of my_config/. LLMRails imports config.py automatically, which triggers the register_provider call at the bottom of that file:
# smoke.py (next to my_config/)
from nemoguardrails import LLMRails, RailsConfig
config = RailsConfig.from_path("./my_config")
rails = LLMRails(config)
result = rails.generate(messages=[{"role": "user", "content": "hi"}])
print(result["content"]) # -> "Hello from echo"
If the smoke test prints Hello from echo, your provider is registered correctly. From there, replace EchoLLMModel.generate_async and stream_async with real backend calls.
What register_provider does#
register_provider(name, cls) from nemoguardrails.llm.providers resolves the active framework with get_default_framework() and calls framework.register_provider(name, cls) on it. For DefaultFramework, that adds name to its in-memory dict. Subsequent create_model("echo", ...) calls use your class as the factory. The active framework is selected once per process by NEMOGUARDRAILS_LLM_FRAMEWORK or set_default_framework() from config.py. You do not register on multiple frameworks.
Calling-convention contract for your __init__#
framework.create_model(model_name, provider_name, model_kwargs) calls your class as EchoLLMModel(model=model_name, **model_kwargs). Make model a required keyword and accept additional **kwargs so that future configuration keys do not break instantiation.
Reference Implementations#
Review these production-grade LLMModel implementations:
nemoguardrails/llm/models/openai_chat.py:OpenAIChatModelfor any OpenAI-compatible HTTP endpoint. Shows tool-call accumulation, reasoning-content extraction, response validation, and exception enrichment. UsesOpenAICompatibleClientfor the HTTP layer.nemoguardrails/integrations/langchain/llm_adapter.py:LangChainLLMAdapterthat bridges any LangChainBaseChatModelorBaseLLM. Shows how to map LangChain’stool_call_chunks,usage_metadata,response_metadata, andadditional_kwargsonto theLLMResponseandLLMResponseChunkshapes.
Both files import their types directly from nemoguardrails.types. Custom models should do the same.
Testing Your Model#
The NVIDIA NeMo Guardrails library ships a pytest-friendly FakeLLMModel under nemoguardrails.testing that is shaped exactly like the protocol and accepts a list of canned strings or LLMResponse objects:
from nemoguardrails.testing import FakeLLMModel
The two recommended approaches:
Write unit tests for your
LLMModelclass in isolation: instantiate it, callawait model.generate_async(prompt), and assert on the returnedLLMResponse. No framework needed.Write end-to-end tests with a real
LLMRailsinstance by registering aFakeLLMModel(orFakeLLMModel-style class) as a custom provider in the test’sconfig.py, then driving the full pipeline.
The contract is small enough that property-based tests are straightforward: any string prompt and any list of ChatMessage objects must produce a non-None LLMResponse.content, and stream_async must always yield a final chunk with a non-None finish_reason.
Best Practices#
Implement both methods even if your backend has no native streaming. A simple
stream_asyncthat yields a single chunk built fromgenerate_asynckeeps the streaming consumer paths working.Pre-flight validate provider responses. The reference
OpenAIChatModel._validate_responserejects non-dict bodies and missingchoicesentries before parsing. This keeps user-facing errors actionable.Forward
**kwargsto the SDK. Anything the user wrote underparametersinconfig.ymllands here. Letting unknown keys pass through means new SDK options work without a library release.Pool shared backend clients on the framework.
create_modelis called once permodels:entry atLLMRailsstartup. After that, your model handles every request. If multiplemodels:entries point at the same backend, the framework, not the model, should hold the underlying client so they share one connection pool.DefaultFramework._get_or_create_clientkeys clients by(base_url, api_key, ...)for exactly this reason. Single-model configs can build the client directly in__init__.Do not raise vanilla
Exception. Use thenemoguardrails.exceptionshierarchy so retries and structured logging behave correctly.