core.transformer.module#

Megatron Module.

Module Contents#

Classes#

MegatronModule

Base Megatron module inhertied by all Models.

GraphableMegatronModule

Megatron module that can be used to capture and replay CUDA graphs. Now only TransformerLayer and MambaLayer are graphable.

Float16Module

Float 16 Module.

Functions#

param_is_not_shared

conversion_helper

Recursively applies a conversion function to values in nested data structures.

fp32_to_float16

Converts floating-point values from fp32 to fp16.

float16_to_fp32

Converts floating-point values from fp16 to fp32.

Data#

API#

core.transformer.module._FLOAT_TYPES#

()

core.transformer.module._HALF_TYPES#

()

core.transformer.module._BF16_TYPES#

()

core.transformer.module.param_is_not_shared(param)#
class core.transformer.module.MegatronModule(
config: megatron.core.transformer.transformer_config.TransformerConfig,
)#

Bases: torch.nn.Module

Base Megatron module inhertied by all Models.

Megatron specific extensions of torch Module with support for pipelining

Parameters:

config (TransformerConfig) – Transformer config

Initialization

state_dict_for_save_checkpoint(
prefix: str = '',
keep_vars: bool = False,
)#

Override state dict for saving checkpoints Use this function to override the state dict for saving checkpoints.

Parameters:
  • prefix (str, optional) – description. Defaults to ‘’.

  • keep_vars (bool, optional) – description. Defaults to False.

Returns:

description

Return type:

type

sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Default implementation for sharded state dict for distributed checkpointing.

General definition of sharded_state_dict simply calls sharded_state_dict_default (which call sharded_state_dict method if possible or a default implementation otherwise) recursively on all submodules.

Parameters:
  • prefix (str) – prefix for the state dict keys

  • sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor

  • metadata (dict, optional) – metadata passed recursively to sharded_state_dict methods

Returns:

dictionary of state dict keys mapped to ShardedTensors

Return type:

dict

set_is_first_microbatch()#

Sets the is_first_microbatch flag if it exists and config.fp8==True. When this flag is set, TE modules will update their fp8 parameter cache. If kitchen is being used, kitchen controls quantization level.

set_symmetric_ar(set_to: Optional[str] = None) None#

Set symmetric all-reduce functionality across all eligible modules.

This method traverses the model’s module hierarchy to find all modules with the ‘symmetric_ar_type’ attribute, caches them, and then sets their ‘_symmetric_ar_cache’ attribute to the specified value to enable or disable symmetric all-reduce operations.

Parameters:
  • set_to (Any, optional) – Value to set for the ‘symmetric_ar_type’ to.

  • ['two_shot' (Allowed choices)

  • "one_shot"

  • "multimem_all_reduce"

  • None]

class core.transformer.module.GraphableMegatronModule(
config: megatron.core.transformer.transformer_config.TransformerConfig,
vp_stage: Optional[int] = None,
)#

Bases: core.transformer.module.MegatronModule

Megatron module that can be used to capture and replay CUDA graphs. Now only TransformerLayer and MambaLayer are graphable.

Parameters:

config (TransformerConfig) – Transformer config

Initialization

get_layer_static_inputs(seq_length, micro_batch_size)#

Get the static inputs for the layer. We assume that the module has one hidden_states input, whose shape is inferred from the seq_length, micro_batch_size, and parallel config. Override this method if the module has other inputs.

Returns:

A dictionary containing the static inputs for the layer.

Return type:

Dict[str, torch.Tensor]

setup_manual_hooks(make_hook_func)#

Set CUDA Graph manual hooks for the submodules that contain direct parameters and are covered by cudagraphs.

_get_submodules_under_cudagraphs()#

Get the submodules that are covered by cudagraphs. Return a list that only contains the module itself if the whole layer is covered by cudagraphs.

_te_cuda_graph_capture(*args, **kwargs)#

CUDA Graph capture for this layer using TE interface. Normally it’s just a forward pass if we’re capturing the entire layer.

_te_cuda_graph_replay(*args, **kwargs)#

CUDA graph replay for this layer and microbatch self.current_microbatch using TE interface. TransformerEngine versions>=1.10 allow keyword arguments with CUDA graph. However, CUDA graph accepts only Tensor inputs. Hence, check if the arguments are all tensors.

_get_te_cuda_graph_replay_args(*args, **kwargs)#

Helper function to get tensor arguments for TE CUDA graph.

_should_call_local_cudagraph(*args, **kwargs)#

Check if we should call the local cudagraph path.

_should_call_te_cudagraph(*args, **kwargs)#

Check if we should call the TE cudagraph path.

__call__(*args, **kwargs)#
core.transformer.module.conversion_helper(val, conversion)#

Recursively applies a conversion function to values in nested data structures.

Parameters:
  • val – A single value or a nested structure (tuple/list) of values to convert

  • conversion (callable) – A function that performs the desired conversion on a single value

Returns:

The converted value, maintaining the same nested structure as the input. If input is a single value, returns the converted value. If input is a tuple/list, returns a tuple/list with all elements converted.

core.transformer.module.fp32_to_float16(val, float16_convertor)#

Converts floating-point values from fp32 to fp16.

Parameters:
  • val – The value to convert. Can be a single number, a tuple, or a list.

  • float16_convertor – A function that converts a single fp32 value to fp16

core.transformer.module.float16_to_fp32(val)#

Converts floating-point values from fp16 to fp32.

Parameters:

val – The value to convert. Can be a single number, a tuple, or a list.

class core.transformer.module.Float16Module(
config: megatron.core.transformer.transformer_config.TransformerConfig,
module: torch.nn.Module,
)#

Bases: core.transformer.module.MegatronModule

Float 16 Module.

.. attribute:: config

Transformer config

Type:

TransformerConfig

.. attribute:: fp16

Specifies if the model runs in fp16 mode

Type:

bool

.. attribute:: bf16

Specifies if the model runs in bf16 mode

Type:

bool

Parameters:

config (TransformerConfig) – The transformer config used to initalize the model

Initialization

set_input_tensor(input_tensor)#
forward(*inputs, fp32_output=True, **kwargs)#

Execute the wrapped module in model precision and optionally upcast outputs to fp32.

On the first pipeline stage, positional/keyword tensor inputs are converted to the module precision (fp16 or bf16) before invoking the wrapped module. The wrapped module is called with the provided inputs and keyword arguments. On the last pipeline stage only, outputs are upcast to fp32 if fp32_output is True; otherwise, outputs are returned in the model precision (fp16/bf16).

Parameters:
  • *inputs – Positional inputs forwarded to the wrapped module (converted to fp16/bf16 on the pipeline first stage).

  • fp32_output (bool, keyword-only) – If True (default), upcast outputs to fp32 on the pipeline last stage. Has no effect on non-last stages. Set to False to keep outputs in model precision when downstream consumers expect half precision or to avoid extra casts.

  • **kwargs – Keyword arguments forwarded to the wrapped module.

Returns:

The wrapped module’s outputs, potentially upcast to fp32 depending on pipeline stage and fp32_output.

state_dict(destination=None, prefix='', keep_vars=False)#
state_dict_for_save_checkpoint(prefix='', keep_vars=False)#

Retrieve state_dict from the module being wrapped.

sharded_state_dict(prefix='', *args, **kwargs)#

Retrieve sharded_state_dict from the module being wrapped.

load_state_dict(state_dict, strict=True)#