core.fusions.fused_softmax#
Module Contents#
Classes#
Fused operation which performs following three operations in sequence |
|
Fused operation which performs following three operations in sequence |
|
Fused operation which performs following two operations in sequence |
|
Softmax-off-by-one function as introduced in https://www.evanmiller.org/attention-is-off-by-one.html Supports fixed or learnable offset |
|
fused operation: scaling + mask + softmax |
API#
- class core.fusions.fused_softmax.ScaledUpperTriangMaskedSoftmax#
Bases:
torch.autograd.FunctionFused operation which performs following three operations in sequence
Scale the tensor.
Apply upper triangular mask (typically used in gpt models).
Perform softmax.
- static forward(ctx, inputs, scale)#
Forward pass for scaled upper-triangular masked softmax.
- Parameters:
ctx – Autograd context used to stash tensors for backward.
inputs (torch.Tensor) – Input tensor of shape [attn_batches, sq, sk].
scale (float) – Scaling factor applied prior to softmax.
- Returns:
Softmax results after applying scale and causal upper-triangular mask.
- Return type:
torch.Tensor
- static backward(ctx, output_grads)#
Backward pass for scaled upper-triangular masked softmax.
- Parameters:
ctx – Autograd context containing saved tensors from forward.
output_grads (torch.Tensor) – Upstream gradients matching forward output shape.
- Returns:
Gradient with respect to inputs and None for scale.
- Return type:
Tuple[torch.Tensor, None]
- class core.fusions.fused_softmax.ScaledMaskedSoftmax#
Bases:
torch.autograd.FunctionFused operation which performs following three operations in sequence
Scale the tensor.
Apply the mask.
Perform softmax.
- static forward(ctx, inputs, mask, scale)#
Forward pass for scaled masked softmax.
- Parameters:
ctx – Autograd context used to stash tensors for backward.
inputs (torch.Tensor) – Input tensor of shape [b, np, sq, sk].
mask (torch.Tensor) – Additive mask broadcastable to inputs.
scale (float) – Scaling factor applied prior to softmax.
- Returns:
Softmax results after applying scale and mask.
- Return type:
torch.Tensor
- static backward(ctx, output_grads)#
Backward pass for scaled masked softmax.
- Parameters:
ctx – Autograd context containing saved tensors from forward.
output_grads (torch.Tensor) – Upstream gradients matching forward output shape.
- Returns:
Gradient w.r.t inputs; None for mask and scale.
- Return type:
Tuple[torch.Tensor, None, None]
- class core.fusions.fused_softmax.ScaledSoftmax#
Bases:
torch.autograd.FunctionFused operation which performs following two operations in sequence
Scale the tensor.
Perform softmax.
- static forward(ctx, inputs, scale)#
Forward pass for scaled softmax (no mask).
- Parameters:
ctx – Autograd context used to stash tensors for backward.
inputs (torch.Tensor) – Input tensor of shape [b, np, sq, sk] or [attn_batches, sq, sk].
scale (float) – Scaling factor applied prior to softmax.
- Returns:
Softmax results after applying scale.
- Return type:
torch.Tensor
- static backward(ctx, output_grads)#
Backward pass for scaled softmax (no mask).
- Parameters:
ctx – Autograd context containing saved tensors from forward.
output_grads (torch.Tensor) – Upstream gradients matching forward output shape.
- Returns:
Gradient w.r.t inputs; None for unused args.
- Return type:
Tuple[torch.Tensor, None, None]
- class core.fusions.fused_softmax.SoftmaxOne(
- dim: Optional[int] = None,
- denominator_offset: Union[torch.Tensor, float] = 1.0,
Bases:
torch.nn.ModuleSoftmax-off-by-one function as introduced in https://www.evanmiller.org/attention-is-off-by-one.html Supports fixed or learnable offset
Initialization
- forward(x: torch.Tensor) torch.Tensor#
forward pass
- class core.fusions.fused_softmax.FusedScaleMaskSoftmax(
- input_in_fp16,
- input_in_bf16,
- attn_mask_type,
- scaled_masked_softmax_fusion,
- mask_func,
- softmax_in_fp32,
- scale,
- window_size=None,
Bases:
torch.nn.Modulefused operation: scaling + mask + softmax
- Parameters:
input_in_fp16 – flag to indicate if input in fp16 data format.
input_in_bf16 – flag to indicate if input in bf16 data format.
attn_mask_type – attention mask type (pad or causal)
scaled_masked_softmax_fusion – flag to indicate user want to use softmax fusion
mask_func – mask function to be applied.
softmax_in_fp32 – if true, softmax in performed at fp32 precision.
scale – scaling factor used in input tensor scaling.
Initialization
- forward(
- input: torch.Tensor,
- mask: Optional[torch.Tensor],
- softmax_offset: Optional[torch.Tensor] = None,
Forward pass of softmax with masked input.
In case attn_mask_type is causal the mask is generated and None can be passed. A user-defined mask is only needed when attn_mask_type is not causal.
- is_kernel_available(mask, b, np, sq, sk)#
Check whether the fused CUDA kernel can be used for the given shapes and settings.
- Parameters:
mask (Optional[torch.Tensor]) – Attention mask or None.
b (int) – Batch size.
np (int) – Number of attention heads per tensor-parallel partition.
sq (int) – Query sequence length.
sk (int) – Key sequence length.
- Returns:
True if the fused kernel constraints are satisfied; otherwise False.
- Return type:
bool
- forward_fused_softmax(input, mask)#
Compute softmax using fused CUDA kernels when available.
- Parameters:
input (torch.Tensor) – Attention scores of shape [b, np, sq, sk].
mask (Optional[torch.Tensor]) – Optional mask for non-causal attention.
- Returns:
Attention probabilities of shape [b, np, sq, sk].
- Return type:
torch.Tensor
- forward_torch_softmax(input, mask, softmax_offset=None)#
Fallback PyTorch implementation for masked softmax.
Applies optional scaling, constructs a causal or sliding-window mask if needed, applies the mask, and computes softmax in PyTorch. Optionally casts back to float16/bfloat16 when requested.
- Parameters:
input (torch.Tensor) – Attention scores of shape [b, np, sq, sk].
mask (Optional[torch.Tensor]) – Optional additive mask.
- Returns:
Attention probabilities of shape [b, np, sq, sk].
- Return type:
torch.Tensor
- static get_batch_per_block(sq, sk, b, np)#
Return CUDA kernel’s batch-per-block parameter for masked softmax.
- Parameters:
sq (int) – Query sequence length.
sk (int) – Key sequence length.
b (int) – Batch size.
np (int) – Number of attention heads per tensor-parallel partition.
- Returns:
Batch-per-block value as computed by the CUDA extension.
- Return type:
int