Custom LLM Providers for NeMo Guardrails

View as Markdown

This guide covers LangChain-based custom providers (BaseLLM and BaseChatModel) and applies when NEMOGUARDRAILS_LLM_FRAMEWORK=langchain is set. It was the only extension path before 0.22. For the built-in client (the 0.22+ default), implement the LLMModel Protocol instead — see Custom LLM Model.

NeMo Guardrails supports two types of custom LLM providers:

TypeBase ClassInputOutput
Text CompletionBaseLLMString promptString response
Chat ModelBaseChatModelList of messagesMessage response

Text Completion Models (BaseLLM)

For models that work with string prompts:

1from typing import Any, List, Optional
2
3from langchain_core.callbacks.manager import CallbackManagerForLLMRun
4from langchain_core.language_models import BaseLLM
5
6from nemoguardrails.llm.providers import register_llm_provider
7
8class MyCustomLLM(BaseLLM):
9 """Custom text completion LLM."""
10
11 @property
12 def _llm_type(self) -> str:
13 return "my_custom_llm"
14
15 def _call(
16 self,
17 prompt: str,
18 stop: Optional[List[str]] = None,
19 run_manager: Optional[CallbackManagerForLLMRun] = None,
20 **kwargs: Any,
21 ) -> str:
22 """Synchronous text completion."""
23 # Your implementation here
24 return "Generated text response"
25
26 async def _acall(
27 self,
28 prompt: str,
29 stop: Optional[List[str]] = None,
30 run_manager: Optional[CallbackManagerForLLMRun] = None,
31 **kwargs: Any,
32 ) -> str:
33 """Asynchronous text completion (recommended)."""
34 # Your async implementation here
35 return "Generated text response"
36
37# Register the provider
38register_llm_provider("my_custom_llm", MyCustomLLM)

Chat Models (BaseChatModel)

For models that work with message-based conversations:

1from typing import Any, List, Optional
2
3from langchain_core.callbacks.manager import CallbackManagerForLLMRun
4from langchain_core.language_models import BaseChatModel
5from langchain_core.messages import AIMessage, BaseMessage
6from langchain_core.outputs import ChatGeneration, ChatResult
7
8from nemoguardrails.llm.providers import register_chat_provider
9
10class MyCustomChatModel(BaseChatModel):
11 """Custom chat model."""
12
13 @property
14 def _llm_type(self) -> str:
15 return "my_custom_chat"
16
17 def _generate(
18 self,
19 messages: List[BaseMessage],
20 stop: Optional[List[str]] = None,
21 run_manager: Optional[CallbackManagerForLLMRun] = None,
22 **kwargs: Any,
23 ) -> ChatResult:
24 """Synchronous chat completion."""
25 # Convert messages to your model's format
26 response_text = "Generated chat response"
27
28 message = AIMessage(content=response_text)
29 generation = ChatGeneration(message=message)
30 return ChatResult(generations=[generation])
31
32 async def _agenerate(
33 self,
34 messages: List[BaseMessage],
35 stop: Optional[List[str]] = None,
36 run_manager: Optional[CallbackManagerForLLMRun] = None,
37 **kwargs: Any,
38 ) -> ChatResult:
39 """Asynchronous chat completion (recommended)."""
40 response_text = "Generated chat response"
41
42 message = AIMessage(content=response_text)
43 generation = ChatGeneration(message=message)
44 return ChatResult(generations=[generation])
45
46# Register the provider
47register_chat_provider("my_custom_chat", MyCustomChatModel)

Using Custom Providers

After registering your custom provider in config.py, use it in config.yml:

1models:
2 - type: main
3 engine: my_custom_llm # or my_custom_chat
4 model: optional-model-name

Required and Optional Methods

BaseLLM Methods

MethodRequiredDescription
_callYesSynchronous text completion
_llm_typeYesReturns the LLM type identifier
_acallYesAsynchronous text completion
_streamOptionalStreaming text completion
_astreamOptionalAsync streaming text completion

BaseChatModel Methods

MethodRequiredDescription
_generateYesSynchronous chat completion
_llm_typeYesReturns the LLM type identifier
_agenerateRecommendedAsynchronous chat completion
_streamOptionalStreaming chat completion
_astreamOptionalAsync streaming chat completion

Best Practices

  1. Implement async methods: For better performance, always implement _acall (for BaseLLM) or _agenerate (for BaseChatModel).

  2. Choose the right base class:

    • Use BaseLLM for text completion models (prompt → text)
    • Use BaseChatModel for chat models (messages → message)
  3. Import from langchain-core: Always import base classes from langchain_core.language_models.

  4. Use correct registration function:

    • register_llm_provider() for BaseLLM subclasses
    • register_chat_provider() for BaseChatModel subclasses