nemo_rl.models.megatron.draft.hidden_capture#

Module Contents#

Classes#

CapturedStates

Container for hidden states captured from the policy model.

HiddenStateCapture

Capture policy embeddings and auxiliary hidden states for Eagle training.

Functions#

get_eagle3_aux_hidden_state_layers

Pick the default auxiliary policy layers whose activations feed Eagle training.

get_capture_context

Return a no-op context unless draft training needs hidden-state capture for this step.

Data#

API#

nemo_rl.models.megatron.draft.hidden_capture.get_eagle3_aux_hidden_state_layers(num_layers: int) tuple[int, ...]#

Pick the default auxiliary policy layers whose activations feed Eagle training.

nemo_rl.models.megatron.draft.hidden_capture._DTYPE_TO_CODE#

None

nemo_rl.models.megatron.draft.hidden_capture._CODE_TO_DTYPE#

None

class nemo_rl.models.megatron.draft.hidden_capture.CapturedStates#

Container for hidden states captured from the policy model.

hidden_states: Optional[torch.Tensor]#

None

inputs_embeds: Optional[torch.Tensor]#

None

class nemo_rl.models.megatron.draft.hidden_capture.HiddenStateCapture(
model: torch.nn.Module,
aux_layer_indices: Optional[Tuple[int, ...]] = None,
)#

Capture policy embeddings and auxiliary hidden states for Eagle training.

Initialization

_compute_local_layer_mapping() None#
_compute_layer_owner_map() Dict[int, int]#
_make_layer_output_hook(global_idx: int)#
_make_embedding_hook()#
register_hooks() None#
clear_hooks() None#
capture_context()#
_assemble_local_states() nemo_rl.models.megatron.draft.hidden_capture.CapturedStates#
_owner_rank_for_global_layer(global_layer_idx: int) int#
static _send_tensor(
tensor: torch.Tensor,
dst_rank: int,
group: torch.distributed.ProcessGroup,
) None#
static _recv_tensor(
src_rank: int,
group: torch.distributed.ProcessGroup,
device: torch.device,
) torch.Tensor#
_gather_distributed() nemo_rl.models.megatron.draft.hidden_capture.CapturedStates#
get_captured_states() nemo_rl.models.megatron.draft.hidden_capture.CapturedStates#
nemo_rl.models.megatron.draft.hidden_capture.get_capture_context(
model: torch.nn.Module,
enabled: bool = False,
aux_layer_indices: Optional[Tuple[int, ...]] = None,
) Tuple[ContextManager, Optional[nemo_rl.models.megatron.draft.hidden_capture.HiddenStateCapture]]#

Return a no-op context unless draft training needs hidden-state capture for this step.