core.transformer.custom_layers.batch_invariant_kernels#

Module Contents#

Classes#

BatchInvariantTEGemmFn

Autograd function implementing batch-invariant TE GEMM.

BatchInvariantRMSNormFn

Autograd function implementing batch-invariant RMSNorm.

Functions#

_matmul_launch_metadata

Build launch metadata for Triton matmul kernels used in BIK matmul.

_compute_pid

matmul_kernel_persistent

Persistent matmul Triton kernel backing matmul_persistent.

get_compute_units

Returns the number of streaming multiprocessors (SMs) or equivalent compute units for the available accelerator. Assigns the value to NUM_SMS.

matmul_persistent

Persistent matmul kernel used by batch-invariant GEMM.

_log_softmax_kernel

Compute log_softmax along the last dimension of a 2D tensor. Each block handles one row of the input tensor.

log_softmax

Compute log_softmax using Triton kernel.

mean_kernel

Kernel for computing mean along a single dimension. Input is viewed as (M, N, K) where N is the dimension being reduced.

mean_dim

Triton implementation of torch.mean with single dimension reduction.

mm_batch_invariant

Batch-invariant replacement for aten::mm using a persistent matmul kernel.

addmm_batch_invariant

Batch-invariant replacement for aten::addmm using a persistent matmul kernel.

_log_softmax_batch_invariant

mean_batch_invariant

Batch-invariant replacement for aten::mean.dim over one or more dimensions.

get_batch_invariant_attention_block_size

Return the (block_m, block_n) tiling used for batch-invariant attention.

_import_module_if_available

_te_patch_for_batch_invariant

Patch Transformer Engine modules to use batch-invariant GEMM and RMSNorm.

_te_unpatch_for_batch_invariant

Restore original Transformer Engine functions if they were patched.

_extract_te_gemm_args

Utility to parse TE general_gemm flexible signature.

_is_supported_dtype_for_bik

_te_general_gemm_patched

Batch-invariant replacement for TE general_gemm. Returns a list of tensors to match TE’s API: (gemm_out, bias_grad, gelu_input, extra_output)

rmsnorm_batch_invariant

Batch-invariant RMSNorm wrapper that delegates to autograd-aware implementation.

_te_rmsnorm_forward_patched

Patched TE RMSNorm.forward that routes to batch-invariant implementation with autograd support.

is_batch_invariant_mode_enabled

Return True if global batch-invariant mode is currently enabled.

enable_batch_invariant_mode

Enable global batch-invariant mode and patch Aten/TE kernels.

disable_batch_invariant_mode

Disable global batch-invariant mode and restore original kernels.

set_batch_invariant_mode

Context manager to toggle global batch-invariant mode.

Data#

API#

core.transformer.custom_layers.batch_invariant_kernels.__all__#

