nemo_automodel.components.datasets.dllm.corruption#

Data corruption utilities for diffusion LLM (dLLM) training.

Provides masking strategies for dLLM SFT:

  • corrupt_uniform: uniform per-sequence corruption

  • corrupt_blockwise: per-block weighted corruption with exponential position bias

Module Contents#

Functions#

gumbel_topk

Return a bool mask of length len(log_w) with exactly k True entries.

_batched_gumbel_topk

Batched variable-k Gumbel top-k selection.

corrupt_uniform

Per-sequence uniform corruption for MDLM training.

corrupt_blockwise

Two-stage corruption with optional per-block sampling.

API#

nemo_automodel.components.datasets.dllm.corruption.gumbel_topk(log_w: torch.Tensor, k: int) torch.Tensor#

Return a bool mask of length len(log_w) with exactly k True entries.

Uses the Gumbel-max trick for stochastic top-k selection.

nemo_automodel.components.datasets.dllm.corruption._batched_gumbel_topk(
log_w: torch.Tensor,
k: torch.Tensor,
) torch.Tensor#

Batched variable-k Gumbel top-k selection.

Vectorised replacement for calling :func:gumbel_topk in a Python loop. Each row i of log_w gets exactly k[i] True entries selected via the Gumbel-max trick.

Algebraically identical to per-row gumbel_topk: both select the k indices with the highest log_w + Gumbel score. torch.sort gives a full ranking; positions < k picks the top-k in sorted order; scatter_ maps them back to original positions.

Parameters:
  • log_w – Log-weights, shape [N, D]. Positions that should never be selected must be set to -inf.

  • k – Number of positions to select per row, shape [N]. Rows with k=0 produce an all-False mask.

Returns:

Boolean mask of shape [N, D].

nemo_automodel.components.datasets.dllm.corruption.corrupt_uniform(
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
mask_token_id: int,
eps: float = 0.001,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Per-sequence uniform corruption for MDLM training.

For each sequence, sample t ~ U[0, 1] and derive a masking probability p = (1 - eps) * t + eps. Each token at a supervised position (where loss_mask == 1) is independently replaced with mask_token_id with probability p.

Parameters:
  • input_ids – Token IDs, shape [B, L].

  • loss_mask – Binary mask indicating supervised positions, shape [B, L].

  • mask_token_id – The token ID used for masking.

  • eps – Minimum corruption ratio.

Returns:

Tuple of (noisy_input_ids, noise_mask, p_mask) each of shape [B, L].

  • noisy_input_ids: input_ids with masked positions replaced.

  • noise_mask: bool mask of which positions were corrupted.

  • p_mask: per-position masking probability (float32).

nemo_automodel.components.datasets.dllm.corruption.corrupt_blockwise(
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
mask_token_id: int,
block_size: int | None = None,
eps: float = 0.001,
half_life_ratio: float = 0.25,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Two-stage corruption with optional per-block sampling.

This function combines three independent concerns that could be separated if a future model needs a different mix (e.g., blocks without position bias, or sequence-level with bias):

  1. Sampling scope — per-sequence vs per-block m sampling (controlled by block_size).

  2. Selection method — Gumbel-max top-k for exact-k masking.

  3. Position bias — exponential weighting via half_life_ratio.

Stage 1: Sample m ~ U(eps, 1) per sequence (or per block), compute k = round(m * length) positions to mask.

Stage 2: Sample exactly k positions using exponentially weighted probabilities w_i(m) = exp[lambda * (1-m) * i] which bias toward later positions when m is small (few masks → mask later tokens) and become uniform when m is large (many masks).

If block_size is given, stages 1 and 2 operate independently within each contiguous block of that length.

All operations are fully vectorised (no Python loops over batch or blocks) via :func:_batched_gumbel_topk.

Parameters:
  • input_ids – Token IDs, shape [B, L].

  • loss_mask – Binary mask indicating supervised positions, shape [B, L].

  • mask_token_id – The token ID used for masking.

  • block_size – If not None, operate block-wise with per-block m sampling.

  • eps – Minimum corruption ratio.

  • half_life_ratio – Controls steepness of positional bias when m → 0.

Returns:

Tuple of (noisy_input_ids, noise_mask, p_mask) each of shape [B, L].