nemo_automodel.components.datasets.dllm.corruption
nemo_automodel.components.datasets.dllm.corruption
Data corruption utilities for diffusion LLM (dLLM) training.
Provides masking strategies for dLLM SFT:
corrupt_uniform: uniform per-sequence corruptioncorrupt_blockwise: per-block weighted corruption with exponential position biascorrupt_uniform_random: per-block random-token (D3PM-uniform) corruption
Module Contents
Functions
API
Batched variable-k Gumbel top-k selection.
Vectorised replacement for calling :func:gumbel_topk in a Python loop.
Each row i of log_w gets exactly k[i] True entries selected
via the Gumbel-max trick.
Algebraically identical to per-row gumbel_topk: both select the k
indices with the highest log_w + Gumbel score. torch.sort gives
a full ranking; positions < k picks the top-k in sorted order;
scatter_ maps them back to original positions.
Parameters:
Log-weights, shape [N, D]. Positions that should never
be selected must be set to -inf.
Number of positions to select per row, shape [N].
Rows with k=0 produce an all-False mask.
Returns: torch.Tensor
Boolean mask of shape [N, D].
Two-stage corruption with optional per-block sampling.
This function combines three independent concerns that could be separated if a future model needs a different mix (e.g., blocks without position bias, or sequence-level with bias):
- Sampling scope — per-sequence vs per-block
msampling (controlled byblock_size). - Selection method — Gumbel-max top-k for exact-
kmasking. - Position bias — exponential weighting via
half_life_ratio.
Stage 1: Sample m ~ U(eps, 1) per sequence (or per block), compute
k = round(m * length) positions to mask.
Stage 2: Sample exactly k positions using exponentially weighted
probabilities w_i(m) = exp[lambda * (1-m) * i] which bias toward
later positions when m is small (few masks → mask later tokens) and
become uniform when m is large (many masks).
If block_size is given, stages 1 and 2 operate independently within
each contiguous block of that length.
All operations are fully vectorised (no Python loops over batch or
blocks) via :func:_batched_gumbel_topk.
Parameters:
Token IDs, shape [B, L].
Binary mask indicating supervised positions, shape [B, L].
The token ID used for masking.
If not None, operate block-wise with per-block m sampling.
Minimum corruption ratio.
Controls steepness of positional bias when m → 0.
Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Tuple of (noisy_input_ids, noise_mask, p_mask) each of shape [B, L].
Per-sequence uniform corruption for MDLM training.
For each sequence, sample t ~ U[0, 1] and derive a masking probability
p = (1 - eps) * t + eps. Each token at a supervised position (where
loss_mask == 1) is independently replaced with mask_token_id with
probability p.
Parameters:
Token IDs, shape [B, L].
Binary mask indicating supervised positions, shape [B, L].
The token ID used for masking.
Minimum corruption ratio.
Returns: torch.Tensor
Tuple of (noisy_input_ids, noise_mask, p_mask) each of shape [B, L].
Per-block uniform random-token (D3PM-uniform) corruption for block diffusion.
This is the diffusion_gemma corruption kernel. Unlike MDLM/LLaDA
absorbing-[MASK] corruption (:func:corrupt_uniform), corrupted
positions are replaced with a random token sampled uniformly over the full
vocabulary — there is no mask_token_id. This matches the checkpoint’s
own canvas initialisation / renoising (torch.randint over the vocab).
For each block (or the whole sequence when block_size is None) a single
corruption level t ~ U(eps, 1) is sampled; every supervised position in
that block is then independently replaced with probability t. No noise
schedule and no positional bias are applied (“everything is uniform”).
The returned p_mask is all ones: the uniform kernel uses a flat loss
(plain mean cross-entropy over corrupted tokens, no 1/t reweighting), so
there is no per-position probability to divide by. p_mask is kept in the
return signature only for plumbing compatibility with the MDLM loss path.
Parameters:
Clean token IDs, shape [B, L].
Binary mask of supervised positions, shape [B, L]. Only
supervised positions are ever corrupted.
Vocabulary size; replacement tokens are drawn from
[0, vocab_size).
If given, sample t per contiguous block of this length;
otherwise sample one t per sequence.
Minimum corruption level (lower bound of t).
Optional torch.Generator (on input_ids.device) used for
ALL random draws (t, the corruption mask, the replacement tokens).
Pass a step-seeded generator so the corruption is a deterministic
function of the training step and reproduces exactly on checkpoint
resume; None falls back to the global RNG (not resume-safe).
Returns: torch.Tensor
Tuple of (noisy_input_ids, noise_mask, p_mask), each shape [B, L].
Return a bool mask of length len(log_w) with exactly k True entries.
Uses the Gumbel-max trick for stochastic top-k selection.