nemo_rl.models.generation.sglang.sglang_worker#

Module Contents#

Classes#

Functions#

_require_sglang

Import sglang lazily so test collection works without the optional extra.

Data#

API#

nemo_rl.models.generation.sglang.sglang_worker.logger#

‘getLogger(…)’

nemo_rl.models.generation.sglang.sglang_worker._require_sglang()#

Import sglang lazily so test collection works without the optional extra.

class nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker(
config: nemo_rl.models.generation.sglang.config.SGLangConfig,
bundle_indices: Optional[list[int]] = None,
fraction_of_gpus: float = 1.0,
seed: Optional[int] = None,
)#

Initialization

Initialize a SGLang worker for distributed inference.

Parameters:
  • config – Configuration dictionary for the policy

  • bundle_indices – List of local bundle indices for this server. The length of this list determines tp_size (number of GPUs per server). Only needed for the first worker in each server group (model owner).

  • fraction_of_gpus – Fraction of GPUs to use for this worker

  • seed – Random seed for initialization, if None, then defaults to the config’s seed

__repr__() str#

Customizes the actor’s prefix in the Ray logs.

This makes it easier to identify which worker is producing specific log messages.

static configure_worker(
num_gpus: int | float,
bundle_indices: Optional[tuple[int, list[int]]] = None,
) tuple[dict[str, Any], dict[str, str], dict[str, Any]]#

Provides complete worker configuration for SGLang server.

This method configures the worker based on bundle_indices which tells us how many GPUs this server should use.

Parameters:
  • num_gpus – Original GPU allocation for this worker based on the placement group

  • bundle_indices – Tuple of (node_idx, local_bundle_indices) for this server

Returns:

  • ‘resources’: Resource allocation (e.g., num_gpus)

  • ’env_vars’: Environment variables for this worker

  • ’init_kwargs’: Parameters to pass to init of the worker

Return type:

tuple with complete worker configuration

get_base_url() str#

Get the base URL of this SGLang server.

invalidate_kv_cache() bool#

Invalidate KV cache before weight updates (Megatron-style).

This flushes the cache before weight updates to clear stale cache. Uses retry logic to handle cases where there are pending requests.

Returns:

True if flush was successful, False otherwise

Return type:

bool

get_gpu_uuids() list[str]#

Get list of GPU UUIDs used by this SGLang server.

Returns:

List of GPU UUIDs (e.g., [“GPU-xxxxx”, “GPU-yyyyy”])

_merge_stop_strings(batch_stop_strings)#

Merge stop strings from config and batch.

Parameters:

batch_stop_strings – List of stop strings from batch (one per sample)

Returns:

List of merged stop strings (one per sample)

_build_sampling_params(
*,
greedy: bool,
stop_strings,
max_new_tokens: Optional[int] = None,
input_len: Optional[int] = None,
context_length: Optional[int] = None,
sample_index: Optional[int] = None,
) dict[str, Any]#

Build sampling parameters dictionary for SGLang API.

Parameters:
  • greedy – Whether to use greedy decoding (temperature=0.0)

  • stop_strings – Merged stop strings (not used here, handled per sample)

  • max_new_tokens – Override max_new_tokens from config if provided

  • input_len – Input length for this sample (used for context_length adjustment)

  • context_length – Maximum context length (if provided, adjusts max_new_tokens)

  • sample_index – Sample index (used for warning messages, 0-indexed)

Returns:

Dictionary of sampling parameters compatible with SGLang API

async _ensure_session()#
async _generate_single_sample(
input_ids: list[int],
sampling_params: dict[str, Any],
stop_string: Optional[str] = None,
) tuple[list[int], list[float]]#

Generate a single sample using SGLang API (async function).

Parameters:
  • input_ids – List of input token IDs (without padding)

  • sampling_params – Dictionary of sampling parameters (temperature, top_p, max_new_tokens, etc.)

  • stop_string – Optional stop string for this sample

Returns:

  • generated_tokens: List of generated token IDs

  • logprobs: List of log probabilities for generated tokens

Return type:

Tuple of (generated_tokens, logprobs)

async _generate_async(tasks)#

Execute generation tasks with concurrency control.

TEMP: Uses a semaphore to limit the number of concurrent requests per server, preventing server overload. A router based solution is preffered in the future.

_launch_server_process(
server_args: Any,
) multiprocessing.Process#

Launch the SGLang server process and wait for it to be ready.

generate(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]#

Generate a batch of data using SGLang generation.

Parameters:
  • data – BatchedDataDict containing input_ids and input_lengths tensors

  • greedy – Whether to use greedy decoding instead of sampling

Returns:

  • output_ids: input + generated token IDs with proper padding

  • logprobs: Log probabilities for tokens

  • generation_lengths: Lengths of each response

  • unpadded_sequence_lengths: Lengths of each input + generated sequence

Return type:

BatchedDataDict conforming to GenerationOutputSpec

sleep()#
wake_up(**kwargs)#
shutdown() bool#

Shutdown the SGLang server process and cleanup async resources.

Returns:

True if shutdown was successful, False otherwise

Return type:

bool

_make_request(endpoint: str, payload: Optional[dict] = None)#

Make a POST request to the specified endpoint with the given payload.

Parameters:
  • endpoint – The API endpoint to call

  • payload – The JSON payload to send (default: empty dict)

Returns:

The JSON response from the server