nemo_rl.models.policy.lm_policy#
Module Contents#
Classes#
Data#
API#
- nemo_rl.models.policy.lm_policy.PathLike#
None
- class nemo_rl.models.policy.lm_policy.Policy(
- cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: transformers.PreTrainedTokenizerBase,
- name_prefix: str = 'lm_policy',
- workers_per_node: Optional[Union[int, list[int]]] = None,
- init_optimizer: bool = True,
- weights_path: Optional[nemo_rl.models.policy.lm_policy.PathLike] = None,
- optimizer_path: Optional[nemo_rl.models.policy.lm_policy.PathLike] = None,
- init_reference_model: bool = True,
- processor: Optional[transformers.AutoProcessor] = None,
Bases:
nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,nemo_rl.models.generation.interfaces.GenerationInterface- init_collective(
- ip: str,
- port: int,
- world_size: int,
- *,
- train_world_size: int,
Initialize the collective communication.
- get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get the logprobs of the model for a data dict.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- micro_batch_size: Optional[int] = None,
Get the logprobs of the reference policy for a data dict.
Returns: Identical to get_logprobs.
- get_topk_logits(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- k: int,
- micro_batch_size: Optional[int] = None,
Dispatch get_topk_logits to workers (no CP/packed support initially).
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- generate(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using the policy.
- score(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Score a batch of data using the policy.
- prepare_for_generation(
- *args: Any,
- **kwargs: Any,
- prepare_for_training(*args: Any, **kwargs: Any) None#
- prepare_for_lp_inference(
- *args: Any,
- **kwargs: Any,
- finish_generation(*args: Any, **kwargs: Any) bool#
- finish_training(*args: Any, **kwargs: Any) None#
- prepare_refit_info() Optional[dict[str, Any]]#
Prepare the info for refit.
- Returns:
A dictionary containing the info for refit.
- Return type:
dict
- get_free_memory_bytes() int#
Get the available free memory.
- stream_weights_via_ipc_zmq(
- buffer_size_bytes: int,
Send the weights for IPC handles via ZMQ socket.
- broadcast_weights_for_collective() list[ray.ObjectRef]#
Broadcast the weights for collective communication.
- offload_before_refit() None#
Offload the optimizer and buffers to the CPU.
- offload_after_refit() None#
Offload the optimizer and buffers to the CPU.
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- tokenizer_path: Optional[str] = None,
- checkpointing_cfg: Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None,
Save a checkpoint of the model.
- shutdown() bool#
Shut down all HF workers and clean up resources.
- __del__() None#
Shuts down the worker groups when the object is deleted or is garbage collected.
This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls worker_group.shutdown().
- start_gpu_profiling() None#
Start GPU profiling.
- stop_gpu_profiling() None#
Stop GPU profiling.
- print_node_ip_and_gpu_id() list[tuple[str, int]]#
Print the node IP and GPU ID of the current worker.