fusions package#

This package provides modules that provide commonly fused operations. Fusing operations improves compute efficiency by increasing the amount of work done each time a tensor is read from memory. To perform the fusion, modules in this either rely on PyTorch functionality for doing just-in-time compilation (i.e. torch.jit.script in older PyTorch versions of torch.compile in recent versions), or call into custom kernels in external libraries such as Apex or TransformerEngine.

Submodules#

fusions.fused_bias_dropout module#

This module uses PyTorch JIT to fuse the bias add and dropout operations. Since dropout is not used during inference, different functions are used when in train mode and when in inference mode.

core.fusions.fused_bias_dropout.bias_dropout_add_fused_inference(
x_with_bias: Tuple[torch.Tensor, torch.Tensor | None],
residual: torch.Tensor,
prob: float,
) torch.Tensor#
core.fusions.fused_bias_dropout.bias_dropout_add_fused_train(
x_with_bias: Tuple[torch.Tensor, torch.Tensor | None],
residual: torch.Tensor,
prob: float,
) torch.Tensor#
core.fusions.fused_bias_dropout.bias_dropout_add_unfused(training)#
core.fusions.fused_bias_dropout.get_bias_dropout_add(training, fused)#

fusions.fused_bias_gelu module#

This module uses PyTorch JIT to fuse the bias add and GeLU nonlinearity operations.

class core.fusions.fused_bias_gelu.GeLUFunction(*args: Any, **kwargs: Any)#

Bases: Function

classmethod apply(*args, **kwargs)#
static backward(ctx, grad_output)#
static forward(ctx, input, bias)#
core.fusions.fused_bias_gelu.bias_gelu(bias, y)#
core.fusions.fused_bias_gelu.bias_gelu_back(g, bias, y)#
core.fusions.fused_bias_gelu.bias_gelu_impl(*args, **kwargs)#

fusions.fused_layer_norm module#

This module provides a wrapper around various fused LayerNorm implementation in Apex.

class core.fusions.fused_layer_norm.FusedLayerNorm(*args: Any, **kwargs: Any)#

Bases: Module

Layer Norm, fused into a single CUDA kernel.

Parameters:
  • hidden_size (int) – Transformer hidden dimension.

  • eps (float) – Epsilon added to denominator, for numerical stability.

  • persist_layer_norm (bool) – Use persistent fused layer norm kernel.

  • Please (This kernel supports only a set of hidden sizes.)

  • supported. (check persist_ln_hidden_sizes if your hidden size is)

  • zero_centered_gamma (bool) – Adjust LayerNorm weights such that they are

  • stability. (centered around zero. This improves numerical)

  • config (TransformerConfig) – Transformer config. Include to match custom

  • interfaces. (layer norm)

  • normalization (str) – Normalization type, used for Transformer Engine.

  • here. (Must equal 'LayerNorm')

forward(input: torch.Tensor) torch.Tensor#
reset_parameters()#

fusions.fused_softmax module#

This module provides wrappers around variations of Softmax in Apex.

class core.fusions.fused_softmax.FusedScaleMaskSoftmax(*args: Any, **kwargs: Any)#

Bases: Module

fused operation: scaling + mask + softmax

Parameters:
  • input_in_fp16 – flag to indicate if input in fp16 data format.

  • input_in_bf16 – flag to indicate if input in bf16 data format.

  • attn_mask_type – attention mask type (pad or causal)

  • scaled_masked_softmax_fusion – flag to indicate user want to use softmax fusion

  • mask_func – mask function to be applied.

  • softmax_in_fp32 – if true, softmax in performed at fp32 precision.

  • scale – scaling factor used in input tensor scaling.

forward(
input: torch.Tensor,
mask: torch.Tensor | None,
softmax_offset: torch.Tensor | None = None,
)#

Forward pass of softmax with masked input.

In case attn_mask_type is causal the mask is generated and None can be passed. A user-defined mask is only needed when attn_mask_type is not causal.

forward_fused_softmax(input, mask)#

Compute softmax using fused CUDA kernels when available.

Parameters:
  • input (torch.Tensor) – Attention scores of shape [b, np, sq, sk].

  • mask (Optional[torch.Tensor]) – Optional mask for non-causal attention.

Returns:

Attention probabilities of shape [b, np, sq, sk].

Return type:

torch.Tensor

forward_torch_softmax(
input,
mask,
softmax_offset=None,
)#

Fallback PyTorch implementation for masked softmax.

Applies optional scaling, constructs a causal or sliding-window mask if needed, applies the mask, and computes softmax in PyTorch. Optionally casts back to float16/bfloat16 when requested.

Parameters:
  • input (torch.Tensor) – Attention scores of shape [b, np, sq, sk].

  • mask (Optional[torch.Tensor]) – Optional additive mask.

Returns:

Attention probabilities of shape [b, np, sq, sk].

Return type:

torch.Tensor

static get_batch_per_block(sq, sk, b, np)#

Return CUDA kernel’s batch-per-block parameter for masked softmax.

Parameters:
  • sq (int) – Query sequence length.

  • sk (int) – Key sequence length.

  • b (int) – Batch size.

  • np (int) – Number of attention heads per tensor-parallel partition.

Returns:

Batch-per-block value as computed by the CUDA extension.

Return type:

