core.fusions.fused_softmax#

Module Contents#

Classes#

ScaledUpperTriangMaskedSoftmax

Fused operation which performs following three operations in sequence

ScaledMaskedSoftmax

Fused operation which performs following three operations in sequence

ScaledSoftmax

Fused operation which performs following two operations in sequence

SoftmaxOne

Softmax-off-by-one function as introduced in https://www.evanmiller.org/attention-is-off-by-one.html Supports fixed or learnable offset

FusedScaleMaskSoftmax

fused operation: scaling + mask + softmax

API#

class core.fusions.fused_softmax.ScaledUpperTriangMaskedSoftmax#

Bases: torch.autograd.Function

Fused operation which performs following three operations in sequence

  1. Scale the tensor.

  2. Apply upper triangular mask (typically used in gpt models).

  3. Perform softmax.

static forward(ctx, inputs, scale)#

Forward pass for scaled upper-triangular masked softmax.

Parameters:
  • ctx – Autograd context used to stash tensors for backward.

  • inputs (torch.Tensor) – Input tensor of shape [attn_batches, sq, sk].

  • scale (float) – Scaling factor applied prior to softmax.

Returns:

Softmax results after applying scale and causal upper-triangular mask.

Return type:

torch.Tensor

static backward(ctx, output_grads)#

Backward pass for scaled upper-triangular masked softmax.

Parameters:
  • ctx – Autograd context containing saved tensors from forward.

  • output_grads (torch.Tensor) – Upstream gradients matching forward output shape.

Returns:

Gradient with respect to inputs and None for scale.

Return type:

Tuple[torch.Tensor, None]

class core.fusions.fused_softmax.ScaledMaskedSoftmax#

Bases: torch.autograd.Function

Fused operation which performs following three operations in sequence

  1. Scale the tensor.

  2. Apply the mask.

  3. Perform softmax.

static forward(ctx, inputs, mask, scale)#

Forward pass for scaled masked softmax.

Parameters:
  • ctx – Autograd context used to stash tensors for backward.

  • inputs (torch.Tensor) – Input tensor of shape [b, np, sq, sk].

  • mask (torch.Tensor) – Additive mask broadcastable to inputs.

  • scale (float) – Scaling factor applied prior to softmax.

Returns:

Softmax results after applying scale and mask.

Return type:

torch.Tensor

static backward(ctx, output_grads)#

Backward pass for scaled masked softmax.

Parameters:
  • ctx – Autograd context containing saved tensors from forward.

  • output_grads (torch.Tensor) – Upstream gradients matching forward output shape.

Returns:

Gradient w.r.t inputs; None for mask and scale.

Return type:

Tuple[torch.Tensor, None, None]

class core.fusions.fused_softmax.ScaledSoftmax#

Bases: torch.autograd.Function

Fused operation which performs following two operations in sequence

  1. Scale the tensor.

  2. Perform softmax.

static forward(ctx, inputs, scale)#

Forward pass for scaled softmax (no mask).

Parameters:
  • ctx – Autograd context used to stash tensors for backward.

  • inputs (torch.Tensor) – Input tensor of shape [b, np, sq, sk] or [attn_batches, sq, sk].

  • scale (float) – Scaling factor applied prior to softmax.

Returns:

Softmax results after applying scale.

Return type:

torch.Tensor

static backward(ctx, output_grads)#

Backward pass for scaled softmax (no mask).

Parameters:
  • ctx – Autograd context containing saved tensors from forward.

  • output_grads (torch.Tensor) – Upstream gradients matching forward output shape.

Returns:

Gradient w.r.t inputs; None for unused args.

Return type:

Tuple[torch.Tensor, None, None]

class core.fusions.fused_softmax.SoftmaxOne(
dim: Optional[int] = None,
denominator_offset: Union[torch.Tensor, float] = 1.0,
)#

Bases: torch.nn.Module

Softmax-off-by-one function as introduced in https://www.evanmiller.org/attention-is-off-by-one.html Supports fixed or learnable offset

Initialization

forward(x: torch.Tensor) torch.Tensor#

forward pass

class core.fusions.fused_softmax.FusedScaleMaskSoftmax(
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
window_size=None,
)#

Bases: torch.nn.Module

fused operation: scaling + mask + softmax

Parameters:
  • input_in_fp16 – flag to indicate if input in fp16 data format.

  • input_in_bf16 – flag to indicate if input in bf16 data format.

  • attn_mask_type – attention mask type (pad or causal)

  • scaled_masked_softmax_fusion – flag to indicate user want to use softmax fusion

  • mask_func – mask function to be applied.

  • softmax_in_fp32 – if true, softmax in performed at fp32 precision.

  • scale – scaling factor used in input tensor scaling.

Initialization

forward(
input: torch.Tensor,
mask: Optional[torch.Tensor],
softmax_offset: Optional[torch.Tensor] = None,
)#

Forward pass of softmax with masked input.

In case attn_mask_type is causal the mask is generated and None can be passed. A user-defined mask is only needed when attn_mask_type is not causal.

is_kernel_available(mask, b, np, sq, sk)#

Check whether the fused CUDA kernel can be used for the given shapes and settings.

Parameters:
  • mask (Optional[torch.Tensor]) – Attention mask or None.

  • b (int) – Batch size.

  • np (int) – Number of attention heads per tensor-parallel partition.

  • sq (int) – Query sequence length.

  • sk (int) – Key sequence length.

Returns:

True if the fused kernel constraints are satisfied; otherwise False.

Return type:

bool

forward_fused_softmax(input, mask)#

Compute softmax using fused CUDA kernels when available.

Parameters:
  • input (torch.Tensor) – Attention scores of shape [b, np, sq, sk].

  • mask (Optional[torch.Tensor]) – Optional mask for non-causal attention.

Returns:

Attention probabilities of shape [b, np, sq, sk].

Return type:

torch.Tensor

forward_torch_softmax(input, mask, softmax_offset=None)#

Fallback PyTorch implementation for masked softmax.

Applies optional scaling, constructs a causal or sliding-window mask if needed, applies the mask, and computes softmax in PyTorch. Optionally casts back to float16/bfloat16 when requested.

Parameters:
  • input (torch.Tensor) – Attention scores of shape [b, np, sq, sk].

  • mask (Optional[torch.Tensor]) – Optional additive mask.

Returns:

Attention probabilities of shape [b, np, sq, sk].

Return type:

torch.Tensor

static get_batch_per_block(sq, sk, b, np)#

Return CUDA kernel’s batch-per-block parameter for masked softmax.

Parameters:
  • sq (int) – Query sequence length.

  • sk (int) – Key sequence length.

  • b (int) – Batch size.

  • np (int) – Number of attention heads per tensor-parallel partition.

Returns:

Batch-per-block value as computed by the CUDA extension.

Return type:

int