core.fusions.fused_weighted_squared_relu#
Module Contents#
Classes#
Autograd wrapper around the weighted Squared-ReLU fused kernels. |
Functions#
Element-wise weight applied after Squared-ReLU. |
|
Gradient of Squared-ReLU. |
|
Backward for weighted Squared-ReLU. |
|
Token-wise weighted Squared-ReLU fusion with optional FP8 storage. |
API#
- core.fusions.fused_weighted_squared_relu.weighted_squared_relu(
- x: torch.Tensor,
- weights: torch.Tensor,
Element-wise weight applied after Squared-ReLU.
- Parameters:
x (torch.Tensor) – Input tensor.
weights (torch.Tensor) – Weight tensor that will be broadcast-multiplied with the activation result. Typically of shape
(B, 1)so it can be broadcast across the hidden dimension.
- Returns:
squared_relu(x) * weightswith originaldtypepreserved.- Return type:
torch.Tensor
- core.fusions.fused_weighted_squared_relu._squared_relu_back(g: torch.Tensor, x: torch.Tensor) torch.Tensor#
Gradient of Squared-ReLU.
The derivative of
(ReLU(x))^2w.r.txis2 * ReLU(x).
- core.fusions.fused_weighted_squared_relu.weighted_squared_relu_back(
- g: torch.Tensor,
- x: torch.Tensor,
- weights: torch.Tensor,
Backward for weighted Squared-ReLU.
Returns gradients w.r.t
xandweights.
- class core.fusions.fused_weighted_squared_relu.WeightedSquaredReLUFunction#
Bases:
torch.autograd.FunctionAutograd wrapper around the weighted Squared-ReLU fused kernels.
- static forward(ctx, input: torch.Tensor, weights: torch.Tensor)#
forward method for
WeightedSquaredReLUFunction- Parameters:
ctx – context object to store intermediate tensors.
input (torch.Tensor) – input tensor.
weights (torch.Tensor) – weight tensor.
fp8_input_store (bool) – a bool flag to indicate if storing input in fp8.
- static backward(ctx, grad_output: torch.Tensor)#
backward method for
WeightedSquaredReLUFunction- Parameters:
ctx – context object to store intermediate tensors.
grad_output (torch.Tensor) – gradient of the output of the forward function.
- core.fusions.fused_weighted_squared_relu.weighted_squared_relu_impl(
- input: torch.Tensor,
- weights: torch.Tensor,
Token-wise weighted Squared-ReLU fusion with optional FP8 storage.
- Parameters:
input (torch.Tensor) – Input tensor of shape
(B, *, hidden_size)where*can be the sequence dimension.weights (torch.Tensor) – Per-token weights broadcastable to the output of
squared_relu.
- Returns:
Output tensor with the same shape as
inputexcept that the hidden dimension remains unchanged.- Return type:
torch.Tensor