Custom Embedding Providers for NeMo Guardrails
Custom embedding providers enable you to use your own embedding models for semantic similarity search in the knowledge base and intent detection.
Creating a Custom Embedding Provider
Create a class that inherits from EmbeddingModel:
1 from typing import List 2 from nemoguardrails.embeddings.providers.base import EmbeddingModel 3 4 class CustomEmbedding(EmbeddingModel): 5 """Custom embedding provider.""" 6 7 engine_name = "custom_embedding" 8 9 def __init__(self, embedding_model: str, **kwargs): 10 """Initialize the embedding model. 11 12 Args: 13 embedding_model: The model name from config.yml 14 **kwargs: Additional parameters from config.yml 15 """ 16 self.model_name = embedding_model 17 # Initialize your model here 18 self.model = load_model(embedding_model) 19 20 def encode(self, documents: List[str]) -> List[List[float]]: 21 """Encode documents into embeddings (synchronous). 22 23 Args: 24 documents: List of text documents to encode 25 26 Returns: 27 List of embedding vectors 28 """ 29 return [self.model.encode(doc) for doc in documents] 30 31 async def encode_async(self, documents: List[str]) -> List[List[float]]: 32 """Encode documents into embeddings (asynchronous). 33 34 Args: 35 documents: List of text documents to encode 36 37 Returns: 38 List of embedding vectors 39 """ 40 # For simple models, can just call sync version 41 return self.encode(documents)
Registering the Provider
Register the provider in your config.py:
1 from nemoguardrails import LLMRails 2 3 def init(app: LLMRails): 4 from .embeddings import CustomEmbedding 5 6 app.register_embedding_provider(CustomEmbedding, "custom_embedding")
Using the Provider
Configure in config.yml:
1 models: 2 - type: embeddings 3 engine: custom_embedding 4 model: my-model-name
Example: Sentence Transformers
1 from typing import List 2 from sentence_transformers import SentenceTransformer 3 from nemoguardrails.embeddings.providers.base import EmbeddingModel 4 5 class SentenceTransformerEmbedding(EmbeddingModel): 6 """Embedding provider using sentence-transformers.""" 7 8 engine_name = "sentence_transformers" 9 10 def __init__(self, embedding_model: str, **kwargs): 11 self.model = SentenceTransformer(embedding_model) 12 13 def encode(self, documents: List[str]) -> List[List[float]]: 14 embeddings = self.model.encode(documents) 15 return embeddings.tolist() 16 17 async def encode_async(self, documents: List[str]) -> List[List[float]]: 18 return self.encode(documents)
config.py:
1 from nemoguardrails import LLMRails 2 3 def init(app: LLMRails): 4 app.register_embedding_provider( 5 SentenceTransformerEmbedding, 6 "sentence_transformers" 7 )
config.yml:
1 models: 2 - type: embeddings 3 engine: sentence_transformers 4 model: all-MiniLM-L6-v2
Example: OpenAI-Compatible API
1 from typing import List 2 import httpx 3 from nemoguardrails.embeddings.providers.base import EmbeddingModel 4 5 class OpenAICompatibleEmbedding(EmbeddingModel): 6 """Embedding provider for OpenAI-compatible APIs.""" 7 8 engine_name = "openai_compatible" 9 10 def __init__(self, embedding_model: str, **kwargs): 11 self.model = embedding_model 12 self.api_url = kwargs.get("api_url", "http://localhost:8080/v1/embeddings") 13 14 def encode(self, documents: List[str]) -> List[List[float]]: 15 response = httpx.post( 16 self.api_url, 17 json={"input": documents, "model": self.model} 18 ) 19 data = response.json() 20 return [item["embedding"] for item in data["data"]] 21 22 async def encode_async(self, documents: List[str]) -> List[List[float]]: 23 async with httpx.AsyncClient() as client: 24 response = await client.post( 25 self.api_url, 26 json={"input": documents, "model": self.model} 27 ) 28 data = response.json() 29 return [item["embedding"] for item in data["data"]]
Required Methods
| Method | Description |
|---|---|
__init__(embedding_model: str, **kwargs) | Initialize with model name and additional parameters from config |
encode(documents: List[str]) | Synchronous encoding |
encode_async(documents: List[str]) | Asynchronous encoding |
Class Attributes
| Attribute | Description |
|---|---|
engine_name | Identifier used in config.yml |