Source code for nemo_automodel.components._peft.lora_kernel

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock

import torch
from packaging import version

from nemo_automodel.shared.import_utils import MISSING_TRITON_MSG, null_decorator

try:
    import triton
    import triton.language as tl

    if version.parse(triton.__version__) < version.parse("3.4.0") and not torch.cuda.is_available():
        HAVE_TRITON = False
    else:
        HAVE_TRITON = tl.constexpr(version.parse(triton.__version__) >= version.parse("2.0.0"))
except ImportError:
    HAVE_TRITON = False

if not HAVE_TRITON:
    triton = MagicMock()
    triton.jit = null_decorator
    triton.autotune = null_decorator
    triton.heuristics = null_decorator
    tl = MagicMock()


[docs] def forward_autotune_configs(): """ Method for generating Triton configs for lora_forward_kernel. """ out = list() for blk_m in [16, 32, 64]: for blk_k in [128, 256, 512]: for blk_l in [128, 256, 512]: out.append( triton.Config( {"BLOCK_SIZE_M": blk_m, "BLOCK_SIZE_K": blk_k, "BLOCK_SIZE_L": blk_l, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4, ) ) return out
[docs] @triton.jit def get_pid_coords(M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """ Converts one-dimensional triton pids into two dimensions. """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m return pid_m, pid_n
[docs] @triton.jit def inner_kernel( pid_m, pid_n, a_ptr, b_ptr, M, K, N, stride_am, stride_ak, stride_bk, stride_bn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, scale, ): """ Performs the matrix multiplication AB. A is an M x K matrix and B is an N x K matrix. The result is returned to be stored by the calling method. """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) ab = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K) b_mask = (offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N) a = tl.load(a_ptrs, mask=a_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) ab += tl.dot(a, b, out_dtype=tl.float32) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk return scale * ab
[docs] @triton.jit def block_vector_mul( pid_m, pid_n, ab_result, c_ptr, d_ptr, M, N, L, stride_cn, stride_cl, stride_dm, stride_dl, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_L: tl.constexpr, ): """ Multiplies an M x N vector AB and and N x L vector C and adds the result to the output vector D. N is assumed to be smaller than BLOCK_SIZE_N. """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_l = tl.arange(0, BLOCK_SIZE_L) offs_dm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) c_ptrs = c_ptr + (offs_cn[:, None] * stride_cn + offs_l[None, :] * stride_cl) d_ptrs = d_ptr + stride_dm * offs_dm[:, None] + stride_dl * offs_l[None, :] d_mask = (offs_dm[:, None] < M) & (offs_l[None, :] < L) c_mask = (offs_cn[:, None] < N) & (offs_l[None, :] < L) for lx in tl.range(0, tl.cdiv(L, BLOCK_SIZE_L)): d_mask = (offs_dm[:, None] < M) & (offs_l[None, :] < L - lx * BLOCK_SIZE_L) c_mask = (offs_cn[:, None] < N) & (offs_l[None, :] < L - lx * BLOCK_SIZE_L) c = tl.load(c_ptrs, mask=c_mask, other=0.0) abc = tl.dot(ab_result, c) tl.store(d_ptrs, abc, mask=d_mask) c_ptrs += BLOCK_SIZE_L * stride_cl d_ptrs += BLOCK_SIZE_L * stride_dl
@triton.autotune( configs=forward_autotune_configs(), key=["N", "K", "L"], ) # This optimization exploits that N is the LoRA dimension and thus we only need one block. @triton.heuristics(values={"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["N"]), 16)}) @triton.jit def lora_forward_kernel( x_ptr, la_ptr, lb_ptr, res_ptr, M, N, K, L, stride_x_m, stride_x_k, stride_la_k, stride_la_n, stride_lb_n, stride_lb_l, stride_res_m, stride_res_l, # scale factor scale, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # BLOCK_SIZE_L: tl.constexpr, GROUP_SIZE_M: tl.constexpr, # ): """ Kernel for computing the matmul D = A x B x C. A has shape (M, K), B has shape (K, N), C has shape (N, L), and D has shape (M, L) N, the LoRA dimension must be less than or equal to than BLOCK_SIZE_N. """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) pid_m, pid_n = get_pid_coords(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) ab_result = inner_kernel( pid_m, pid_n, x_ptr, la_ptr, M, K, N, stride_x_m, stride_x_k, stride_la_k, stride_la_n, BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N, scale, ) ab_result = ab_result.to(lb_ptr.dtype.element_ty) block_vector_mul( pid_m, pid_n, ab_result, lb_ptr, res_ptr, M, N, L, stride_lb_n, stride_lb_l, stride_res_m, stride_res_l, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_L, )
[docs] def lora_forward_wrapper(x, lora_A, lora_B, res, scale, dtype=torch.float32): """ Computes LoRA forward pass. Args: x: input activations, (M x K) lora_A: LoRA A weights (K x N) lora_B: LoRA B weights (N x L) res (optional(torch.Tensor)): output tensor scale: LoRA scale factor (scalar) dtype: dtype for output """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) assert x.shape[1] == lora_A.shape[0], "Incompatible X and LoRA A dimensions" assert lora_A.shape[1] == lora_B.shape[0], "Incompatible LoRA dimensions" if res is not None: assert x.shape[0] == res.shape[0], "Incompatible X and output dimensions" assert lora_B.shape[1] == res.shape[1], "Incompatible LoRA B and output dimensions" M, K = x.shape K, N = lora_A.shape N, L = lora_B.shape if res is None: res = torch.empty((M, L), device=x.device, dtype=dtype) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) # noqa: E731 lora_forward_kernel[grid]( x, lora_A, lora_B, res, M, N, K, L, x.stride(0), x.stride(1), lora_A.stride(0), lora_A.stride(1), lora_B.stride(0), lora_B.stride(1), res.stride(0), res.stride(1), scale, ) return res
[docs] def da_dx_autotune_configs(): """ Method for generating Triton configs for lora_da_dx_kernel. """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) out = list() for blk_k in [64, 128]: for blk_l in [64, 128, 256]: for blk_m in [64, 128]: out.append( triton.Config( {"BLOCK_SIZE_K": blk_k, "BLOCK_SIZE_L": blk_l, "BLOCK_SIZE_M": blk_m, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4, ) ) return out
@triton.autotune( configs=da_dx_autotune_configs() if HAVE_TRITON else list(), key=["N", "K", "L"], ) @triton.heuristics(values={"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["N"]), 16)}) @triton.jit def lora_da_dx_kernel( dy_ptr, b_ptr, a_ptr, dx_ptr, dyb_ptr, M, K, N, L, stride_dy_m, stride_dy_k, stride_lorab_k, stride_lorab_n, stride_loraa_n, stride_loraa_l, stride_dx_m, stride_dx_l, stride_dyb_m, stride_dyb_n, scale, BLOCK_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_L: tl.constexpr, ): """ Kernel for computing the matmul DYB = DY x B and DX = DY * B * A. XT has shape (S, M), DY has shape (M, K), B has shape (K, N), and A has shape (N, L) N, the LoRA dimension must be less than or equal to than BLOCK_SIZE_N. The result returned by this kernel is reduced in the wrapper. """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) pid_m, pid_n = get_pid_coords(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) dyb = inner_kernel( pid_m, pid_n, dy_ptr, b_ptr, M, K, N, stride_dy_m, stride_dy_k, stride_lorab_k, stride_lorab_n, BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N, scale, ) dyb = dyb.to(a_ptr.dtype.element_ty) offs_la_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_dx_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_l = tl.arange(0, BLOCK_SIZE_L) dx_ptrs = dx_ptr + stride_dx_m * offs_dx_m[:, None] + stride_dx_l * offs_l[None, :] la_ptrs = a_ptr + stride_loraa_n * offs_la_n[:, None] + stride_loraa_l * offs_l[None, :] for lx in tl.range(0, tl.cdiv(L, BLOCK_SIZE_L)): dx_mask = (offs_dx_m[:, None] < M) & (offs_l[None, :] < L - lx * BLOCK_SIZE_L) la_mask = (offs_la_n[:, None] < N) & (offs_l[None, :] < L - lx * BLOCK_SIZE_L) lora_a = tl.load(la_ptrs, mask=la_mask, other=0.0) dx = tl.dot(dyb, lora_a) dx = dx.to(a_ptr.dtype.element_ty) tl.store(dx_ptrs, dx, mask=dx_mask) la_ptrs += BLOCK_SIZE_L * stride_loraa_l dx_ptrs += BLOCK_SIZE_L * stride_dx_l offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dyb_ptrs = dyb_ptr + stride_dyb_m * offs_cm[:, None] + stride_dyb_n * offs_cn[None, :] dyb_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(dyb_ptrs, dyb, mask=dyb_mask)
[docs] def lora_da_dx_update_wrapper(xt, dy, lora_B, lora_A, scale, dtype=torch.float32): """Computes dlora_A and dx. xt: input activation weights, transposed (S x M) dy: gradients (M x K) lora_B: LoRA B weights (K x N) lora_A: LoRA A weights (N x L) scale: LoRA scale factor (scalar) dtype: dtype for output """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) assert xt.shape[1] == dy.shape[0], "Incompatible X and dY dimensions" assert dy.shape[1] == lora_B.shape[0], "Incompatible dY and B dimensions" assert lora_B.shape[1] == lora_A.shape[0], "LoRA dimensions must match" _, M = xt.shape M, K = dy.shape K, N = lora_B.shape N, L = lora_A.shape dx = torch.empty((M, L), device=xt.device, dtype=dtype) dyb = torch.empty((M, N), device=xt.device, dtype=dtype) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) # noqa: E731 lora_da_dx_kernel[grid]( dy, lora_B, lora_A, dx, dyb, M, K, N, L, dy.stride(0), dy.stride(1), # lora_B.stride(0), lora_B.stride(1), lora_A.stride(0), lora_A.stride(1), dx.stride(0), dx.stride(1), dyb.stride(0), dyb.stride(1), scale, ) dlora_A = torch.matmul(xt, dyb) return dlora_A, dx
[docs] def db_autotune_configs(): """ Method for generating Triton configs for lora_db_kernel. """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) out = list() for blk_n in [32, 64, 128]: for blk_k in [32, 64, 128]: for blk_m in [64, 128]: out.append( triton.Config( {"BLOCK_SIZE_N": blk_n, "BLOCK_SIZE_K": blk_k, "BLOCK_SIZE_M": blk_m, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4, ) ) return out
@triton.autotune( configs=db_autotune_configs() if HAVE_TRITON else list(), key=["S", "M", "K"], ) @triton.jit def lora_db_kernel( a_ptr, b_ptr, c_ptr, M, K, N, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, scale, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ): """ Kernel for computing the matmul AXT = A x X^T. A has shape (M, K), X has shape (N, K). """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) pid_m, pid_n = get_pid_coords(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) ab = inner_kernel( pid_m, pid_n, a_ptr, b_ptr, M, K, N, stride_am, stride_ak, stride_bk, stride_bn, BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N, scale, ) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, ab, mask=c_mask)
[docs] def lora_db_update_wrapper(lora_A, xt, dy, scale, dtype=torch.float32): """Computes d_lora_B. lora_A: LoRA A weights (M x K) xt: input activation weights, transposed (K x N) dy: gradients (N x S) scale: LoRA scale factor (scalar) dtype: dtype for output """ if not HAVE_TRITON: raise ImportError(MISSING_TRITON_MSG) assert xt.shape[1] == dy.shape[0], "Incompatible X and dY dimensions" assert lora_A.shape[1] == xt.shape[0], "Incompatible X and A dimensions" M, K = lora_A.shape K, N = xt.shape N, _ = dy.shape axt = torch.empty((M, N), device=dy.device, dtype=dtype) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) lora_db_kernel[grid]( lora_A, xt, axt, M, K, N, lora_A.stride(0), lora_A.stride(1), xt.stride(0), xt.stride(1), axt.stride(0), axt.stride(1), scale, ) return torch.matmul(axt, dy).t()