nemo_rl.models.generation.sglang.sglang_worker#
Module Contents#
Classes#
Functions#
Import |
Data#
API#
- nemo_rl.models.generation.sglang.sglang_worker.logger#
‘getLogger(…)’
- nemo_rl.models.generation.sglang.sglang_worker._require_sglang()#
Import
sglanglazily so test collection works without the optional extra.
- class nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker(
- config: nemo_rl.models.generation.sglang.config.SGLangConfig,
- bundle_indices: Optional[list[int]] = None,
- fraction_of_gpus: float = 1.0,
- seed: Optional[int] = None,
Initialization
Initialize a SGLang worker for distributed inference.
- Parameters:
config – Configuration dictionary for the policy
bundle_indices – List of local bundle indices for this server. The length of this list determines tp_size (number of GPUs per server). Only needed for the first worker in each server group (model owner).
fraction_of_gpus – Fraction of GPUs to use for this worker
seed – Random seed for initialization, if None, then defaults to the config’s seed
- __repr__() str#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- static configure_worker(
- num_gpus: int | float,
- bundle_indices: Optional[tuple[int, list[int]]] = None,
Provides complete worker configuration for SGLang server.
This method configures the worker based on bundle_indices which tells us how many GPUs this server should use.
- Parameters:
num_gpus – Original GPU allocation for this worker based on the placement group
bundle_indices – Tuple of (node_idx, local_bundle_indices) for this server
- Returns:
‘resources’: Resource allocation (e.g., num_gpus)
’env_vars’: Environment variables for this worker
’init_kwargs’: Parameters to pass to init of the worker
- Return type:
tuple with complete worker configuration
- get_base_url() str#
Get the base URL of this SGLang server.
- invalidate_kv_cache() bool#
Invalidate KV cache before weight updates (Megatron-style).
This flushes the cache before weight updates to clear stale cache. Uses retry logic to handle cases where there are pending requests.
- Returns:
True if flush was successful, False otherwise
- Return type:
bool
- get_gpu_uuids() list[str]#
Get list of GPU UUIDs used by this SGLang server.
- Returns:
List of GPU UUIDs (e.g., [“GPU-xxxxx”, “GPU-yyyyy”])
- _merge_stop_strings(batch_stop_strings)#
Merge stop strings from config and batch.
- Parameters:
batch_stop_strings – List of stop strings from batch (one per sample)
- Returns:
List of merged stop strings (one per sample)
- _build_sampling_params(
- *,
- greedy: bool,
- stop_strings,
- max_new_tokens: Optional[int] = None,
- input_len: Optional[int] = None,
- context_length: Optional[int] = None,
- sample_index: Optional[int] = None,
Build sampling parameters dictionary for SGLang API.
- Parameters:
greedy – Whether to use greedy decoding (temperature=0.0)
stop_strings – Merged stop strings (not used here, handled per sample)
max_new_tokens – Override max_new_tokens from config if provided
input_len – Input length for this sample (used for context_length adjustment)
context_length – Maximum context length (if provided, adjusts max_new_tokens)
sample_index – Sample index (used for warning messages, 0-indexed)
- Returns:
Dictionary of sampling parameters compatible with SGLang API
- async _ensure_session()#
- async _generate_single_sample(
- input_ids: list[int],
- sampling_params: dict[str, Any],
- stop_string: Optional[str] = None,
Generate a single sample using SGLang API (async function).
- Parameters:
input_ids – List of input token IDs (without padding)
sampling_params – Dictionary of sampling parameters (temperature, top_p, max_new_tokens, etc.)
stop_string – Optional stop string for this sample
- Returns:
generated_tokens: List of generated token IDs
logprobs: List of log probabilities for generated tokens
- Return type:
Tuple of (generated_tokens, logprobs)
- async _generate_async(tasks)#
Execute generation tasks with concurrency control.
TEMP: Uses a semaphore to limit the number of concurrent requests per server, preventing server overload. A router based solution is preffered in the future.
- _launch_server_process(
- server_args: Any,
Launch the SGLang server process and wait for it to be ready.
- generate(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using SGLang generation.
- Parameters:
data – BatchedDataDict containing input_ids and input_lengths tensors
greedy – Whether to use greedy decoding instead of sampling
- Returns:
output_ids: input + generated token IDs with proper padding
logprobs: Log probabilities for tokens
generation_lengths: Lengths of each response
unpadded_sequence_lengths: Lengths of each input + generated sequence
- Return type:
BatchedDataDict conforming to GenerationOutputSpec
- sleep()#
- wake_up(**kwargs)#
- shutdown() bool#
Shutdown the SGLang server process and cleanup async resources.
- Returns:
True if shutdown was successful, False otherwise
- Return type:
bool
- _make_request(endpoint: str, payload: Optional[dict] = None)#
Make a POST request to the specified endpoint with the given payload.
- Parameters:
endpoint – The API endpoint to call
payload – The JSON payload to send (default: empty dict)
- Returns:
The JSON response from the server