Generation Interface#
This document explains the token generation interface and various backends for the NeMo RL framework. The generation system is designed with a unified interface that allows different backends (like VLLM, Megatron, Hugging Face, SGLang, and TRT-LLM) to provide token generation capabilities while adhering to the same API.
Generation Interface#
The core of the generation system is defined in interfaces.py, which establishes an abstract interface that all generation backends must implement. This ensures consistency across different implementations and makes it easy to swap backends without changing the calling code.
Key Components#
GenerationConfig: A TypedDict that defines the configuration for generation:
class GenerationConfig(TypedDict): """Configuration for generation.""" backend: str # The backend to use (e.g., "vllm", "megatron", "hf") max_new_tokens: int # Maximum number of tokens to generate temperature: float # Sampling temperature top_p: float # Top-p sampling parameter top_k: int | None # Top-k sampling parameter model_name: str # Name or path of the model
GenerationDatumSpec: A TypedDict that defines the input data format:
class GenerationDatumSpec(TypedDict): input_ids: torch.Tensor # Input token IDs attention_mask: torch.Tensor # Attention mask __extra__: Any # Additional data specific to the backend
GenerationOutputSpec: A TypedDict that defines output data format:
class GenerationOutputSpec(TypedDict): output_ids: torch.Tensor generation_lengths: torch.Tensor # Length of just the generated response part unpadded_sequence_lengths: torch.Tensor # Length of full valid sequence (input + generated response) logprobs: torch.Tensor __extra__: Any # Additional output data specific to the backend
GenerationInterface: An abstract base class that all generation backends must implement:
class GenerationInterface(ABC): """Abstract base class defining the interface for RL policies.""" @abstractmethod def generate( self, data: BatchedDataDict["GenerationDatumSpec"], greedy: bool ) -> BatchedDataDict["GenerationOutputSpec"]: pass @abstractmethod def prepare_for_generation(self, *args, **kwargs): pass @abstractmethod def finish_generation(self, *args, **kwargs): pass
A key design principle for generation backends is that they process tokens directly, without involving the tokenizer. By ensuring that only tokens are exchanged, we eliminate the risk of inconsistencies arising from different tokenizer versions or specifications between the training and generation frameworks.
Generation Backends#
NeMo RL supports multiple generation backends that implement the GenerationInterface to provide efficient text generation for different use cases.
VLLM Backend#
The VLLM backend (models/generation/vllm/vllm_generation.py) implements the GenerationInterface to provide efficient text generation using the VLLM library, which is optimized for large language models.
VllmGeneration Class#
The VllmGeneration class is the main implementation of the GenerationInterface for VLLM. It performs the following functions:
Sets up VLLM workers in a distributed environment using Ray.
Manages the lifecycle of these workers (initialization, generation, shutdown).
Distributes inputs to workers and collects outputs.
Handles weight updates and synchronization.
VllmGenerationWorker#
The VllmGenerationWorker is a Ray actor that:
Initializes and manages a VLLM model instance.
Performs the actual generation on a GPU.
Supports dynamic weight updates through IPC handles.
Implements sleep/wake mechanisms for efficient resource utilization.
Custom VLLM Extensions#
The UpdatableVllmInternalWorker class in vllm_backend.py extends the VLLM worker with additional capabilities:
Reporting device IDs to allow mapping of workers to specific GPUs.
Updating weights from IPC handles for efficient weight sharing.
Checking if weights have been updated correctly.
Megatron Backend#
The Megatron backend provides native Megatron-Core inference capabilities, eliminating the need for weight conversion between training and generation. This backend is particularly beneficial when using Megatron for training, as it enables seamless integration and optimal performance.
Key Features#
No Weight Conversion: Uses the same Megatron model format for both training and generation, eliminating conversion overhead and potential inconsistencies.
CUDA Graph Support: Leverages CUDA graphs for optimized inference performance.
Dynamic Inference Engine: Utilizes Megatron Core’s
DynamicInferenceEnginefor efficient batched generation.Integrated with Training: The generation capability is built directly into the
MegatronPolicyWorker, enabling efficient co-located training and generation.
MegatronPolicyWorker Generation#
The Megatron generation backend is implemented within the MegatronPolicyWorker class. The generate <nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker.generate> method performs the following:
Wraps the Megatron model with
GPTInferenceWrapperfor inference optimization.Creates a
DynamicInferenceContextto manage inference state and memory.Initializes a
DynamicInferenceEnginewith CUDA graph support enabled.Processes batched requests with proper sampling parameters (temperature, top_k, top_p).
Returns outputs conforming to
GenerationOutputSpec.
Configuration#
To use the Megatron generation backend, configure your YAML file as follows:
policy:
megatron_cfg:
enabled: true
generation:
backend: megatron
max_new_tokens: 512
temperature: 1.0
top_p: 1.0
top_k: null
mcore_generation_config:
buffer_size_gb: 20 # Memory buffer size for inference context
buffer_guaranteed_fraction: 0.1 # Fraction of buffer guaranteed to be available for active requests
num_cuda_graphs: 16 # Number of CUDA graphs to pre-allocate
max_tokens: 16384 # Maximum number of tokens for inference
Configuration Parameters#
The mcore_generation_config section controls Megatron Core inference engine behavior:
buffer_size_gb: Total memory buffer size (in GB) allocated for the dynamic inference context. This determines how much GPU memory is reserved for KV caches and intermediate states. Keeping this higher will pull in more requests at once.
buffer_guaranteed_fraction: Fraction of the buffer that is guaranteed to be available (between 0.0 and 1.0). This helps to make sure that there is always some memory for active requests to complete.
num_cuda_graphs: Number of CUDA graphs to pre-allocate for different batch sizes. More graphs can improve performance by avoiding runtime graph capture, but consume more memory.
max_tokens: Maximum total number of tokens (across all requests) that can be processed simultaneously. This limits the maximum batch size and sequence length combinations. Increasing this might throw OOM depending on vocab size and buffer size allocated.
Usage Examples#
Using VLLM Backend#
To use the VLLM generation backend:
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.models.generation.interfaces import configure_generation_config
from nemo_rl.models.generation.vllm import VllmGeneration, VllmConfig
# Set up the configuration
config = VllmConfig(
model_name="Qwen/Qwen2.5-1.5B",
max_new_tokens=100,
temperature=0.7,
top_p=1,
top_k=None,
backend="vllm",
vllm_cfg={
"tensor_parallel_size": 1,
"gpu_memory_utilization": 0.8,
"max_model_len": 2048,
}
)
# Configure config with tokenizer
tokenizer = get_tokenizer(config["model_name"])
config = configure_generation_config(config, tokenizer)
# Initialize the cluster and generation backend
cluster = RayVirtualCluster(...)
generator = VllmGeneration(cluster, config)
# Prepare input data
input_data = BatchedDataDict(...)
# Generate text
generator.prepare_for_generation()
output = generator.generate(input_data, greedy=False)
generator.finish_generation()
Using Megatron Backend#
To use the Megatron generation backend, configure your YAML file:
policy:
model_name: meta-llama/Llama-3.2-1B-Instruct
megatron_cfg:
enabled: true
generation:
backend: megatron
max_new_tokens: 512
temperature: 1.0
top_p: 1.0
top_k: null
mcore_generation_config:
buffer_size_gb: 20
buffer_guaranteed_fraction: 0.1
num_cuda_graphs: 16
max_tokens: 16384
For a complete example, see:
Configuration:
examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yamlTest Script:
tests/functional/grpo_megatron_generation.sh
Extend with New Backends#
To add a new generation backend:
Create a new class that implements
GenerationInterface.Implement the required methods:
generate,prepare_for_generation, andfinish_generation.Ensure your implementation works with the standard
GenerationConfigandGenerationDatumSpecstructures.Register your backend with the system (if needed) to make it accessible.
This modular design allows for easy extension with new backends while maintaining a consistent interface for the rest of the system.