core.transformer.module#
Megatron Module.
Module Contents#
Classes#
Base Megatron module inhertied by all Models. |
|
Megatron module that can be used to capture and replay CUDA graphs. Now only TransformerLayer and MambaLayer are graphable. |
|
Float 16 Module. |
Functions#
Recursively applies a conversion function to values in nested data structures. |
|
Converts floating-point values from fp32 to fp16. |
|
Converts floating-point values from fp16 to fp32. |
Data#
API#
- core.transformer.module._FLOAT_TYPES#
()
- core.transformer.module._HALF_TYPES#
()
- core.transformer.module._BF16_TYPES#
()
- class core.transformer.module.MegatronModule(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
Bases:
torch.nn.ModuleBase Megatron module inhertied by all Models.
Megatron specific extensions of torch Module with support for pipelining
- Parameters:
config (TransformerConfig) – Transformer config
Initialization
- state_dict_for_save_checkpoint(
- prefix: str = '',
- keep_vars: bool = False,
Override state dict for saving checkpoints Use this function to override the state dict for saving checkpoints.
- Parameters:
prefix (str, optional) – description. Defaults to ‘’.
keep_vars (bool, optional) – description. Defaults to False.
- Returns:
description
- Return type:
type
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int]] = (),
- metadata: Optional[dict] = None,
Default implementation for sharded state dict for distributed checkpointing.
General definition of sharded_state_dict simply calls
sharded_state_dict_default(which call sharded_state_dict method if possible or a default implementation otherwise) recursively on all submodules.- Parameters:
prefix (str) – prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional) – metadata passed recursively to sharded_state_dict methods
- Returns:
dictionary of state dict keys mapped to ShardedTensors
- Return type:
dict
- set_is_first_microbatch()#
Sets the is_first_microbatch flag if it exists and config.fp8==True. When this flag is set, TE modules will update their fp8 parameter cache. If kitchen is being used, kitchen controls quantization level.
- set_symmetric_ar(set_to: Optional[str] = None) None#
Set symmetric all-reduce functionality across all eligible modules.
This method traverses the model’s module hierarchy to find all modules with the ‘symmetric_ar_type’ attribute, caches them, and then sets their ‘_symmetric_ar_cache’ attribute to the specified value to enable or disable symmetric all-reduce operations.
- Parameters:
set_to (Any, optional) – Value to set for the ‘symmetric_ar_type’ to.
['two_shot' (Allowed choices)
"one_shot"
"multimem_all_reduce"
None]
- class core.transformer.module.GraphableMegatronModule(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- vp_stage: Optional[int] = None,
Bases:
core.transformer.module.MegatronModuleMegatron module that can be used to capture and replay CUDA graphs. Now only TransformerLayer and MambaLayer are graphable.
- Parameters:
config (TransformerConfig) – Transformer config
Initialization
- get_layer_static_inputs(seq_length, micro_batch_size)#
Get the static inputs for the layer. We assume that the module has one hidden_states input, whose shape is inferred from the seq_length, micro_batch_size, and parallel config. Override this method if the module has other inputs.
- Returns:
A dictionary containing the static inputs for the layer.
- Return type:
Dict[str, torch.Tensor]
- setup_manual_hooks(make_hook_func)#
Set CUDA Graph manual hooks for the submodules that contain direct parameters and are covered by cudagraphs.
- _get_submodules_under_cudagraphs()#
Get the submodules that are covered by cudagraphs. Return a list that only contains the module itself if the whole layer is covered by cudagraphs.
- _te_cuda_graph_capture(*args, **kwargs)#
CUDA Graph capture for this layer using TE interface. Normally it’s just a forward pass if we’re capturing the entire layer.
- _te_cuda_graph_replay(*args, **kwargs)#
CUDA graph replay for this layer and microbatch
self.current_microbatchusing TE interface. TransformerEngine versions>=1.10 allow keyword arguments with CUDA graph. However, CUDA graph accepts only Tensor inputs. Hence, check if the arguments are all tensors.
- _get_te_cuda_graph_replay_args(*args, **kwargs)#
Helper function to get tensor arguments for TE CUDA graph.
- _should_call_local_cudagraph(*args, **kwargs)#
Check if we should call the local cudagraph path.
- _should_call_te_cudagraph(*args, **kwargs)#
Check if we should call the TE cudagraph path.
- __call__(*args, **kwargs)#
- core.transformer.module.conversion_helper(val, conversion)#
Recursively applies a conversion function to values in nested data structures.
- Parameters:
val – A single value or a nested structure (tuple/list) of values to convert
conversion (callable) – A function that performs the desired conversion on a single value
- Returns:
The converted value, maintaining the same nested structure as the input. If input is a single value, returns the converted value. If input is a tuple/list, returns a tuple/list with all elements converted.
- core.transformer.module.fp32_to_float16(val, float16_convertor)#
Converts floating-point values from fp32 to fp16.
- Parameters:
val – The value to convert. Can be a single number, a tuple, or a list.
float16_convertor – A function that converts a single fp32 value to fp16
- core.transformer.module.float16_to_fp32(val)#
Converts floating-point values from fp16 to fp32.
- Parameters:
val – The value to convert. Can be a single number, a tuple, or a list.
- class core.transformer.module.Float16Module(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- module: torch.nn.Module,
Bases:
core.transformer.module.MegatronModuleFloat 16 Module.
.. attribute:: config
Transformer config
- Type:
.. attribute:: fp16
Specifies if the model runs in fp16 mode
- Type:
bool
.. attribute:: bf16
Specifies if the model runs in bf16 mode
- Type:
bool
- Parameters:
config (TransformerConfig) – The transformer config used to initalize the model
Initialization
- set_input_tensor(input_tensor)#
- forward(*inputs, fp32_output=True, **kwargs)#
Execute the wrapped module in model precision and optionally upcast outputs to fp32.
On the first pipeline stage, positional/keyword tensor inputs are converted to the module precision (fp16 or bf16) before invoking the wrapped module. The wrapped module is called with the provided inputs and keyword arguments. On the last pipeline stage only, outputs are upcast to fp32 if
fp32_outputis True; otherwise, outputs are returned in the model precision (fp16/bf16).- Parameters:
*inputs – Positional inputs forwarded to the wrapped module (converted to fp16/bf16 on the pipeline first stage).
fp32_output (bool, keyword-only) – If True (default), upcast outputs to fp32 on the pipeline last stage. Has no effect on non-last stages. Set to False to keep outputs in model precision when downstream consumers expect half precision or to avoid extra casts.
**kwargs – Keyword arguments forwarded to the wrapped module.
- Returns:
The wrapped module’s outputs, potentially upcast to fp32 depending on pipeline stage and
fp32_output.
- state_dict(destination=None, prefix='', keep_vars=False)#
- state_dict_for_save_checkpoint(prefix='', keep_vars=False)#
Retrieve state_dict from the module being wrapped.
- sharded_state_dict(prefix='', *args, **kwargs)#
Retrieve sharded_state_dict from the module being wrapped.
- load_state_dict(state_dict, strict=True)#