nemo_rl.models.generation.sglang.sglang_generation#
Module Contents#
Classes#
Data#
API#
- nemo_rl.models.generation.sglang.sglang_generation.TOP_K_THRESHOLD#
8000
- nemo_rl.models.generation.sglang.sglang_generation.TOP_P_THRESHOLD#
0.99
- nemo_rl.models.generation.sglang.sglang_generation.logger#
‘getLogger(…)’
- class nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration(
- cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
- config: nemo_rl.models.generation.sglang.config.SGLangConfig,
- name_prefix: str = 'sglang_policy',
- workers_per_node: Optional[Union[int, list[int]]] = None,
Bases:
nemo_rl.models.generation.interfaces.GenerationInterface- _allocate_bundles_for_servers(
- cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
- num_servers: int,
- gpus_per_server: int,
Allocate GPU bundles to each SGLang server.
Each server gets consecutive bundles within the same placement group (node). Ray will automatically set CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, …, gpus_per_server-1.
- Parameters:
cluster – The Ray virtual cluster
num_servers – Total number of SGLang servers to create
gpus_per_server – Number of GPUs each server needs
- Returns:
List of (node_idx, [bundle_indices]) tuples for each server
- init_collective(
- ip: str,
- port: int,
- world_size: int,
- *,
- train_world_size: int,
Initialize the collective communication.
TODO: if weight updates via NCCL are needed in the future.
- generate(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using SGLang.
- prepare_refit_info(state_dict_info: dict[str, Any]) None#
- update_weights_via_ipc_zmq() list[ray.ObjectRef]#
- update_weights_from_collective() list[ray.ObjectRef]#
- get_sglang_server_urls() list[str]#
Get base URLs of all SGLang servers.
- Returns:
//localhost:30000”, “http://localhost:30001”])
- Return type:
List of base URLs (e.g., [“http
- get_sglang_url_to_gpu_uuids() dict[str, list[str]]#
Get mapping from SGLang server URL to list of GPU UUIDs it uses.
- Returns:
Dict mapping server URL to list of GPU UUIDs e.g., {“http://localhost:30000”: [“GPU-aaa”, “GPU-bbb”], …}
- prepare_for_generation(
- *args: Any,
- **kwargs: Any,
Wake workers up for colocated inference.
- finish_generation(*args: Any, **kwargs: Any) bool#
Sleep workers and reset prefix cache.
- shutdown() bool#
Shut down all SGLang workers and clean up resources.
- __del__() None#
Shuts down the worker groups when the object is deleted or is garbage collected.
This is an extra safety net in case the user forgets to call shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls shutdown().
- invalidate_kv_cache() bool#
Invalidate KV cache before weight updates (Megatron-style).
This flushes the cache before weight updates to clear stale cache. Only primary workers (TP rank 0, model owners) will flush their cache.
- Returns:
True if all caches were flushed successfully, False otherwise
- Return type:
bool