core.fusions.fused_bias_swiglu#
Module Contents#
Classes#
Custom autograd function for SwiGLU activation with bias support. |
|
Custom autograd function for SwiGLU activation without bias. |
|
Functions#
Performs SwiGLU (Swish-Gated Linear Unit) activation function. |
|
Performs SwiGLU activation with bias addition. |
|
Computes the gradient for the SwiGLU activation function. |
|
Computes the gradient for the biased SwiGLU activation function. |
|
Implementation of biased SwiGLU that handles different input shapes. |
|
Token-wise-weighted bias swiglu fusion. |
API#
- core.fusions.fused_bias_swiglu.swiglu(y)#
Performs SwiGLU (Swish-Gated Linear Unit) activation function.
- Parameters:
y (torch.Tensor) – Input tensor to be split into two halves along the last dimension.
- Returns:
Result of SwiGLU activation: SiLU(y1) * y2, where y1, y2 are the split halves.
- Return type:
torch.Tensor
- core.fusions.fused_bias_swiglu.bias_swiglu(y, bias)#
Performs SwiGLU activation with bias addition.
- Parameters:
y (torch.Tensor) – Input tensor.
bias (torch.Tensor) – Bias tensor to be added to input.
- Returns:
Result of bias addition followed by SwiGLU activation.
- Return type:
torch.Tensor
- core.fusions.fused_bias_swiglu.weighted_swiglu(y, weights)#
- core.fusions.fused_bias_swiglu.swiglu_back(g, y)#
Computes the gradient for the SwiGLU activation function.
- Parameters:
g (torch.Tensor) – Gradient tensor from the subsequent layer.
y (torch.Tensor) – Input tensor that was used in the forward pass.
- Returns:
Gradient with respect to the input tensor, computed using the chain rule and the derivative of the SiLU activation function.
- Return type:
torch.Tensor
- core.fusions.fused_bias_swiglu.bias_swiglu_back(g, y, bias)#
Computes the gradient for the biased SwiGLU activation function.
- Parameters:
g (torch.Tensor) – Gradient tensor from the subsequent layer.
y (torch.Tensor) – Input tensor that was used in the forward pass.
bias (torch.Tensor) – Bias tensor that was added in the forward pass.
- Returns:
Gradient with respect to the input tensor, computed after applying the bias addition.
- Return type:
torch.Tensor
- core.fusions.fused_bias_swiglu.weighted_swiglu_back(g, y, weights)#
- class core.fusions.fused_bias_swiglu.BiasSwiGLUFunction#
Bases:
torch.autograd.FunctionCustom autograd function for SwiGLU activation with bias support.
- static forward(ctx, input, bias, fp8_input_store, cpu_offload_input)#
Forward pass of biased SwiGLU activation.
- Parameters:
ctx – Autograd context object for saving tensors for backward pass.
input (torch.Tensor) – Input tensor to apply SwiGLU to.
bias (torch.Tensor) – Bias tensor to be added to input before SwiGLU.
fp8_input_store (bool) – If True, stores intermediate values in FP8 format.
- Returns:
Result of applying bias addition followed by SwiGLU activation.
- Return type:
torch.Tensor
- static backward(ctx, grad_output)#
Backward pass of biased SwiGLU activation.
- Parameters:
ctx – Autograd context object containing saved tensors from forward pass.
grad_output (torch.Tensor) – Gradient of the loss with respect to the output.
- Returns:
Tuple containing: - Gradient with respect to the input tensor - Gradient with respect to the bias tensor - None for fp8_input_store parameter
- Return type:
tuple
- class core.fusions.fused_bias_swiglu.SwiGLUFunction#
Bases:
torch.autograd.FunctionCustom autograd function for SwiGLU activation without bias.
- static forward(ctx, input, fp8_input_store, cpu_offload_input)#
Forward pass of SwiGLU activation.
- Parameters:
ctx – Autograd context object for saving tensors for backward pass.
input (torch.Tensor) – Input tensor to apply SwiGLU to.
fp8_input_store (bool) – If True, stores intermediate values in FP8 format.
- Returns:
Result of applying SwiGLU activation.
- Return type:
torch.Tensor
- static backward(ctx, grad_output)#
Backward pass of SwiGLU activation.
- Parameters:
ctx – Autograd context object containing saved tensors from forward pass.
grad_output (torch.Tensor) – Gradient of the loss with respect to the output.
- Returns:
Tuple containing: - Gradient with respect to the input tensor - None for fp8_input_store parameter
- Return type:
tuple
- class core.fusions.fused_bias_swiglu.WeightedSwiGLUFunction#
Bases:
torch.autograd.Function- static forward(ctx, input, weights, fp8_input_store)#
- static backward(ctx, grad_output)#
- core.fusions.fused_bias_swiglu.bias_swiglu_impl(
- input,
- bias,
- fp8_input_store=False,
- cpu_offload_input=False,
Implementation of biased SwiGLU that handles different input shapes.
This function reshapes the input if necessary, applies the SwiGLU activation (with or without bias), and restores the original shape.
- Parameters:
input (torch.Tensor) – Input tensor to apply SwiGLU activation.
bias (torch.Tensor, optional) – Bias tensor to be added to input. If None, uses the bias-free SwiGLU variant.
fp8_input_store (bool, optional) – Whether to store intermediate values in FP8 format. Defaults to False.
- Returns:
Result of biased SwiGLU activation.
- Return type:
torch.Tensor
- Raises:
AssertionError – If input tensor does not have 2 or 3 dimensions.
- core.fusions.fused_bias_swiglu.weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False)#
Token-wise-weighted bias swiglu fusion.