Custom Embedding Providers for NeMo Guardrails

View as Markdown

Custom embedding providers enable you to use your own embedding models for semantic similarity search in the knowledge base and intent detection.

Creating a Custom Embedding Provider

Create a class that inherits from EmbeddingModel:

1from typing import List
2from nemoguardrails.embeddings.providers.base import EmbeddingModel
3
4class CustomEmbedding(EmbeddingModel):
5 """Custom embedding provider."""
6
7 engine_name = "custom_embedding"
8
9 def __init__(self, embedding_model: str, **kwargs):
10 """Initialize the embedding model.
11
12 Args:
13 embedding_model: The model name from config.yml
14 **kwargs: Additional parameters from config.yml
15 """
16 self.model_name = embedding_model
17 # Initialize your model here
18 self.model = load_model(embedding_model)
19
20 def encode(self, documents: List[str]) -> List[List[float]]:
21 """Encode documents into embeddings (synchronous).
22
23 Args:
24 documents: List of text documents to encode
25
26 Returns:
27 List of embedding vectors
28 """
29 return [self.model.encode(doc) for doc in documents]
30
31 async def encode_async(self, documents: List[str]) -> List[List[float]]:
32 """Encode documents into embeddings (asynchronous).
33
34 Args:
35 documents: List of text documents to encode
36
37 Returns:
38 List of embedding vectors
39 """
40 # For simple models, can just call sync version
41 return self.encode(documents)

Registering the Provider

Register the provider in your config.py:

1from nemoguardrails import LLMRails
2
3def init(app: LLMRails):
4 from .embeddings import CustomEmbedding
5
6 app.register_embedding_provider(CustomEmbedding, "custom_embedding")

Using the Provider

Configure in config.yml:

1models:
2 - type: embeddings
3 engine: custom_embedding
4 model: my-model-name

Example: Sentence Transformers

1from typing import List
2from sentence_transformers import SentenceTransformer
3from nemoguardrails.embeddings.providers.base import EmbeddingModel
4
5class SentenceTransformerEmbedding(EmbeddingModel):
6 """Embedding provider using sentence-transformers."""
7
8 engine_name = "sentence_transformers"
9
10 def __init__(self, embedding_model: str, **kwargs):
11 self.model = SentenceTransformer(embedding_model)
12
13 def encode(self, documents: List[str]) -> List[List[float]]:
14 embeddings = self.model.encode(documents)
15 return embeddings.tolist()
16
17 async def encode_async(self, documents: List[str]) -> List[List[float]]:
18 return self.encode(documents)

config.py:

1from nemoguardrails import LLMRails
2
3def init(app: LLMRails):
4 app.register_embedding_provider(
5 SentenceTransformerEmbedding,
6 "sentence_transformers"
7 )

config.yml:

1models:
2 - type: embeddings
3 engine: sentence_transformers
4 model: all-MiniLM-L6-v2

Example: OpenAI-Compatible API

1from typing import List
2import httpx
3from nemoguardrails.embeddings.providers.base import EmbeddingModel
4
5class OpenAICompatibleEmbedding(EmbeddingModel):
6 """Embedding provider for OpenAI-compatible APIs."""
7
8 engine_name = "openai_compatible"
9
10 def __init__(self, embedding_model: str, **kwargs):
11 self.model = embedding_model
12 self.api_url = kwargs.get("api_url", "http://localhost:8080/v1/embeddings")
13
14 def encode(self, documents: List[str]) -> List[List[float]]:
15 response = httpx.post(
16 self.api_url,
17 json={"input": documents, "model": self.model}
18 )
19 data = response.json()
20 return [item["embedding"] for item in data["data"]]
21
22 async def encode_async(self, documents: List[str]) -> List[List[float]]:
23 async with httpx.AsyncClient() as client:
24 response = await client.post(
25 self.api_url,
26 json={"input": documents, "model": self.model}
27 )
28 data = response.json()
29 return [item["embedding"] for item in data["data"]]

Required Methods

MethodDescription
__init__(embedding_model: str, **kwargs)Initialize with model name and additional parameters from config
encode(documents: List[str])Synchronous encoding
encode_async(documents: List[str])Asynchronous encoding

Class Attributes

AttributeDescription
engine_nameIdentifier used in config.yml