core.recompute#

Module Contents#

Functions#

checkpointed_forward

Forward method with activation checkpointing.

Data#

API#

core.recompute.te_checkpoint#

None

core.recompute.checkpointed_forward(
self: megatron.core.transformer.module.MegatronModule,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: Optional[torch.Tensor],
context_mask: Optional[torch.Tensor],
rotary_pos_emb: torch.Tensor,
attention_bias: Optional[torch.Tensor],
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams,
use_inner_quantization_context: bool,
padding_mask: Optional[torch.Tensor] = None,
extract_layer_indices: Optional[Set[int]] = None,
layer_offset: int = 0,
) Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]#

Forward method with activation checkpointing.

Parameters:
  • extract_layer_indices (Set[int], optional) – Global layer indices (across all pipeline stages) from which to extract features.

  • layer_offset (int) – The global layer offset for the current pipeline stage. Used to convert local layer indices to global indices when checking extract_layer_indices.

Returns:

hidden_states tensor If extract_layer_indices is non-empty: (hidden_states, intermediate_hidden_states) tuple

Return type:

If extract_layer_indices is empty