nemo_automodel.components._peft.lora_kernel#

Module Contents#

Functions#

forward_autotune_configs

Method for generating Triton configs for lora_forward_kernel.

get_pid_coords

Converts one-dimensional triton pids into two dimensions.

inner_kernel

Performs the matrix multiplication AB.

block_vector_mul

Multiplies an M x N vector AB and and N x L vector C and adds the result to the output vector D.

lora_forward_kernel

Kernel for computing the matmul D = A x B x C.

lora_forward_wrapper

Computes LoRA forward pass.

da_dx_autotune_configs

Method for generating Triton configs for lora_da_dx_kernel.

lora_da_dx_kernel

Kernel for computing the matmul DYB = DY x B and DX = DY * B * A.

lora_da_dx_update_wrapper

Computes dlora_A and dx.

db_autotune_configs

Method for generating Triton configs for lora_db_kernel.

lora_db_kernel

Kernel for computing the matmul AXT = A x X^T.

lora_db_update_wrapper

Computes d_lora_B.

API#

nemo_automodel.components._peft.lora_kernel.forward_autotune_configs()[source]#

Method for generating Triton configs for lora_forward_kernel.

nemo_automodel.components._peft.lora_kernel.get_pid_coords(
M,
N,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
GROUP_SIZE_M: triton.language.constexpr,
)#

Converts one-dimensional triton pids into two dimensions.

nemo_automodel.components._peft.lora_kernel.inner_kernel(
pid_m,
pid_n,
a_ptr,
b_ptr,
M,
K,
N,
stride_am,
stride_ak,
stride_bk,
stride_bn,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
scale,
)#

Performs the matrix multiplication AB.

A is an M x K matrix and B is an N x K matrix. The result is returned to be stored by the calling method.

nemo_automodel.components._peft.lora_kernel.block_vector_mul(
pid_m,
pid_n,
ab_result,
c_ptr,
d_ptr,
M,
N,
L,
stride_cn,
stride_cl,
stride_dm,
stride_dl,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_L: triton.language.constexpr,
)#

Multiplies an M x N vector AB and and N x L vector C and adds the result to the output vector D.

N is assumed to be smaller than BLOCK_SIZE_N.

nemo_automodel.components._peft.lora_kernel.lora_forward_kernel(
x_ptr,
la_ptr,
lb_ptr,
res_ptr,
M,
N,
K,
L,
stride_x_m,
stride_x_k,
stride_la_k,
stride_la_n,
stride_lb_n,
stride_lb_l,
stride_res_m,
stride_res_l,
scale,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
BLOCK_SIZE_L: triton.language.constexpr,
GROUP_SIZE_M: triton.language.constexpr,
)#

Kernel for computing the matmul D = A x B x C.

A has shape (M, K), B has shape (K, N), C has shape (N, L), and D has shape (M, L) N, the LoRA dimension must be less than or equal to than BLOCK_SIZE_N.

nemo_automodel.components._peft.lora_kernel.lora_forward_wrapper(
x,
lora_A,
lora_B,
res,
scale,
dtype=torch.float32,
)[source]#

Computes LoRA forward pass.

Parameters:
  • x – input activations, (M x K)

  • lora_A – LoRA A weights (K x N)

  • lora_B – LoRA B weights (N x L)

  • res (optional(torch.Tensor)) – output tensor

  • scale – LoRA scale factor (scalar)

  • dtype – dtype for output

nemo_automodel.components._peft.lora_kernel.da_dx_autotune_configs()[source]#

Method for generating Triton configs for lora_da_dx_kernel.

nemo_automodel.components._peft.lora_kernel.lora_da_dx_kernel(
dy_ptr,
b_ptr,
a_ptr,
dx_ptr,
dyb_ptr,
M,
K,
N,
L,
stride_dy_m,
stride_dy_k,
stride_lorab_k,
stride_lorab_n,
stride_loraa_n,
stride_loraa_l,
stride_dx_m,
stride_dx_l,
stride_dyb_m,
stride_dyb_n,
scale,
BLOCK_SIZE_M: triton.language.constexpr,
GROUP_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
BLOCK_SIZE_L: triton.language.constexpr,
)#

Kernel for computing the matmul DYB = DY x B and DX = DY * B * A.

XT has shape (S, M), DY has shape (M, K), B has shape (K, N), and A has shape (N, L) N, the LoRA dimension must be less than or equal to than BLOCK_SIZE_N. The result returned by this kernel is reduced in the wrapper.

nemo_automodel.components._peft.lora_kernel.lora_da_dx_update_wrapper(
xt,
dy,
lora_B,
lora_A,
scale,
dtype=torch.float32,
)[source]#

Computes dlora_A and dx.

xt: input activation weights, transposed (S x M) dy: gradients (M x K) lora_B: LoRA B weights (K x N) lora_A: LoRA A weights (N x L) scale: LoRA scale factor (scalar) dtype: dtype for output

nemo_automodel.components._peft.lora_kernel.db_autotune_configs()[source]#

Method for generating Triton configs for lora_db_kernel.

nemo_automodel.components._peft.lora_kernel.lora_db_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
K,
N,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
scale,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
GROUP_SIZE_M: triton.language.constexpr,
)#

Kernel for computing the matmul AXT = A x X^T.

A has shape (M, K), X has shape (N, K).

nemo_automodel.components._peft.lora_kernel.lora_db_update_wrapper(lora_A, xt, dy, scale, dtype=torch.float32)[source]#

Computes d_lora_B.

lora_A: LoRA A weights (M x K) xt: input activation weights, transposed (K x N) dy: gradients (N x S) scale: LoRA scale factor (scalar) dtype: dtype for output