nemo_rl.models.generation.sglang.sglang_generation#

Module Contents#

Classes#

Data#

API#

nemo_rl.models.generation.sglang.sglang_generation.TOP_K_THRESHOLD#

8000

nemo_rl.models.generation.sglang.sglang_generation.TOP_P_THRESHOLD#

0.99

nemo_rl.models.generation.sglang.sglang_generation.logger#

‘getLogger(…)’

class nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration(
cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
config: nemo_rl.models.generation.sglang.config.SGLangConfig,
name_prefix: str = 'sglang_policy',
workers_per_node: Optional[Union[int, list[int]]] = None,
)#

Bases: nemo_rl.models.generation.interfaces.GenerationInterface

_allocate_bundles_for_servers(
cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
num_servers: int,
gpus_per_server: int,
) list[tuple[int, list[int]]]#

Allocate GPU bundles to each SGLang server.

Each server gets consecutive bundles within the same placement group (node). Ray will automatically set CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, …, gpus_per_server-1.

Parameters:
  • cluster – The Ray virtual cluster

  • num_servers – Total number of SGLang servers to create

  • gpus_per_server – Number of GPUs each server needs

Returns:

List of (node_idx, [bundle_indices]) tuples for each server

init_collective(
ip: str,
port: int,
world_size: int,
*,
train_world_size: int,
) list[ray.ObjectRef]#

Initialize the collective communication.

TODO: if weight updates via NCCL are needed in the future.

generate(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]#

Generate a batch of data using SGLang.

prepare_refit_info(state_dict_info: dict[str, Any]) None#
update_weights_via_ipc_zmq() list[ray.ObjectRef]#
update_weights_from_collective() list[ray.ObjectRef]#
get_sglang_server_urls() list[str]#

Get base URLs of all SGLang servers.

Returns:

//localhost:30000”, “http://localhost:30001”])

Return type:

List of base URLs (e.g., [“http

get_sglang_url_to_gpu_uuids() dict[str, list[str]]#

Get mapping from SGLang server URL to list of GPU UUIDs it uses.

Returns:

Dict mapping server URL to list of GPU UUIDs e.g., {“http://localhost:30000”: [“GPU-aaa”, “GPU-bbb”], …}

prepare_for_generation(
*args: Any,
**kwargs: Any,
) bool#

Wake workers up for colocated inference.

finish_generation(*args: Any, **kwargs: Any) bool#

Sleep workers and reset prefix cache.

shutdown() bool#

Shut down all SGLang workers and clean up resources.

__del__() None#

Shuts down the worker groups when the object is deleted or is garbage collected.

This is an extra safety net in case the user forgets to call shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls shutdown().

invalidate_kv_cache() bool#

Invalidate KV cache before weight updates (Megatron-style).

This flushes the cache before weight updates to clear stale cache. Only primary workers (TP rank 0, model owners) will flush their cache.

Returns:

True if all caches were flushed successfully, False otherwise

Return type:

bool