core.recompute#
Module Contents#
Functions#
Forward method with activation checkpointing. |
Data#
API#
- core.recompute.te_checkpoint#
None
- core.recompute.checkpointed_forward(
- self: megatron.core.transformer.module.MegatronModule,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- context: Optional[torch.Tensor],
- context_mask: Optional[torch.Tensor],
- rotary_pos_emb: torch.Tensor,
- attention_bias: Optional[torch.Tensor],
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams,
- use_inner_quantization_context: bool,
- padding_mask: Optional[torch.Tensor] = None,
- extract_layer_indices: Optional[Set[int]] = None,
- layer_offset: int = 0,
Forward method with activation checkpointing.
- Parameters:
extract_layer_indices (Set[int], optional) – Global layer indices (across all pipeline stages) from which to extract features.
layer_offset (int) – The global layer offset for the current pipeline stage. Used to convert local layer indices to global indices when checking extract_layer_indices.
- Returns:
hidden_states tensor If extract_layer_indices is non-empty: (hidden_states, intermediate_hidden_states) tuple
- Return type:
If extract_layer_indices is empty