int

is_kernel_available(mask, b, np, sq, sk)#

Check whether the fused CUDA kernel can be used for the given shapes and settings.

Parameters:
  • mask (Optional[torch.Tensor]) – Attention mask or None.

  • b (int) – Batch size.

  • np (int) – Number of attention heads per tensor-parallel partition.

  • sq (int) – Query sequence length.

  • sk (int) – Key sequence length.

Returns:

True if the fused kernel constraints are satisfied; otherwise False.

Return type:

bool

class core.fusions.fused_softmax.ScaledMaskedSoftmax(*args: Any, **kwargs: Any)#

Bases: Function

Fused operation which performs following three operations in sequence 1. Scale the tensor. 2. Apply the mask. 3. Perform softmax.

static backward(ctx, output_grads)#

Backward pass for scaled masked softmax.

Parameters:
  • ctx – Autograd context containing saved tensors from forward.

  • output_grads (torch.Tensor) – Upstream gradients matching forward output shape.

Returns:

Gradient w.r.t inputs; None for mask and scale.

Return type:

Tuple[torch.Tensor, None, None]

static forward(ctx, inputs, mask, scale)#

Forward pass for scaled masked softmax.

Parameters:
  • ctx – Autograd context used to stash tensors for backward.

  • inputs (torch.Tensor) – Input tensor of shape [b, np, sq, sk].

  • mask (torch.Tensor) – Additive mask broadcastable to inputs.

  • scale (float) – Scaling factor applied prior to softmax.

Returns:

Softmax results after applying scale and mask.

Return type:

torch.Tensor

class core.fusions.fused_softmax.ScaledSoftmax(*args: Any, **kwargs: Any)#

Bases: Function

Fused operation which performs following two operations in sequence 1. Scale the tensor. 2. Perform softmax.

static backward(ctx, output_grads)#

Backward pass for scaled softmax (no mask).

Parameters:
  • ctx – Autograd context containing saved tensors from forward.

  • output_grads (torch.Tensor) – Upstream gradients matching forward output shape.

Returns:

Gradient w.r.t inputs; None for unused args.

Return type:

Tuple[torch.Tensor, None, None]

static forward(ctx, inputs, scale)#

Forward pass for scaled softmax (no mask).

Parameters:
  • ctx – Autograd context used to stash tensors for backward.

  • inputs (torch.Tensor) – Input tensor of shape [b, np, sq, sk] or [attn_batches, sq, sk].

  • scale (float) – Scaling factor applied prior to softmax.

Returns:

Softmax results after applying scale.

Return type:

torch.Tensor

class core.fusions.fused_softmax.ScaledUpperTriangMaskedSoftmax(*args: Any, **kwargs: Any)#

Bases: Function

Fused operation which performs following three operations in sequence 1. Scale the tensor. 2. Apply upper triangular mask (typically used in gpt models). 3. Perform softmax.

static backward(ctx, output_grads)#

Backward pass for scaled upper-triangular masked softmax.

Parameters:
  • ctx – Autograd context containing saved tensors from forward.

  • output_grads (torch.Tensor) – Upstream gradients matching forward output shape.

Returns:

Gradient with respect to inputs and None for scale.

Return type:

Tuple[torch.Tensor, None]

static forward(ctx, inputs, scale)#

Forward pass for scaled upper-triangular masked softmax.

Parameters:
  • ctx – Autograd context used to stash tensors for backward.

  • inputs (torch.Tensor) – Input tensor of shape [attn_batches, sq, sk].

  • scale (float) – Scaling factor applied prior to softmax.

Returns:

Softmax results after applying scale and causal upper-triangular mask.

Return type:

torch.Tensor

class core.fusions.fused_softmax.SoftmaxOne(*args: Any, **kwargs: Any)#

Bases: Module

Softmax-off-by-one function as introduced in https://www.evanmiller.org/attention-is-off-by-one.html Supports fixed or learnable offset

forward(x: torch.Tensor) torch.Tensor#

forward pass

fusions.fused_cross_entropy_loss module#

This module uses PyTorch JIT to fuse the cross entropy loss calculation and batches communication calls.

core.fusions.fused_cross_entropy.calculate_cross_entropy_loss(
exp_logits: torch.Tensor,
predicted_logits_sum_exp_logits: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor]#

Calculates the final cross entropy loss for the tokens.

core.fusions.fused_cross_entropy.calculate_gradients(
softmax: torch.Tensor,
grad_output: torch.Tensor,
target_mask: torch.Tensor,
masked_target_1d: torch.Tensor,
) torch.Tensor#

Calculate the logits gradients scaled based on the CE loss

core.fusions.fused_cross_entropy.calculate_logits_max(
vocab_parallel_logits: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor]#

Calculates the maximum logits of the predicted tokens.

core.fusions.fused_cross_entropy.calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]#

Calculates the predicted logits for the tokens.

core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy(
vocab_parallel_logits,
target,
tp_group,
)#

Performs cross entropy loss when logits are split across tensor parallel ranks

Parameters:
  • vocab_parallel_logits – logits split across tensor parallel ranks dimension is [sequence_length, batch_size, hidden_size]

  • target – correct vocab ids of dimseion [sequence_length, micro_batch_size]

  • tp_group – the tensor parallel group over which to all reduce