core.transformer.custom_layers.batch_invariant_kernels#
Module Contents#
Classes#
Autograd function implementing batch-invariant TE GEMM. |
|
Autograd function implementing batch-invariant RMSNorm. |
Functions#
Build launch metadata for Triton matmul kernels used in BIK matmul. |
|
Persistent matmul Triton kernel backing |
|
Returns the number of streaming multiprocessors (SMs) or equivalent compute units for the available accelerator. Assigns the value to NUM_SMS. |
|
Persistent matmul kernel used by batch-invariant GEMM. |
|
Compute log_softmax along the last dimension of a 2D tensor. Each block handles one row of the input tensor. |
|
Compute log_softmax using Triton kernel. |
|
Kernel for computing mean along a single dimension. Input is viewed as (M, N, K) where N is the dimension being reduced. |
|
Triton implementation of torch.mean with single dimension reduction. |
|
Batch-invariant replacement for |
|
Batch-invariant replacement for |
|
Batch-invariant replacement for |
|
Return the (block_m, block_n) tiling used for batch-invariant attention. |
|
Patch Transformer Engine modules to use batch-invariant GEMM and RMSNorm. |
|
Restore original Transformer Engine functions if they were patched. |
|
Utility to parse TE general_gemm flexible signature. |
|
Batch-invariant replacement for TE general_gemm. Returns a list of tensors to match TE’s API: (gemm_out, bias_grad, gelu_input, extra_output) |
|
Batch-invariant RMSNorm wrapper that delegates to autograd-aware implementation. |
|
Patched TE RMSNorm.forward that routes to batch-invariant implementation with autograd support. |
|
Return True if global batch-invariant mode is currently enabled. |
|
Enable global batch-invariant mode and patch Aten/TE kernels. |
|
Disable global batch-invariant mode and restore original kernels. |
|
Context manager to toggle global batch-invariant mode. |
Data#
API#
- core.transformer.custom_layers.batch_invariant_kernels.__all__#
[‘set_batch_invariant_mode’, ‘is_batch_invariant_mode_enabled’, ‘disable_batch_invariant_mode’, ‘ena…
- core.transformer.custom_layers.batch_invariant_kernels._LOGGER#
‘getLogger(…)’
- core.transformer.custom_layers.batch_invariant_kernels._matmul_launch_metadata(
- grid: collections.abc.Callable[..., Any],
- kernel: Any,
- args: Dict[str, Any],
Build launch metadata for Triton matmul kernels used in BIK matmul.
- core.transformer.custom_layers.batch_invariant_kernels._compute_pid(
- tile_id,
- num_pid_in_group,
- num_pid_m,
- GROUP_SIZE_M,
- NUM_SMS,
- core.transformer.custom_layers.batch_invariant_kernels.matmul_kernel_persistent(
- a_ptr,
- b_ptr,
- c_ptr,
- bias_ptr,
- M,
- N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- GROUP_SIZE_M: triton.language.constexpr,
- NUM_SMS: triton.language.constexpr,
- A_LARGE: triton.language.constexpr,
- B_LARGE: triton.language.constexpr,
- C_LARGE: triton.language.constexpr,
- HAS_BIAS: triton.language.constexpr,
Persistent matmul Triton kernel backing
matmul_persistent.
- core.transformer.custom_layers.batch_invariant_kernels.get_compute_units()#
Returns the number of streaming multiprocessors (SMs) or equivalent compute units for the available accelerator. Assigns the value to NUM_SMS.
- core.transformer.custom_layers.batch_invariant_kernels.matmul_persistent(
- a: torch.Tensor,
- b: torch.Tensor,
- bias: torch.Tensor | None = None,
Persistent matmul kernel used by batch-invariant GEMM.
- core.transformer.custom_layers.batch_invariant_kernels._log_softmax_kernel(
- input_ptr,
- output_ptr,
- input_row_stride,
- output_row_stride,
- n_cols,
- BLOCK_SIZE: triton.language.constexpr,
Compute log_softmax along the last dimension of a 2D tensor. Each block handles one row of the input tensor.
- core.transformer.custom_layers.batch_invariant_kernels.log_softmax(input: torch.Tensor, dim: int = -1) torch.Tensor#
Compute log_softmax using Triton kernel.
- Parameters:
input – Input tensor
dim – Dimension along which to compute log_softmax (only -1 or last dim supported)
Stashed changes
- Returns:
Tensor with log_softmax applied along the specified dimension
- core.transformer.custom_layers.batch_invariant_kernels.mean_kernel(
- input_ptr,
- output_ptr,
- input_stride0,
- input_stride1,
- input_stride2,
- output_stride0,
- output_stride1,
- M,
- N,
- K,
- BLOCK_SIZE: triton.language.constexpr,
Kernel for computing mean along a single dimension. Input is viewed as (M, N, K) where N is the dimension being reduced.
- core.transformer.custom_layers.batch_invariant_kernels.mean_dim(
- input: torch.Tensor,
- dim: int,
- keepdim: bool = False,
- dtype: torch.dtype | None = None,
Triton implementation of torch.mean with single dimension reduction.
- Parameters:
input – Input tensor
dim – Single dimension along which to compute mean
keepdim – Whether to keep the reduced dimension
dtype – Output dtype. If None, uses input dtype (or float32 for integer inputs)
- Returns:
Tensor with mean values along specified dimension
- core.transformer.custom_layers.batch_invariant_kernels.mm_batch_invariant(a, b)#
Batch-invariant replacement for
aten::mmusing a persistent matmul kernel.
- core.transformer.custom_layers.batch_invariant_kernels.addmm_batch_invariant(bias, a, b)#
Batch-invariant replacement for
aten::addmmusing a persistent matmul kernel.
- core.transformer.custom_layers.batch_invariant_kernels._log_softmax_batch_invariant(input, dim, _half_to_float)#
- core.transformer.custom_layers.batch_invariant_kernels.mean_batch_invariant(
- input,
- dim,
- keepdim=False,
- dtype: torch.dtype | None = None,
Batch-invariant replacement for
aten::mean.dimover one or more dimensions.
- core.transformer.custom_layers.batch_invariant_kernels.AttentionBlockSize#
‘namedtuple(…)’
- core.transformer.custom_layers.batch_invariant_kernels.get_batch_invariant_attention_block_size() core.transformer.custom_layers.batch_invariant_kernels.AttentionBlockSize#
Return the (block_m, block_n) tiling used for batch-invariant attention.
- core.transformer.custom_layers.batch_invariant_kernels._batch_invariant_MODE#
False
- core.transformer.custom_layers.batch_invariant_kernels._batch_invariant_LIB#
None
- core.transformer.custom_layers.batch_invariant_kernels._TE_GENERAL_GEMM_ORIG#
None
- core.transformer.custom_layers.batch_invariant_kernels._TE_RMSNORM_ORIG_FWD#
None
- core.transformer.custom_layers.batch_invariant_kernels._MEG_TE_GENERAL_GEMM_ORIG#
None
- core.transformer.custom_layers.batch_invariant_kernels._TE_RMSNORM_FUNC_ORIGS: Dict[str, Any]#
None
- core.transformer.custom_layers.batch_invariant_kernels._TE_GEMM_FUNC_ORIGS: Dict[str, Any]#
None
- core.transformer.custom_layers.batch_invariant_kernels._import_module_if_available(name: str)#
- core.transformer.custom_layers.batch_invariant_kernels._te_patch_for_batch_invariant()#
Patch Transformer Engine modules to use batch-invariant GEMM and RMSNorm.
This monkey-patches TE’s GEMM and RMSNorm entry points to dispatch to the batch-invariant implementations when batch-invariant mode is enabled. Safe no-op if TE is unavailable.
- core.transformer.custom_layers.batch_invariant_kernels._te_unpatch_for_batch_invariant()#
Restore original Transformer Engine functions if they were patched.
- core.transformer.custom_layers.batch_invariant_kernels._extract_te_gemm_args(
- args: tuple,
- kwargs: Dict[str, Any],
Utility to parse TE general_gemm flexible signature.
Returns (A, B, out_dtype, layout, out, bias, grad).
- core.transformer.custom_layers.batch_invariant_kernels._is_supported_dtype_for_bik(t: torch.dtype) bool#
- class core.transformer.custom_layers.batch_invariant_kernels.BatchInvariantTEGemmFn#
Bases:
torch.autograd.FunctionAutograd function implementing batch-invariant TE GEMM.
- static forward(
- ctx,
- A: torch.Tensor,
- B: torch.Tensor,
- bias: Optional[torch.Tensor],
- out_dtype: Optional[torch.dtype],
- layout: str,
Forward pass computing batch-invariant TE GEMM.
Respects TE’s flexible
layoutsemantics, flattens leading dimensions of the input as needed, applies optional bias, and casts toout_dtype.
- static backward(ctx, grad_output: torch.Tensor)#
Backward pass for batch-invariant TE GEMM.
Computes gradients w.r.t. A, B, and optional bias while mirroring the reshaping/layout logic used in the forward pass.
- core.transformer.custom_layers.batch_invariant_kernels._te_general_gemm_patched(*args, **kwargs) List[torch.Tensor]#
Batch-invariant replacement for TE general_gemm. Returns a list of tensors to match TE’s API: (gemm_out, bias_grad, gelu_input, extra_output)
- class core.transformer.custom_layers.batch_invariant_kernels.BatchInvariantRMSNormFn#
Bases:
torch.autograd.FunctionAutograd function implementing batch-invariant RMSNorm.
- static forward(
- ctx,
- x: torch.Tensor,
- weight: torch.Tensor,
- eps: float,
- zero_centered_gamma: bool,
Forward pass for batch-invariant RMSNorm.
Normalizes
xusing an RMSNorm-style statistic computed viamean_dim, applies affineweight, and stores intermediate rsigma for backward.
- static backward(ctx, grad_output: torch.Tensor)#
Backward pass for batch-invariant RMSNorm.
Computes gradients w.r.t. input and weight while matching TE’s fp32 accumulation and reduction behavior for numerical stability.
- core.transformer.custom_layers.batch_invariant_kernels.rmsnorm_batch_invariant(
- x: torch.Tensor,
- weight: torch.Tensor,
- eps: float,
Batch-invariant RMSNorm wrapper that delegates to autograd-aware implementation.
This provides a simple functional interface while using the optimized BatchInvariantRMSNormFn which has better numerics (fp32 precision in forward/backward).
- core.transformer.custom_layers.batch_invariant_kernels._te_rmsnorm_forward_patched(self, x: torch.Tensor) torch.Tensor#
Patched TE RMSNorm.forward that routes to batch-invariant implementation with autograd support.
- core.transformer.custom_layers.batch_invariant_kernels.is_batch_invariant_mode_enabled()#
Return True if global batch-invariant mode is currently enabled.
- core.transformer.custom_layers.batch_invariant_kernels.enable_batch_invariant_mode()#
Enable global batch-invariant mode and patch Aten/TE kernels.
- core.transformer.custom_layers.batch_invariant_kernels.disable_batch_invariant_mode()#
Disable global batch-invariant mode and restore original kernels.
- core.transformer.custom_layers.batch_invariant_kernels.set_batch_invariant_mode(enabled: bool = True)#
Context manager to toggle global batch-invariant mode.
When
enabledis True, batch-invariant kernels are enabled for the duration of the context; when False, they are disabled for the duration. This implementation is re-entrant and correctly restores the previous state even under nesting.