core.fusions.fused_weighted_squared_relu#

Module Contents#

Classes#

WeightedSquaredReLUFunction

Autograd wrapper around the weighted Squared-ReLU fused kernels.

Functions#

weighted_squared_relu

Element-wise weight applied after Squared-ReLU.

_squared_relu_back

Gradient of Squared-ReLU.

weighted_squared_relu_back

Backward for weighted Squared-ReLU.

weighted_squared_relu_impl

Token-wise weighted Squared-ReLU fusion with optional FP8 storage.

API#

core.fusions.fused_weighted_squared_relu.weighted_squared_relu(
x: torch.Tensor,
weights: torch.Tensor,
) torch.Tensor#

Element-wise weight applied after Squared-ReLU.

Parameters:
  • x (torch.Tensor) – Input tensor.

  • weights (torch.Tensor) – Weight tensor that will be broadcast-multiplied with the activation result. Typically of shape (B, 1) so it can be broadcast across the hidden dimension.

Returns:

squared_relu(x) * weights with original dtype preserved.

Return type:

torch.Tensor

core.fusions.fused_weighted_squared_relu._squared_relu_back(g: torch.Tensor, x: torch.Tensor) torch.Tensor#

Gradient of Squared-ReLU.

The derivative of (ReLU(x))^2 w.r.t x is 2 * ReLU(x).

core.fusions.fused_weighted_squared_relu.weighted_squared_relu_back(
g: torch.Tensor,
x: torch.Tensor,
weights: torch.Tensor,
)#

Backward for weighted Squared-ReLU.

Returns gradients w.r.t x and weights.

class core.fusions.fused_weighted_squared_relu.WeightedSquaredReLUFunction#

Bases: torch.autograd.Function

Autograd wrapper around the weighted Squared-ReLU fused kernels.

static forward(ctx, input: torch.Tensor, weights: torch.Tensor)#

forward method for WeightedSquaredReLUFunction

Parameters:
  • ctx – context object to store intermediate tensors.

  • input (torch.Tensor) – input tensor.

  • weights (torch.Tensor) – weight tensor.

  • fp8_input_store (bool) – a bool flag to indicate if storing input in fp8.

static backward(ctx, grad_output: torch.Tensor)#

backward method for WeightedSquaredReLUFunction

Parameters:
  • ctx – context object to store intermediate tensors.

  • grad_output (torch.Tensor) – gradient of the output of the forward function.

core.fusions.fused_weighted_squared_relu.weighted_squared_relu_impl(
input: torch.Tensor,
weights: torch.Tensor,
) torch.Tensor#

Token-wise weighted Squared-ReLU fusion with optional FP8 storage.

Parameters:
  • input (torch.Tensor) – Input tensor of shape (B, *, hidden_size) where * can be the sequence dimension.

  • weights (torch.Tensor) – Per-token weights broadcastable to the output of squared_relu.

Returns:

Output tensor with the same shape as input except that the hidden dimension remains unchanged.

Return type:

torch.Tensor