nemo_rl.models.policy.hf_policy#

Module Contents#

Classes#

API#

class nemo_rl.models.policy.hf_policy.HfPolicy(
cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
config: nemo_rl.models.policy.PolicyConfig,
tokenizer: transformers.AutoTokenizer,
name_prefix: str = 'hf_policy',
workers_per_node: Optional[Union[int, List[int]]] = None,
init_optimizer: bool = True,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_reference_model: bool = True,
)[source]#

Bases: nemo_rl.models.interfaces.PolicyInterface, nemo_rl.models.generation.interfaces.GenerationInterface

_get_tied_worker_bundle_indices(cluster)[source]#

Calculate bundle indices for tensor parallel workers.

get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Get the logprobs of the model for a data dict.

Returns:

a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
micro_batch_size: int = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Get the logprobs of the reference policy for a data dict.

Returns: Identical to get_logprobs.

train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
)[source]#

Train the policy on a batch of data with a given loss function.

generate(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec][source]#

Generate a batch of data using the policy.

prepare_for_generation(*args, **kwargs)[source]#
prepare_for_training(*args, **kwargs)[source]#
prepare_for_lp_inference(*args, **kwargs)[source]#
finish_generation(*args, **kwargs)[source]#
finish_training(*args, **kwargs)[source]#
prepare_weights_for_ipc()[source]#

Prepare the weights for IPC.

Returns:

A dictionary containing the state_dict_info of the model.

Return type:

dict

get_weights_ipc_handles(key)[source]#

Fetch weight IPC handles from all workers.

Returns:

A dictionary mapping device UUIDs to parameter IPC handles.

Return type:

dict

offload_before_refit()[source]#

Offload the optimizer and buffers to the CPU.

offload_after_refit()[source]#

Offload the optimizer and buffers to the CPU.

save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
)[source]#

Save a checkpoint of the model.

shutdown() bool[source]#

Shut down all HF workers and clean up resources.

__del__()[source]#

Shuts down the worker groups when the object is deleted or is garbage collected.

This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls worker_group.shutdown().