nemo_automodel.components._peft.lora_kernel
#
Module Contents#
Functions#
Method for generating Triton configs for lora_forward_kernel. |
|
Converts one-dimensional triton pids into two dimensions. |
|
Performs the matrix multiplication AB. |
|
Multiplies an M x N vector AB and and N x L vector C and adds the result to the output vector D. |
|
Kernel for computing the matmul D = A x B x C. |
|
Computes LoRA forward pass. |
|
Method for generating Triton configs for lora_da_dx_kernel. |
|
Kernel for computing the matmul DYB = DY x B and DX = DY * B * A. |
|
Computes dlora_A and dx. |
|
Method for generating Triton configs for lora_db_kernel. |
|
Kernel for computing the matmul AXT = A x X^T. |
|
Computes d_lora_B. |
API#
- nemo_automodel.components._peft.lora_kernel.forward_autotune_configs()[source]#
Method for generating Triton configs for lora_forward_kernel.
- nemo_automodel.components._peft.lora_kernel.get_pid_coords(
- M,
- N,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- GROUP_SIZE_M: triton.language.constexpr,
Converts one-dimensional triton pids into two dimensions.
- nemo_automodel.components._peft.lora_kernel.inner_kernel(
- pid_m,
- pid_n,
- a_ptr,
- b_ptr,
- M,
- K,
- N,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- scale,
Performs the matrix multiplication AB.
A is an M x K matrix and B is an N x K matrix. The result is returned to be stored by the calling method.
- nemo_automodel.components._peft.lora_kernel.block_vector_mul(
- pid_m,
- pid_n,
- ab_result,
- c_ptr,
- d_ptr,
- M,
- N,
- L,
- stride_cn,
- stride_cl,
- stride_dm,
- stride_dl,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_L: triton.language.constexpr,
Multiplies an M x N vector AB and and N x L vector C and adds the result to the output vector D.
N is assumed to be smaller than BLOCK_SIZE_N.
- nemo_automodel.components._peft.lora_kernel.lora_forward_kernel(
- x_ptr,
- la_ptr,
- lb_ptr,
- res_ptr,
- M,
- N,
- K,
- L,
- stride_x_m,
- stride_x_k,
- stride_la_k,
- stride_la_n,
- stride_lb_n,
- stride_lb_l,
- stride_res_m,
- stride_res_l,
- scale,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- BLOCK_SIZE_L: triton.language.constexpr,
- GROUP_SIZE_M: triton.language.constexpr,
Kernel for computing the matmul D = A x B x C.
A has shape (M, K), B has shape (K, N), C has shape (N, L), and D has shape (M, L) N, the LoRA dimension must be less than or equal to than BLOCK_SIZE_N.
- nemo_automodel.components._peft.lora_kernel.lora_forward_wrapper(
- x,
- lora_A,
- lora_B,
- res,
- scale,
- dtype=torch.float32,
Computes LoRA forward pass.
- Parameters:
x – input activations, (M x K)
lora_A – LoRA A weights (K x N)
lora_B – LoRA B weights (N x L)
res (optional(torch.Tensor)) – output tensor
scale – LoRA scale factor (scalar)
dtype – dtype for output
- nemo_automodel.components._peft.lora_kernel.da_dx_autotune_configs()[source]#
Method for generating Triton configs for lora_da_dx_kernel.
- nemo_automodel.components._peft.lora_kernel.lora_da_dx_kernel(
- dy_ptr,
- b_ptr,
- a_ptr,
- dx_ptr,
- dyb_ptr,
- M,
- K,
- N,
- L,
- stride_dy_m,
- stride_dy_k,
- stride_lorab_k,
- stride_lorab_n,
- stride_loraa_n,
- stride_loraa_l,
- stride_dx_m,
- stride_dx_l,
- stride_dyb_m,
- stride_dyb_n,
- scale,
- BLOCK_SIZE_M: triton.language.constexpr,
- GROUP_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- BLOCK_SIZE_L: triton.language.constexpr,
Kernel for computing the matmul DYB = DY x B and DX = DY * B * A.
XT has shape (S, M), DY has shape (M, K), B has shape (K, N), and A has shape (N, L) N, the LoRA dimension must be less than or equal to than BLOCK_SIZE_N. The result returned by this kernel is reduced in the wrapper.
- nemo_automodel.components._peft.lora_kernel.lora_da_dx_update_wrapper(
- xt,
- dy,
- lora_B,
- lora_A,
- scale,
- dtype=torch.float32,
Computes dlora_A and dx.
xt: input activation weights, transposed (S x M) dy: gradients (M x K) lora_B: LoRA B weights (K x N) lora_A: LoRA A weights (N x L) scale: LoRA scale factor (scalar) dtype: dtype for output
- nemo_automodel.components._peft.lora_kernel.db_autotune_configs()[source]#
Method for generating Triton configs for lora_db_kernel.
- nemo_automodel.components._peft.lora_kernel.lora_db_kernel(
- a_ptr,
- b_ptr,
- c_ptr,
- M,
- K,
- N,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- scale,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- GROUP_SIZE_M: triton.language.constexpr,
Kernel for computing the matmul AXT = A x X^T.
A has shape (M, K), X has shape (N, K).