nemo_rl.models.policy.hf_policy
#
Module Contents#
Classes#
API#
- class nemo_rl.models.policy.hf_policy.HfPolicy(
- cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: transformers.AutoTokenizer,
- name_prefix: str = 'hf_policy',
- workers_per_node: Optional[Union[int, List[int]]] = None,
- init_optimizer: bool = True,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_reference_model: bool = True,
Bases:
nemo_rl.models.interfaces.PolicyInterface
,nemo_rl.models.generation.interfaces.GenerationInterface
- _get_tied_worker_bundle_indices(cluster)[source]#
Calculate bundle indices for tensor parallel workers.
- get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get the logprobs of the model for a data dict.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- micro_batch_size: int = None,
Get the logprobs of the reference policy for a data dict.
Returns: Identical to get_logprobs.
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- generate(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using the policy.
- prepare_weights_for_ipc()[source]#
Prepare the weights for IPC.
- Returns:
A dictionary containing the state_dict_info of the model.
- Return type:
dict
- get_weights_ipc_handles(key)[source]#
Fetch weight IPC handles from all workers.
- Returns:
A dictionary mapping device UUIDs to parameter IPC handles.
- Return type:
dict
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- tokenizer_path: Optional[str] = None,
Save a checkpoint of the model.
- __del__()[source]#
Shuts down the worker groups when the object is deleted or is garbage collected.
This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls worker_group.shutdown().