[‘set_batch_invariant_mode’, ‘is_batch_invariant_mode_enabled’, ‘disable_batch_invariant_mode’, ‘ena…

core.transformer.custom_layers.batch_invariant_kernels._LOGGER#

‘getLogger(…)’

core.transformer.custom_layers.batch_invariant_kernels._matmul_launch_metadata(
grid: collections.abc.Callable[..., Any],
kernel: Any,
args: Dict[str, Any],
) Dict[str, Any]#

Build launch metadata for Triton matmul kernels used in BIK matmul.

core.transformer.custom_layers.batch_invariant_kernels._compute_pid(
tile_id,
num_pid_in_group,
num_pid_m,
GROUP_SIZE_M,
NUM_SMS,
)#
core.transformer.custom_layers.batch_invariant_kernels.matmul_kernel_persistent(
a_ptr,
b_ptr,
c_ptr,
bias_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
GROUP_SIZE_M: triton.language.constexpr,
NUM_SMS: triton.language.constexpr,
A_LARGE: triton.language.constexpr,
B_LARGE: triton.language.constexpr,
C_LARGE: triton.language.constexpr,
HAS_BIAS: triton.language.constexpr,
)#

Persistent matmul Triton kernel backing matmul_persistent.

core.transformer.custom_layers.batch_invariant_kernels.get_compute_units()#

Returns the number of streaming multiprocessors (SMs) or equivalent compute units for the available accelerator. Assigns the value to NUM_SMS.

core.transformer.custom_layers.batch_invariant_kernels.matmul_persistent(
a: torch.Tensor,
b: torch.Tensor,
bias: torch.Tensor | None = None,
)#

Persistent matmul kernel used by batch-invariant GEMM.

core.transformer.custom_layers.batch_invariant_kernels._log_softmax_kernel(
input_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: triton.language.constexpr,
)#

Compute log_softmax along the last dimension of a 2D tensor. Each block handles one row of the input tensor.

core.transformer.custom_layers.batch_invariant_kernels.log_softmax(input: torch.Tensor, dim: int = -1) torch.Tensor#

Compute log_softmax using Triton kernel.

Parameters:
  • input – Input tensor

  • dim – Dimension along which to compute log_softmax (only -1 or last dim supported)

Stashed changes

Returns:

Tensor with log_softmax applied along the specified dimension

core.transformer.custom_layers.batch_invariant_kernels.mean_kernel(
input_ptr,
output_ptr,
input_stride0,
input_stride1,
input_stride2,
output_stride0,
output_stride1,
M,
N,
K,
BLOCK_SIZE: triton.language.constexpr,
)#

Kernel for computing mean along a single dimension. Input is viewed as (M, N, K) where N is the dimension being reduced.

core.transformer.custom_layers.batch_invariant_kernels.mean_dim(
input: torch.Tensor,
dim: int,
keepdim: bool = False,
dtype: torch.dtype | None = None,
) torch.Tensor#

Triton implementation of torch.mean with single dimension reduction.

Parameters:
  • input – Input tensor

  • dim – Single dimension along which to compute mean

  • keepdim – Whether to keep the reduced dimension

  • dtype – Output dtype. If None, uses input dtype (or float32 for integer inputs)

Returns:

Tensor with mean values along specified dimension

core.transformer.custom_layers.batch_invariant_kernels.mm_batch_invariant(a, b)#

Batch-invariant replacement for aten::mm using a persistent matmul kernel.

core.transformer.custom_layers.batch_invariant_kernels.addmm_batch_invariant(bias, a, b)#

Batch-invariant replacement for aten::addmm using a persistent matmul kernel.

core.transformer.custom_layers.batch_invariant_kernels._log_softmax_batch_invariant(input, dim, _half_to_float)#
core.transformer.custom_layers.batch_invariant_kernels.mean_batch_invariant(
input,
dim,
keepdim=False,
dtype: torch.dtype | None = None,
)#

Batch-invariant replacement for aten::mean.dim over one or more dimensions.

core.transformer.custom_layers.batch_invariant_kernels.AttentionBlockSize#

‘namedtuple(…)’

core.transformer.custom_layers.batch_invariant_kernels.get_batch_invariant_attention_block_size() core.transformer.custom_layers.batch_invariant_kernels.AttentionBlockSize#

Return the (block_m, block_n) tiling used for batch-invariant attention.

core.transformer.custom_layers.batch_invariant_kernels._batch_invariant_MODE#

False

core.transformer.custom_layers.batch_invariant_kernels._batch_invariant_LIB#

None

core.transformer.custom_layers.batch_invariant_kernels._TE_GENERAL_GEMM_ORIG#

None

core.transformer.custom_layers.batch_invariant_kernels._TE_RMSNORM_ORIG_FWD#

None

core.transformer.custom_layers.batch_invariant_kernels._MEG_TE_GENERAL_GEMM_ORIG#

None

core.transformer.custom_layers.batch_invariant_kernels._TE_RMSNORM_FUNC_ORIGS: Dict[str, Any]#

None

core.transformer.custom_layers.batch_invariant_kernels._TE_GEMM_FUNC_ORIGS: Dict[str, Any]#

None

core.transformer.custom_layers.batch_invariant_kernels._import_module_if_available(name: str)#
core.transformer.custom_layers.batch_invariant_kernels._te_patch_for_batch_invariant()#

Patch Transformer Engine modules to use batch-invariant GEMM and RMSNorm.

This monkey-patches TE’s GEMM and RMSNorm entry points to dispatch to the batch-invariant implementations when batch-invariant mode is enabled. Safe no-op if TE is unavailable.

core.transformer.custom_layers.batch_invariant_kernels._te_unpatch_for_batch_invariant()#

Restore original Transformer Engine functions if they were patched.

core.transformer.custom_layers.batch_invariant_kernels._extract_te_gemm_args(
args: tuple,
kwargs: Dict[str, Any],
)#

Utility to parse TE general_gemm flexible signature.

Returns (A, B, out_dtype, layout, out, bias, grad).

core.transformer.custom_layers.batch_invariant_kernels._is_supported_dtype_for_bik(t: torch.dtype) bool#
class core.transformer.custom_layers.batch_invariant_kernels.BatchInvariantTEGemmFn#

Bases: torch.autograd.Function

Autograd function implementing batch-invariant TE GEMM.

static forward(
ctx,
A: torch.Tensor,
B: torch.Tensor,
bias: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
layout: str,
)#

Forward pass computing batch-invariant TE GEMM.

Respects TE’s flexible layout semantics, flattens leading dimensions of the input as needed, applies optional bias, and casts to out_dtype.

static backward(ctx, grad_output: torch.Tensor)#

Backward pass for batch-invariant TE GEMM.

Computes gradients w.r.t. A, B, and optional bias while mirroring the reshaping/layout logic used in the forward pass.

core.transformer.custom_layers.batch_invariant_kernels._te_general_gemm_patched(*args, **kwargs) List[torch.Tensor]#

Batch-invariant replacement for TE general_gemm. Returns a list of tensors to match TE’s API: (gemm_out, bias_grad, gelu_input, extra_output)

class core.transformer.custom_layers.batch_invariant_kernels.BatchInvariantRMSNormFn#

Bases: torch.autograd.Function

Autograd function implementing batch-invariant RMSNorm.

static forward(
ctx,
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
zero_centered_gamma: bool,
)#

Forward pass for batch-invariant RMSNorm.

Normalizes x using an RMSNorm-style statistic computed via mean_dim, applies affine weight, and stores intermediate rsigma for backward.

static backward(ctx, grad_output: torch.Tensor)#

Backward pass for batch-invariant RMSNorm.

Computes gradients w.r.t. input and weight while matching TE’s fp32 accumulation and reduction behavior for numerical stability.

core.transformer.custom_layers.batch_invariant_kernels.rmsnorm_batch_invariant(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
) torch.Tensor#

Batch-invariant RMSNorm wrapper that delegates to autograd-aware implementation.

This provides a simple functional interface while using the optimized BatchInvariantRMSNormFn which has better numerics (fp32 precision in forward/backward).

core.transformer.custom_layers.batch_invariant_kernels._te_rmsnorm_forward_patched(self, x: torch.Tensor) torch.Tensor#

Patched TE RMSNorm.forward that routes to batch-invariant implementation with autograd support.

core.transformer.custom_layers.batch_invariant_kernels.is_batch_invariant_mode_enabled()#

Return True if global batch-invariant mode is currently enabled.

core.transformer.custom_layers.batch_invariant_kernels.enable_batch_invariant_mode()#

Enable global batch-invariant mode and patch Aten/TE kernels.

core.transformer.custom_layers.batch_invariant_kernels.disable_batch_invariant_mode()#

Disable global batch-invariant mode and restore original kernels.

core.transformer.custom_layers.batch_invariant_kernels.set_batch_invariant_mode(enabled: bool = True)#

Context manager to toggle global batch-invariant mode.

When enabled is True, batch-invariant kernels are enabled for the duration of the context; when False, they are disabled for the duration. This implementation is re-entrant and correctly restores the previous state even under nesting.