Creating Custom Actions

View as Markdown

This section describes how to create custom actions in the actions.py file.

The @action Decorator

Use the @action decorator from nemoguardrails.actions to define custom actions:

1from nemoguardrails.actions import action
2
3@action()
4async def my_custom_action():
5 """A simple custom action."""
6 return "result"

Decorator Parameters

ParameterTypeDescriptionDefault
namestrCustom name for the actionFunction name
is_system_actionboolAlways run locally, bypassing the actions serverFalse
execute_asyncboolDon’t block event processing while the action runs (Colang 2.x only)False
output_mappingCallable[[Any], bool]Function to interpret the action result for blocking decisionsdefault_output_mapping

Custom Action Name

Override the default action name:

1@action(name="validate_user_input")
2async def check_input(text: str):
3 """Validates user input."""
4 return len(text) > 0

Call from Colang:

$is_valid = execute validate_user_input(text=$user_message)

System Actions

When is_system_action=True, the action always runs locally, even when an actions_server_url is configured. This is important for actions that need access to special parameters like context, llm, config, and events, which are only injected for locally-run actions.

When no actions_server_url is configured, all actions run locally and receive special parameters regardless of the is_system_action setting. The flag only affects behavior when an actions server is in use.

1@action(is_system_action=True)
2async def check_policy_compliance(context: Optional[dict] = None):
3 """Check if message complies with policy."""
4 message = context.get("last_user_message", "")
5 # Validation logic
6 return True

Async Execution

When execute_async=True, the event processing loop does not wait for the action to complete before continuing. The action runs in the background and the result is picked up later via polling. This is useful for long-running operations where you don’t need the result immediately.

This flag is only supported in the Colang 2.x runtime. In the Colang 1.0 runtime, it is stored in metadata but has no effect.

1@action(execute_async=True)
2async def call_external_api(endpoint: str):
3 """Call an external API without blocking event processing."""
4 response = await http_client.get(endpoint)
5 return response.json()

Output Mapping

The output_mapping parameter controls how the action’s return value is interpreted to determine if output should be blocked. It accepts a callable that takes the return value and returns True if the output is not safe (should be blocked).

When no output_mapping is provided, the default behavior is:

  • Boolean results: True means allowed, False means blocked
  • Numeric results: Values below 0.5 are blocked
  • Other types: Allowed by default
1@action(output_mapping=lambda value: value)
2async def check_hallucination(context: Optional[dict] = None):
3 """Return True if hallucination detected (blocked), False if safe."""
4 return detect_hallucination(context.get("bot_message", ""))
1@action(is_system_action=True, output_mapping=lambda value: not value)
2async def check_output_safety(context: Optional[dict] = None):
3 """Return True if safe (allowed), mapped to not-blocked."""
4 return is_safe(context.get("bot_message", ""))

You can also define a custom mapping function for more complex logic:

1def my_custom_mapping(result):
2 if isinstance(result, dict):
3 return result.get("score", 1.0) < 0.7
4 return False
5
6@action(output_mapping=my_custom_mapping)
7async def score_safety(context: Optional[dict] = None):
8 """Return a dict with a safety score."""
9 return {"score": compute_score(context.get("bot_message", ""))}

Function Parameters

Actions can accept parameters of the following types:

TypeExample
str"hello"
int42
float3.14
boolTrue
list["a", "b", "c"]
dict{"key": "value"}

Basic Parameters

1@action()
2async def greet_user(name: str, formal: bool = False):
3 """Generate a greeting."""
4 if formal:
5 return f"Good day, {name}."
6 return f"Hello, {name}!"

Call from Colang:

$greeting = execute greet_user(name="Alice", formal=True)

Optional Parameters with Defaults

1@action()
2async def search_documents(
3 query: str,
4 max_results: int = 10,
5 include_metadata: bool = False
6):
7 """Search documents with optional parameters."""
8 results = perform_search(query, limit=max_results)
9 if include_metadata:
10 return {"results": results, "count": len(results)}
11 return results

Return Values

Actions can return various types:

Simple Return

1@action()
2async def get_status():
3 return "active"

Dictionary Return

1@action()
2async def get_user_info(user_id: str):
3 return {
4 "id": user_id,
5 "name": "John Doe",
6 "role": "admin"
7 }

Boolean Return (for validation)

1@action(is_system_action=True)
2async def is_safe_content(context: Optional[dict] = None):
3 content = context.get("bot_message", "")
4 # Returns True if safe, False if blocked
5 return not contains_harmful_content(content)

Error Handling

Handle errors gracefully within actions:

1@action()
2async def fetch_data(url: str):
3 """Fetch data with error handling."""
4 try:
5 response = await http_client.get(url)
6 response.raise_for_status()
7 return response.json()
8 except Exception as e:
9 # Log the error
10 print(f"Error fetching data: {e}")
11 # Return a safe default or raise
12 return None

Example Actions

Input Validation Action

1from typing import Optional
2from nemoguardrails.actions import action
3
4@action(is_system_action=True)
5async def check_input_length(context: Optional[dict] = None):
6 """Ensure user input is not too long."""
7 user_message = context.get("last_user_message", "")
8 max_length = 1000
9
10 if len(user_message) > max_length:
11 return False # Block the input
12
13 return True # Allow the input

Output Filtering Action

1@action(is_system_action=True)
2async def filter_sensitive_data(context: Optional[dict] = None):
3 """Check for sensitive data in bot response."""
4 bot_response = context.get("bot_message", "")
5
6 sensitive_patterns = [
7 r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern
8 r"\b\d{16}\b", # Credit card pattern
9 ]
10
11 import re
12 for pattern in sensitive_patterns:
13 if re.search(pattern, bot_response):
14 return True # Contains sensitive data
15
16 return False # No sensitive data found

External API Action

1import aiohttp
2
3@action(execute_async=True)
4async def query_knowledge_base(query: str, top_k: int = 5):
5 """Query an external knowledge base API."""
6 async with aiohttp.ClientSession() as session:
7 async with session.post(
8 "https://api.example.com/search",
9 json={"query": query, "limit": top_k}
10 ) as response:
11 data = await response.json()
12 return data.get("results", [])