nemo_rl.models.megatron.draft.hidden_capture#
Module Contents#
Classes#
Container for hidden states captured from the policy model. |
|
Capture policy embeddings and auxiliary hidden states for Eagle training. |
Functions#
Pick the default auxiliary policy layers whose activations feed Eagle training. |
|
Return a no-op context unless draft training needs hidden-state capture for this step. |
Data#
API#
- nemo_rl.models.megatron.draft.hidden_capture.get_eagle3_aux_hidden_state_layers(num_layers: int) → tuple[int, ...]#
Pick the default auxiliary policy layers whose activations feed Eagle training.
- nemo_rl.models.megatron.draft.hidden_capture._DTYPE_TO_CODE#
None
- nemo_rl.models.megatron.draft.hidden_capture._CODE_TO_DTYPE#
None
- class nemo_rl.models.megatron.draft.hidden_capture.CapturedStates#
Container for hidden states captured from the policy model.
- hidden_states: Optional[torch.Tensor]#
None
- inputs_embeds: Optional[torch.Tensor]#
None
- class nemo_rl.models.megatron.draft.hidden_capture.HiddenStateCapture(
- model: torch.nn.Module,
- aux_layer_indices: Optional[Tuple[int, ...]] = None,
Capture policy embeddings and auxiliary hidden states for Eagle training.
Initialization
- _compute_local_layer_mapping() → None#
- _compute_layer_owner_map() → Dict[int, int]#
- _make_layer_output_hook(global_idx: int)#
- _make_embedding_hook()#
- register_hooks() → None#
- clear_hooks() → None#
- capture_context()#
- _assemble_local_states() → nemo_rl.models.megatron.draft.hidden_capture.CapturedStates#
- _owner_rank_for_global_layer(global_layer_idx: int) → int#
- static _send_tensor(
- tensor: torch.Tensor,
- dst_rank: int,
- group: torch.distributed.ProcessGroup,
- static _recv_tensor(
- src_rank: int,
- group: torch.distributed.ProcessGroup,
- device: torch.device,
- _gather_distributed() → nemo_rl.models.megatron.draft.hidden_capture.CapturedStates#
- get_captured_states() → nemo_rl.models.megatron.draft.hidden_capture.CapturedStates#
- nemo_rl.models.megatron.draft.hidden_capture.get_capture_context(
- model: torch.nn.Module,
- enabled: bool = False,
- aux_layer_indices: Optional[Tuple[int, ...]] = None,
Return a no-op context unless draft training needs hidden-state capture for this step.