core.transformer.utils#
Utilities for transformer layers.
Module Contents#
Functions#
Simple linear layer with weight initialization. |
|
Return the causal upper triangular mask for softmax input. |
|
Create the equivalent attention mask for SWA in [sq, skv] shape |
|
OpenAI’s gelu implementation. |
|
Wraps tensors from transformer layers with ShardedTensor or ShardedObject. |
|
Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group). |
|
Turns ShardedTensor offsets into offsets suitable for ShardedObject. |
|
Ensure |
|
Provides implementation for sharded_state_dict method for non-MegatronModules. |
|
Initialize the cache of modules with sequence parallel attributes. Only needs to be called once, subsequent calls have no effect. |
|
Set sequence parallel attributes for the model. |
|
Initialize the cache of modules for cuda graphs |
|
Toggle CUDA graph-related attributes for the model and its modules. |
|
Data#
API#
- core.transformer.utils.get_linear_layer(
- rows,
- columns,
- init_method,
- perform_initialization=True,
Simple linear layer with weight initialization.
- core.transformer.utils.get_default_causal_mask(sq: int) torch.Tensor#
Return the causal upper triangular mask for softmax input.
- core.transformer.utils.get_sliding_window_causal_mask(sq, skv, window_size)#
Create the equivalent attention mask for SWA in [sq, skv] shape
- core.transformer.utils.attention_mask_func(attention_scores, attention_mask)#
- core.transformer.utils.gelu_impl(x)#
OpenAI’s gelu implementation.
- core.transformer.utils.openai_gelu(x)#
- core.transformer.utils.erf_gelu(x)#
- core.transformer.utils.make_sharded_tensors_for_checkpoint(
- state_dict: megatron.core.dist_checkpointing.mapping.StateDict,
- prefix: str,
- tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None,
- sharded_offsets: Iterable[Tuple[int, int, int]] = (),
- extra_state_suffix: str = '_extra_state',
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
- dp_cp_group: Optional[torch.distributed.ProcessGroup] = None,
Wraps tensors from transformer layers with ShardedTensor or ShardedObject.
For a given
state_dict, wraps:all _extra_states with ShardedObject
all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor
other values with DP sharded ShardedTensor
- Parameters:
state_dict (StateDict) – state_dict to convert
prefix (str) – prefix appended to keys in final state dict
tensor_parallel_layers_axis_map (Dict[str, int], optional) – dict mapping layer names to the axis for TP sharding
sharded_offsets (Iterable[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related), passed along to ShardedTensor
extra_state_suffix (str, default = '_extra_state') – layers with this suffix will be wrapped with ShardedObject instead of ShardedTensor.
tp_group (Optional[torch.distributed.ProcessGroup], optional) – tensor parallel group. If None, defaults to parallel_state.get_tensor_model_parallel_group()
dp_cp_group (Optional[torch.distributed.ProcessGroup], optional) – data parallel group with context parallel. If None, defaults to parallel_state.get_data_parallel_group(with_context_parallel=True)
- core.transformer.utils.make_sharded_object_for_checkpoint(
- obj: Any,
- key: str,
- sharded_offsets: Iterable[Tuple[int, int, int]] = (),
- replica_id: Union[None, int, Tuple[int, ...]] = None,
- **kwargs,
Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).
- Parameters:
obj (object) – any object to be sharded
key (str) – unique identifier of the object
sharded_offsets (Iterable[Tuple[int, int, int]]) – offsets normally prepended to ShardedTensors, will be used as global offsets for ShardedObject
replica_id (Union[None, int, Tuple[int, ...]]) – replica id
- core.transformer.utils._get_extra_state_offsets(
- sharded_offsets: Iterable[Tuple[int, int, int]],
Turns ShardedTensor offsets into offsets suitable for ShardedObject.
- core.transformer.utils.ensure_metadata_has_dp_cp_group(
- metadata: Optional[dict],
Ensure
metadatais a dict containingdp_cp_groupentry.If
metadatais None, a new dict is returned withdp_cp_groupset. Ifmetadatais a dict and missingdp_cp_group, it is updated in-place. Otherwise, asserts thatdp_cp_groupexists.
- core.transformer.utils.sharded_state_dict_default(
- module: torch.nn.Module,
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int]] = (),
- metadata: Optional[dict] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Provides implementation for sharded_state_dict method for non-MegatronModules.
Tries to call
module.sharded_state_dictwhen possible, otherwise uses regular state dict and assumes tensors are replicated across TP and DP.keep_vars=Trueis passed to module.state_dict so that optimizer states can be sharded later on.- Parameters:
module (torch.nn.Module) – module which sharded state dict we want to obtain
prefix (str) – prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional) – metadata passed to module sharded_state_dict method
tp_group (Optional[torch.distributed.ProcessGroup], optional) – tensor parallel group. If None, defaults to parallel_state.get_tensor_model_parallel_group()
- Returns:
dictionary of state dict keys mapped to ShardedTensors
- Return type:
dict
- core.transformer.utils._sequence_parallel_attr_cache#
None
- core.transformer.utils._init_sequence_parallel_cache(model, exclude_modules)#
Initialize the cache of modules with sequence parallel attributes. Only needs to be called once, subsequent calls have no effect.
- Parameters:
model – model to change sequence parallelism attributes
exclude_modules – Modules to exclude from changing sequence parallelism
- core.transformer.utils.set_model_to_sequence_parallel(
- model,
- set_to=False,
- exclude_modules=None,
Set sequence parallel attributes for the model.
- Parameters:
set_to – Value to set for sequence_parallel attributes
exclude_modules – Modules to exclude from changing sequence parallelism
- core.transformer.utils.cuda_graph_attr_cache#
None
- core.transformer.utils.init_cuda_graph_cache(model)#
Initialize the cache of modules for cuda graphs
- core.transformer.utils.toggle_cuda_graphs(model, set_to='none', reset_cuda_graphs=True)#
Toggle CUDA graph-related attributes for the model and its modules.
- Parameters:
set_to (str) – Value to set for CUDA graph-related attributes.
reset_cuda_graphs (bool) – If True, remake the CUDA graph; if False, use cached CUDA graph managers.
- core.transformer.utils.is_layer_window_attention(
- window_size: Optional[Tuple[int, int]],
- window_attn_skip_freq: int | list,
- layer_number: int,