core.transformer.utils#

Utilities for transformer layers.

Module Contents#

Functions#

get_linear_layer

Simple linear layer with weight initialization.

get_default_causal_mask

Return the causal upper triangular mask for softmax input.

get_sliding_window_causal_mask

Create the equivalent attention mask for SWA in [sq, skv] shape

attention_mask_func

gelu_impl

OpenAI’s gelu implementation.

openai_gelu

erf_gelu

make_sharded_tensors_for_checkpoint

Wraps tensors from transformer layers with ShardedTensor or ShardedObject.

make_sharded_object_for_checkpoint

Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).

_get_extra_state_offsets

Turns ShardedTensor offsets into offsets suitable for ShardedObject.

ensure_metadata_has_dp_cp_group

Ensure metadata is a dict containing dp_cp_group entry.

sharded_state_dict_default

Provides implementation for sharded_state_dict method for non-MegatronModules.

_init_sequence_parallel_cache

Initialize the cache of modules with sequence parallel attributes. Only needs to be called once, subsequent calls have no effect.

set_model_to_sequence_parallel

Set sequence parallel attributes for the model.

init_cuda_graph_cache

Initialize the cache of modules for cuda graphs

toggle_cuda_graphs

Toggle CUDA graph-related attributes for the model and its modules.

is_layer_window_attention

Data#

API#

core.transformer.utils.get_linear_layer(
rows,
columns,
init_method,
perform_initialization=True,
)#

Simple linear layer with weight initialization.

core.transformer.utils.get_default_causal_mask(sq: int) torch.Tensor#

Return the causal upper triangular mask for softmax input.

core.transformer.utils.get_sliding_window_causal_mask(sq, skv, window_size)#

Create the equivalent attention mask for SWA in [sq, skv] shape

core.transformer.utils.attention_mask_func(attention_scores, attention_mask)#
core.transformer.utils.gelu_impl(x)#

OpenAI’s gelu implementation.

core.transformer.utils.openai_gelu(x)#
core.transformer.utils.erf_gelu(x)#
core.transformer.utils.make_sharded_tensors_for_checkpoint(
state_dict: megatron.core.dist_checkpointing.mapping.StateDict,
prefix: str,
tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None,
sharded_offsets: Iterable[Tuple[int, int, int]] = (),
extra_state_suffix: str = '_extra_state',
tp_group: Optional[torch.distributed.ProcessGroup] = None,
dp_cp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Wraps tensors from transformer layers with ShardedTensor or ShardedObject.

For a given state_dict, wraps:

  • all _extra_states with ShardedObject

  • all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor

  • other values with DP sharded ShardedTensor

Parameters:
  • state_dict (StateDict) – state_dict to convert

  • prefix (str) – prefix appended to keys in final state dict

  • tensor_parallel_layers_axis_map (Dict[str, int], optional) – dict mapping layer names to the axis for TP sharding

  • sharded_offsets (Iterable[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related), passed along to ShardedTensor

  • extra_state_suffix (str, default = '_extra_state') – layers with this suffix will be wrapped with ShardedObject instead of ShardedTensor.

  • tp_group (Optional[torch.distributed.ProcessGroup], optional) – tensor parallel group. If None, defaults to parallel_state.get_tensor_model_parallel_group()

  • dp_cp_group (Optional[torch.distributed.ProcessGroup], optional) – data parallel group with context parallel. If None, defaults to parallel_state.get_data_parallel_group(with_context_parallel=True)

core.transformer.utils.make_sharded_object_for_checkpoint(
obj: Any,
key: str,
sharded_offsets: Iterable[Tuple[int, int, int]] = (),
replica_id: Union[None, int, Tuple[int, ...]] = None,
**kwargs,
)#

Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).

Parameters:
  • obj (object) – any object to be sharded

  • key (str) – unique identifier of the object

  • sharded_offsets (Iterable[Tuple[int, int, int]]) – offsets normally prepended to ShardedTensors, will be used as global offsets for ShardedObject

  • replica_id (Union[None, int, Tuple[int, ...]]) – replica id

core.transformer.utils._get_extra_state_offsets(
sharded_offsets: Iterable[Tuple[int, int, int]],
) Tuple[Tuple[int, ...], Tuple[int, ...]]#

Turns ShardedTensor offsets into offsets suitable for ShardedObject.

core.transformer.utils.ensure_metadata_has_dp_cp_group(
metadata: Optional[dict],
) dict#

Ensure metadata is a dict containing dp_cp_group entry.

If metadata is None, a new dict is returned with dp_cp_group set. If metadata is a dict and missing dp_cp_group, it is updated in-place. Otherwise, asserts that dp_cp_group exists.

core.transformer.utils.sharded_state_dict_default(
module: torch.nn.Module,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Provides implementation for sharded_state_dict method for non-MegatronModules.

Tries to call module.sharded_state_dict when possible, otherwise uses regular state dict and assumes tensors are replicated across TP and DP.

keep_vars=True is passed to module.state_dict so that optimizer states can be sharded later on.

Parameters:
  • module (torch.nn.Module) – module which sharded state dict we want to obtain

  • prefix (str) – prefix for the state dict keys

  • sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor

  • metadata (dict, optional) – metadata passed to module sharded_state_dict method

  • tp_group (Optional[torch.distributed.ProcessGroup], optional) – tensor parallel group. If None, defaults to parallel_state.get_tensor_model_parallel_group()

Returns:

dictionary of state dict keys mapped to ShardedTensors

Return type:

dict

core.transformer.utils._sequence_parallel_attr_cache#

None

core.transformer.utils._init_sequence_parallel_cache(model, exclude_modules)#

Initialize the cache of modules with sequence parallel attributes. Only needs to be called once, subsequent calls have no effect.

Parameters:
  • model – model to change sequence parallelism attributes

  • exclude_modules – Modules to exclude from changing sequence parallelism

core.transformer.utils.set_model_to_sequence_parallel(
model,
set_to=False,
exclude_modules=None,
)#

Set sequence parallel attributes for the model.

Parameters:
  • set_to – Value to set for sequence_parallel attributes

  • exclude_modules – Modules to exclude from changing sequence parallelism

core.transformer.utils.cuda_graph_attr_cache#

None

core.transformer.utils.init_cuda_graph_cache(model)#

Initialize the cache of modules for cuda graphs

core.transformer.utils.toggle_cuda_graphs(model, set_to='none', reset_cuda_graphs=True)#

Toggle CUDA graph-related attributes for the model and its modules.

Parameters:
  • set_to (str) – Value to set for CUDA graph-related attributes.

  • reset_cuda_graphs (bool) – If True, remake the CUDA graph; if False, use cached CUDA graph managers.

core.transformer.utils.is_layer_window_attention(
window_size: Optional[Tuple[int, int]],
window_attn_skip_freq: int | list,
layer_number: int,
) bool#