nemo_rl.models.policy.lm_policy#

Module Contents#

Classes#

Data#

API#

nemo_rl.models.policy.lm_policy.PathLike#

None

class nemo_rl.models.policy.lm_policy.Policy(
cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
config: nemo_rl.models.policy.PolicyConfig,
tokenizer: transformers.PreTrainedTokenizerBase,
name_prefix: str = 'lm_policy',
workers_per_node: Optional[Union[int, list[int]]] = None,
init_optimizer: bool = True,
weights_path: Optional[nemo_rl.models.policy.lm_policy.PathLike] = None,
optimizer_path: Optional[nemo_rl.models.policy.lm_policy.PathLike] = None,
init_reference_model: bool = True,
)[source]#

Bases: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, nemo_rl.models.generation.interfaces.GenerationInterface

init_collective(
ip: str,
port: int,
world_size: int,
) list[ray.ObjectRef][source]#

Initialize the collective communication.

get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec][source]#

Get the logprobs of the model for a data dict.

Returns:

a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec][source]#

Get the logprobs of the reference policy for a data dict.

Returns: Identical to get_logprobs.

train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any][source]#

Train the policy on a batch of data with a given loss function.

generate(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec][source]#

Generate a batch of data using the policy.

prepare_for_generation(
*args: Any,
**kwargs: Any,
) bool[source]#
prepare_for_training(*args: Any, **kwargs: Any) None[source]#
prepare_for_lp_inference(
*args: Any,
**kwargs: Any,
) None[source]#
finish_generation(*args: Any, **kwargs: Any) bool[source]#
finish_training(*args: Any, **kwargs: Any) None[source]#
prepare_refit_info() Optional[dict[str, Any]][source]#

Prepare the info for refit.

Returns:

A dictionary containing the info for refit.

Return type:

dict

prepare_weights_for_ipc(
_refit_buffer_size_gb: Optional[int] = None,
) list[list[str]][source]#

Prepare the weights for IPC.

Returns:

A list containing the keys of the parameters, which is grouped by size.

Return type:

list

get_weights_ipc_handles(keys: list[str]) dict[str, Any][source]#

Fetch weight IPC handles from all workers.

Returns:

A dictionary mapping device UUIDs to parameter IPC handles.

Return type:

dict

broadcast_weights_for_collective() list[ray.ObjectRef][source]#

Broadcast the weights for collective communication.

offload_before_refit() None[source]#

Offload the optimizer and buffers to the CPU.

offload_after_refit() None[source]#

Offload the optimizer and buffers to the CPU.

save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
) None[source]#

Save a checkpoint of the model.

shutdown() bool[source]#

Shut down all HF workers and clean up resources.

__del__() None[source]#

Shuts down the worker groups when the object is deleted or is garbage collected.

This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls worker_group.shutdown().

start_gpu_profiling() None[source]#

Start GPU profiling.

stop_gpu_profiling() None[source]#

Stop GPU profiling.