nemo_automodel.components.datasets.dllm.corruption

View as Markdown

Data corruption utilities for diffusion LLM (dLLM) training.

Provides masking strategies for dLLM SFT:

  • corrupt_uniform: uniform per-sequence corruption
  • corrupt_blockwise: per-block weighted corruption with exponential position bias
  • corrupt_uniform_random: per-block random-token (D3PM-uniform) corruption

Module Contents

Functions

NameDescription
_batched_gumbel_topkBatched variable-k Gumbel top-k selection.
corrupt_blockwiseTwo-stage corruption with optional per-block sampling.
corrupt_uniformPer-sequence uniform corruption for MDLM training.
corrupt_uniform_randomPer-block uniform random-token (D3PM-uniform) corruption for block diffusion.
gumbel_topkReturn a bool mask of length len(log_w) with exactly k True entries.

API

nemo_automodel.components.datasets.dllm.corruption._batched_gumbel_topk(
log_w: torch.Tensor,
k: torch.Tensor
) -> torch.Tensor

Batched variable-k Gumbel top-k selection.

Vectorised replacement for calling :func:gumbel_topk in a Python loop. Each row i of log_w gets exactly k[i] True entries selected via the Gumbel-max trick.

Algebraically identical to per-row gumbel_topk: both select the k indices with the highest log_w + Gumbel score. torch.sort gives a full ranking; positions < k picks the top-k in sorted order; scatter_ maps them back to original positions.

Parameters:

log_w
torch.Tensor

Log-weights, shape [N, D]. Positions that should never be selected must be set to -inf.

k
torch.Tensor

Number of positions to select per row, shape [N]. Rows with k=0 produce an all-False mask.

Returns: torch.Tensor

Boolean mask of shape [N, D].

nemo_automodel.components.datasets.dllm.corruption.corrupt_blockwise(
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
mask_token_id: int,
block_size: int | None = None,
eps: float = 0.001,
half_life_ratio: float = 0.25
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Two-stage corruption with optional per-block sampling.

This function combines three independent concerns that could be separated if a future model needs a different mix (e.g., blocks without position bias, or sequence-level with bias):

  1. Sampling scope — per-sequence vs per-block m sampling (controlled by block_size).
  2. Selection method — Gumbel-max top-k for exact-k masking.
  3. Position bias — exponential weighting via half_life_ratio.

Stage 1: Sample m ~ U(eps, 1) per sequence (or per block), compute k = round(m * length) positions to mask.

Stage 2: Sample exactly k positions using exponentially weighted probabilities w_i(m) = exp[lambda * (1-m) * i] which bias toward later positions when m is small (few masks → mask later tokens) and become uniform when m is large (many masks).

If block_size is given, stages 1 and 2 operate independently within each contiguous block of that length.

All operations are fully vectorised (no Python loops over batch or blocks) via :func:_batched_gumbel_topk.

Parameters:

input_ids
torch.Tensor

Token IDs, shape [B, L].

loss_mask
torch.Tensor

Binary mask indicating supervised positions, shape [B, L].

mask_token_id
int

The token ID used for masking.

block_size
int | NoneDefaults to None

If not None, operate block-wise with per-block m sampling.

eps
floatDefaults to 0.001

Minimum corruption ratio.

half_life_ratio
floatDefaults to 0.25

Controls steepness of positional bias when m → 0.

Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Tuple of (noisy_input_ids, noise_mask, p_mask) each of shape [B, L].

nemo_automodel.components.datasets.dllm.corruption.corrupt_uniform(
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
mask_token_id: int,
eps: float = 0.001
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Per-sequence uniform corruption for MDLM training.

For each sequence, sample t ~ U[0, 1] and derive a masking probability p = (1 - eps) * t + eps. Each token at a supervised position (where loss_mask == 1) is independently replaced with mask_token_id with probability p.

Parameters:

input_ids
torch.Tensor

Token IDs, shape [B, L].

loss_mask
torch.Tensor

Binary mask indicating supervised positions, shape [B, L].

mask_token_id
int

The token ID used for masking.

eps
floatDefaults to 0.001

Minimum corruption ratio.

Returns: torch.Tensor

Tuple of (noisy_input_ids, noise_mask, p_mask) each of shape [B, L].

nemo_automodel.components.datasets.dllm.corruption.corrupt_uniform_random(
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
vocab_size: int,
block_size: int | None = None,
eps: float = 0.001,
generator: torch.Generator | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Per-block uniform random-token (D3PM-uniform) corruption for block diffusion.

This is the diffusion_gemma corruption kernel. Unlike MDLM/LLaDA absorbing-[MASK] corruption (:func:corrupt_uniform), corrupted positions are replaced with a random token sampled uniformly over the full vocabulary — there is no mask_token_id. This matches the checkpoint’s own canvas initialisation / renoising (torch.randint over the vocab).

For each block (or the whole sequence when block_size is None) a single corruption level t ~ U(eps, 1) is sampled; every supervised position in that block is then independently replaced with probability t. No noise schedule and no positional bias are applied (“everything is uniform”).

The returned p_mask is all ones: the uniform kernel uses a flat loss (plain mean cross-entropy over corrupted tokens, no 1/t reweighting), so there is no per-position probability to divide by. p_mask is kept in the return signature only for plumbing compatibility with the MDLM loss path.

Parameters:

input_ids
torch.Tensor

Clean token IDs, shape [B, L].

loss_mask
torch.Tensor

Binary mask of supervised positions, shape [B, L]. Only supervised positions are ever corrupted.

vocab_size
int

Vocabulary size; replacement tokens are drawn from [0, vocab_size).

block_size
int | NoneDefaults to None

If given, sample t per contiguous block of this length; otherwise sample one t per sequence.

eps
floatDefaults to 0.001

Minimum corruption level (lower bound of t).

generator
torch.Generator | NoneDefaults to None

Optional torch.Generator (on input_ids.device) used for ALL random draws (t, the corruption mask, the replacement tokens). Pass a step-seeded generator so the corruption is a deterministic function of the training step and reproduces exactly on checkpoint resume; None falls back to the global RNG (not resume-safe).

Returns: torch.Tensor

Tuple of (noisy_input_ids, noise_mask, p_mask), each shape [B, L].

nemo_automodel.components.datasets.dllm.corruption.gumbel_topk(
log_w: torch.Tensor,
k: int
) -> torch.Tensor

Return a bool mask of length len(log_w) with exactly k True entries.

Uses the Gumbel-max trick for stochastic top-k selection.