Llama Stack API (Experimental)
Support for the Llama Stack API in NIMs is experimental!
The Llama Stack API is a comprehensive set of interfaces developed by Meta for ML developers building on top of Llama foundation models. This API aims to standardize interactions with Llama models, simplifying the developer experience and fostering innovation across the Llama ecosystem. The Llama Stack encompasses various components of the model lifecycle, including inference, fine-tuning, evaluations, and synthetic data generation.
With the Llama Stack API, developers can easily integrate Llama models into their applications, leverage tool-calling capabilities, and build sophisticated AI systems. This documentation provides an overview of how to use the Python bindings for the Llama Stack API, focusing on chat completions and tool use.
For the full API documentation and source code, please visit the Llama Stack GitHub repository.
To get started with the Llama Stack API, you’ll need to install the necessary packages. You can do this using pip:
pip install llama-toolchain llama-models llama-agentic-system
These packages provide the core functionality for working with the Llama Stack API.
The following example stores common components in the file inference.py
. This file contains the InferenceClient
class and utility functions that are used across different examples.
Here’s the content of inference.py
:
import json
from typing import Union, Generator
import requests
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk
)
class InferenceClient:
def __init__(self, base_url: str):
self.base_url = base_url
def chat_completion(self, request: ChatCompletionRequest) -> Generator[Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk], None, None]:
url = f"{self.base_url}/inference/chat_completion"
payload = json.loads(request.json())
response = requests.post(
url,
json=payload,
headers={"Content-Type": "application/json"},
stream=request.stream
)
if response.status_code != 200:
raise Exception(f"Error: HTTP{response.status_code}{response.text}")
if request.stream:
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data = json.loads(line[6:])
yield ChatCompletionResponseStreamChunk(**data)
else:
response_data = response.json()
# Handle potential None values in tool_calls
if 'completion_message' in response_data and 'tool_calls' in response_data['completion_message']:
tool_calls = response_data['completion_message']['tool_calls']
if tool_calls is not None:
for tool_call in tool_calls:
if 'arguments' in tool_call and tool_call['arguments'] is None:
tool_call['arguments'] = '' # Replace None with empty string
yield ChatCompletionResponse(**response_data)
def process_chat_completion(response: Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]):
if isinstance(response, ChatCompletionResponse):
print("Response content:", response.completion_message.content)
if response.completion_message.tool_calls:
print("Tool calls:")
for tool_call in response.completion_message.tool_calls:
print(f" Tool:{tool_call.tool_name}")
print(f" Arguments:{tool_call.arguments}")
elif isinstance(response, ChatCompletionResponseStreamChunk):
print(response.event.delta, end='', flush=True)
if response.event.stop_reason:
print(f"\nStop reason:{response.event.stop_reason}")
Use these common components in the following basic usage example:
from inference import InferenceClient, process_chat_completion
from llama_toolchain.inference.api import ChatCompletionRequest, UserMessage
from llama_models.llama3.api.datatypes import SamplingParams
def chat():
client = InferenceClient("http://0.0.0.0:8000/experimental/ls")
message = UserMessage(content="Explain the concept of recursion in programming.")
request = ChatCompletionRequest(
model="meta/llama-3.1-70b-instruct",
messages=[message],
stream=False,
sampling_params=SamplingParams(
max_tokens=1024
)
)
for response in client.chat_completion(request):
process_chat_completion(response)
if __name__ == "__main__":
chat()
For streaming responses, use the same structure:
from inference import InferenceClient, process_chat_completion
from llama_toolchain.inference.api import ChatCompletionRequest, UserMessage
from llama_models.llama3.api.datatypes import SamplingParams
def stream_chat():
client = InferenceClient("http://0.0.0.0:8000/experimental/ls")
message = UserMessage(content="Write a short story about a time-traveling scientist.")
request = ChatCompletionRequest(
model="meta/llama-3.1-70b-instruct",
messages=[message],
stream=True,
sampling_params=SamplingParams(
max_tokens=1024
)
)
for response in client.chat_completion(request):
process_chat_completion(response)
if __name__ == "__main__":
stream_chat()
The Llama Stack API supports tool calling, allowing the model to interact with external functions.
Unlike the OpenAI API, the Llama Stack API only supports the tool choices "auto"
, “required"
, or None
.
from inference import InferenceClient, process_chat_completion
from llama_toolchain.inference.api import ChatCompletionRequest, UserMessage, ToolDefinition, ToolParamDefinition
from llama_models.llama3.api.datatypes import SamplingParams, ToolChoice
weather_tool = ToolDefinition(
tool_name="get_current_weather",
description="Get the current weather for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True
),
"unit": ToolParamDefinition(
param_type="string",
description="The temperature unit (celsius or fahrenheit)",
required=True
)
}
)
def tool_calling_example():
client = InferenceClient("http://0.0.0.0:8000/experimental/ls")
message = UserMessage(content="Get me the weather in New York City, NY.")
request = ChatCompletionRequest(
model="meta/llama-3.1-8b-instruct",
messages=[message],
tools=[weather_tool],
tool_choice=ToolChoice.auto,
sampling_params=SamplingParams(
max_tokens=200
)
)
for response in client.chat_completion(request):
process_chat_completion(response)
if __name__ == "__main__":
tool_calling_example()