nemo_automodel.components.datasets.dllm.corruption#
Data corruption utilities for diffusion LLM (dLLM) training.
Provides masking strategies for dLLM SFT:
corrupt_uniform: uniform per-sequence corruptioncorrupt_blockwise: per-block weighted corruption with exponential position bias
Module Contents#
Functions#
Return a bool mask of length |
|
Batched variable-k Gumbel top-k selection. |
|
Per-sequence uniform corruption for MDLM training. |
|
Two-stage corruption with optional per-block sampling. |
API#
- nemo_automodel.components.datasets.dllm.corruption.gumbel_topk(log_w: torch.Tensor, k: int) torch.Tensor#
Return a bool mask of length
len(log_w)with exactly kTrueentries.Uses the Gumbel-max trick for stochastic top-k selection.
- nemo_automodel.components.datasets.dllm.corruption._batched_gumbel_topk(
- log_w: torch.Tensor,
- k: torch.Tensor,
Batched variable-k Gumbel top-k selection.
Vectorised replacement for calling :func:
gumbel_topkin a Python loop. Each rowiof log_w gets exactlyk[i]Trueentries selected via the Gumbel-max trick.Algebraically identical to per-row
gumbel_topk: both select the k indices with the highestlog_w + Gumbelscore.torch.sortgives a full ranking;positions < kpicks the top-k in sorted order;scatter_maps them back to original positions.- Parameters:
log_w – Log-weights, shape
[N, D]. Positions that should never be selected must be set to-inf.k – Number of positions to select per row, shape
[N]. Rows withk=0produce an all-False mask.
- Returns:
Boolean mask of shape
[N, D].
- nemo_automodel.components.datasets.dllm.corruption.corrupt_uniform(
- input_ids: torch.Tensor,
- loss_mask: torch.Tensor,
- mask_token_id: int,
- eps: float = 0.001,
Per-sequence uniform corruption for MDLM training.
For each sequence, sample
t ~ U[0, 1]and derive a masking probabilityp = (1 - eps) * t + eps. Each token at a supervised position (whereloss_mask == 1) is independently replaced withmask_token_idwith probability p.- Parameters:
input_ids – Token IDs, shape
[B, L].loss_mask – Binary mask indicating supervised positions, shape
[B, L].mask_token_id – The token ID used for masking.
eps – Minimum corruption ratio.
- Returns:
Tuple of
(noisy_input_ids, noise_mask, p_mask)each of shape[B, L].noisy_input_ids: input_ids with masked positions replaced.noise_mask: bool mask of which positions were corrupted.p_mask: per-position masking probability (float32).
- nemo_automodel.components.datasets.dllm.corruption.corrupt_blockwise(
- input_ids: torch.Tensor,
- loss_mask: torch.Tensor,
- mask_token_id: int,
- block_size: int | None = None,
- eps: float = 0.001,
- half_life_ratio: float = 0.25,
Two-stage corruption with optional per-block sampling.
This function combines three independent concerns that could be separated if a future model needs a different mix (e.g., blocks without position bias, or sequence-level with bias):
Sampling scope — per-sequence vs per-block
msampling (controlled byblock_size).Selection method — Gumbel-max top-k for exact-
kmasking.Position bias — exponential weighting via
half_life_ratio.
Stage 1: Sample
m ~ U(eps, 1)per sequence (or per block), computek = round(m * length)positions to mask.Stage 2: Sample exactly k positions using exponentially weighted probabilities
w_i(m) = exp[lambda * (1-m) * i]which bias toward later positions whenmis small (few masks → mask later tokens) and become uniform whenmis large (many masks).If
block_sizeis given, stages 1 and 2 operate independently within each contiguous block of that length.All operations are fully vectorised (no Python loops over batch or blocks) via :func:
_batched_gumbel_topk.- Parameters:
input_ids – Token IDs, shape
[B, L].loss_mask – Binary mask indicating supervised positions, shape
[B, L].mask_token_id – The token ID used for masking.
block_size – If not None, operate block-wise with per-block m sampling.
eps – Minimum corruption ratio.
half_life_ratio – Controls steepness of positional bias when
m → 0.
- Returns:
Tuple of
(noisy_input_ids, noise_mask, p_mask)each of shape[B, L].