nemo_rl.algorithms.x_token.loss_utils#

Shared utilities for cross-tokenizer distillation.

Hosts pieces that are used by both :mod:token_aligner (in this package) and :mod:nemo_rl.algorithms.loss.loss_functions:

- :class:`Fp32SparseMM` — FP32 sparse-dense matmul autograd Function
  that ignores the surrounding BF16 autocast (PyTorch has no BF16
  sparse-mm kernel).
- :func:`chunk_average_log_probs`, :func:`valid_chunk_mask` —
  chunk-aggregation helpers for the cross-tokenizer KL paths.
- :func:`dp_all_reduce_sum` — sum-reduce a scalar count across the
  data-parallel group so the chunk-KL denominator is the global
  valid-chunk count rather than a per-rank mean.
- :func:`parse_projection_file` — single source of truth for
  reading the on-disk projection matrix file (both the dense top-k
  format and the sparse ``dict[(s, t)] -> count`` format) into COO
  components. Callers retain their own validation / sizing rules.
- :func:`get_sparse_projection_matrix`, :func:`get_topk_projection`
  — process-local lazy caches for the materialized projection
  matrix on a given device. Driver processes never trigger a fill;
  each Ray worker populates its own cache on the first loss call.
- :func:`build_exact_token_map` — derived common/uncommon vocab
  partition for the gold-loss path. Cached per
  ``(path, device, xtoken_loss, teacher_vocab_size)`` because the
  partition depends on those four inputs.
- :func:`alignment_from_flat_batch` — rehydrate the flat
  ``alignment_*`` transport keys on the loss data dict into a single
  :class:`AlignmentBatch` so the loss bodies access alignment via
  attributes instead of repeating flat field names.

Module Contents#

Classes#

Fp32SparseMM

FP32 M.t() @ dense (sparse-dense matmul) ignoring surrounding autocast.

Functions#

alignment_from_flat_batch

Rebuild :class:AlignmentBatch from the flat alignment_* keys.

rebuild_teacher_full_logits_from_ipc

View-only rebuild of the microbatch’s teacher-logits slice from IPC.

chunk_average_log_probs

Average log_probs over the chunks defined by chunk_id.

valid_chunk_mask

Per-chunk validity gate: both sides non-empty and pair is valid.

dp_all_reduce_sum

Sum-reduce a scalar count across the data-parallel group.

parse_projection_file

Parse a projection-matrix file into COO components.

get_sparse_projection_matrix

Return the sparse-COO projection matrix on device (cached).

get_topk_projection

Return the dense top-k (indices, likelihoods) projection on device (cached).

build_exact_token_map

Build the common/uncommon vocab partition for the gold path (cached).

Data#

API#

nemo_rl.algorithms.x_token.loss_utils.alignment_from_flat_batch(
data: Mapping[str, Any],
) nemo_rl.algorithms.x_token.token_aligner.AlignmentBatch#

Rebuild :class:AlignmentBatch from the flat alignment_* keys.

The field set is driven off :class:AlignmentBatch so the helper can’t drift from the schema.

nemo_rl.algorithms.x_token.loss_utils.rebuild_teacher_full_logits_from_ipc(
data: Mapping[str, Any],
) torch.Tensor#

View-only rebuild of the microbatch’s teacher-logits slice from IPC.

The producer maintains a persistent IPC buffer on its GPU sized [B_r, T_t, V_t]; the buffer (and the IPC handle it was captured with) survives across training steps, with fresh logits .copy_()-ed in each step. Because the producer never frees the buffer between steps, holding a view into the IPC-imported storage is safe: the producer-side allocation isn’t fighting the consumer’s refcount, it’s simply alive for the worker’s lifetime.

Every per-sample entry in teacher_full_logits_ipc carries the same stable rank-level handle plus its rank-local sample_idx_within_rank. We rebuild that single handle once and slice [mb_start:mb_end] for the current microbatch – zero allocation on the consumer, dtype preserved (the loss fn casts if/where it needs fp32).

Parameters:

data – The loss data dict, carrying teacher_full_logits_ipc – a list of per-sample IPC handle dicts produced by Policy.get_full_logits_ipc.

Returns:

A [mb_B, T_t, V_t] view into the producer’s GPU memory (no copy).

class nemo_rl.algorithms.x_token.loss_utils.Fp32SparseMM#

Bases: torch.autograd.Function

FP32 M.t() @ dense (sparse-dense matmul) ignoring surrounding autocast.

addmm_sparse_cuda has no BF16 kernel on either forward or backward. The worker wraps forward + loss + backward in autocast(BF16), so a plain with autocast(enabled=False): around the forward call is not enough — loss.backward() runs inside the outer autocast and the sparse-mm backward kernel is still dispatched as BF16. The custom_fwd(cast_inputs=torch.float32) / custom_bwd decorators are PyTorch’s official escape: they force FP32 inputs on forward and run the backward as if autocast were disabled.

autograd’s builtin sparse-mm backward computes M @ grad_out. The gradient w.r.t. the sparse argument isn’t needed (the projection matrix is frozen), so it’s returned as None.

static forward(
ctx: Any,
sparse_M: torch.Tensor,
dense: torch.Tensor,
) torch.Tensor#
static backward(
ctx: Any,
grad_out: torch.Tensor,
) tuple[None, torch.Tensor]#
nemo_rl.algorithms.x_token.loss_utils.chunk_average_log_probs(
log_probs: torch.Tensor,
chunk_id: torch.Tensor,
max_chunks: int,
) tuple[torch.Tensor, torch.Tensor]#

Average log_probs over the chunks defined by chunk_id.

Builds a one-hot chunk mask from chunk_id (-1 means “no chunk”, contributes to no bucket), then bmm-aggregates and divides by chunk sizes.

Parameters:
  • log_probs[B, T, V] log-probabilities.

  • chunk_id[B, T] long tensor, values in [-1, max_chunks).

  • max_chunks – number of chunk buckets.

Returns:

[B, max_chunks, V] averaged log-probs. chunk_sizes: [B, max_chunks] float tensor of bucket sizes.

Return type:

chunk_log_probs

nemo_rl.algorithms.x_token.loss_utils.valid_chunk_mask(
s_sizes: torch.Tensor,
t_sizes: torch.Tensor,
pair_valid: torch.Tensor,
) torch.Tensor#

Per-chunk validity gate: both sides non-empty and pair is valid.

nemo_rl.algorithms.x_token.loss_utils.dp_all_reduce_sum(local: torch.Tensor) torch.Tensor#

Sum-reduce a scalar count across the data-parallel group.

Used to compute global_valid_chunks from each rank’s local chunk count, so the chunk-KL denominator matches the sum(global_valid_chunk_kl) / sum(global_valid_chunks) objective (the same convention CE follows via global_valid_toks). The cross-tokenizer setup asserts tensor_parallel_size=1 and context_parallel_size=1 in xtoken_off_policy_distillation.setup, so the default process group equals the DP group — calling all-reduce on the default group therefore sums across DP only.

Returns a fresh float32 scalar; the input tensor is not modified. Falls back to a copy of the local value when distributed is not initialized (unit tests).

nemo_rl.algorithms.x_token.loss_utils.parse_projection_file(
path: Union[str, os.PathLike],
) Tuple[torch.Tensor, torch.Tensor, int, int]#

Parse a projection-matrix file into COO components.

Detects either the dense top-k format (dict["indices"] / dict["likelihoods"]) or the sparse multi-token format (dict[(student_id, teacher_id)] -> count) and converts both to a uniform COO representation.

The function does not apply any sizing or validity policy: the -1 sentinel used by _exact_map_remapped projection files is preserved in the returned indices, and the inferred vocab sizes are derived from the file alone (caller may override them upward against tokenizer / config knowledge). This keeps a single parser while letting :mod:token_aligner and the loss fn keep their own clipping rules.

Parameters:

path – Path to a torch.saved projection-matrix file.

Returns:

LongTensor[2, nnz](student_idx, teacher_idx). values: FloatTensor[nnz]. v_student_inferred: int — dense format: row count; sparse format: max(student_idx) + 1. v_teacher_inferred: intmax(positive teacher_idx) + 1 (0 if no positive entries exist).

Return type:

indices

Raises:
  • FileNotFoundErrorpath does not exist.

  • ValueError – the file is not in a recognized format.

nemo_rl.algorithms.x_token.loss_utils._SPARSE_PROJECTION_CACHE: dict[Tuple[str, torch.device, int, int], torch.Tensor]#

None

nemo_rl.algorithms.x_token.loss_utils._TOPK_PROJECTION_CACHE: dict[Tuple[str, torch.device], Tuple[torch.Tensor, torch.Tensor]]#

None

nemo_rl.algorithms.x_token.loss_utils.get_sparse_projection_matrix(
path: Union[str, os.PathLike],
device: torch.device,
*,
student_vocab_size: int,
teacher_vocab_size: int,
) torch.Tensor#

Return the sparse-COO projection matrix on device (cached).

On a cache miss, parses the file via :func:parse_projection_file, drops -1 teacher sentinels (illegal in sparse-COO), sizes V_s = max(student_vocab_size, max_observed_student_idx + 1) and V_t = max(teacher_vocab_size, max_observed_teacher_idx + 1), and builds a coalesced torch.sparse_coo_tensor on device. Subsequent calls with the same (path, device, student_vocab_size, teacher_vocab_size) return the cached tensor — no disk I/O, no re-materialization.

Both vocab sizes are keyword-only to prevent a positional swap (two same-magnitude ints, no error if confused).

Parameters:
  • path – Path to a torch.saved projection-matrix file.

  • device – Device the sparse tensor must live on.

  • student_vocab_size – Minimum width of the student-side axis.

  • teacher_vocab_size – Minimum width of the teacher-side axis.

Returns:

torch.sparse_coo_tensor of shape (V_s, V_t), coalesced, dtype=float32.

nemo_rl.algorithms.x_token.loss_utils.get_topk_projection(
path: Union[str, os.PathLike],
device: torch.device,
) Tuple[torch.Tensor, torch.Tensor]#

Return the dense top-k (indices, likelihoods) projection on device (cached).

Used by the gold-loss exact-map builder, which needs the per-row top-k weights — the sparse dict[(s, t)] -> count projection format doesn’t carry those, so this loader rejects it.

Parameters:
  • path – Path to a torch.saved projection-matrix file.

  • device – Device the returned tensors must live on.

Returns:

(indices, likelihoods)LongTensor[V_s, top_k] and FloatTensor[V_s, top_k] on device.

Raises:
  • FileNotFoundErrorpath does not exist.

  • ValueError – the file is not in the dense top-k format.

nemo_rl.algorithms.x_token.loss_utils._EXACT_TOKEN_MAP_CACHE: dict[Tuple[str, torch.device, bool, int], Dict[str, torch.Tensor]]#

None

nemo_rl.algorithms.x_token.loss_utils.build_exact_token_map(
path: Union[str, os.PathLike],
device: torch.device,
*,
xtoken_loss: bool,
teacher_vocab_size: int,
) Dict[str, torch.Tensor]#

Build the common/uncommon vocab partition for the gold path (cached).

Reads the dense projection arrays via :func:get_topk_projection, sorts each student row’s projection weights descending, then picks an exact-token map per the xtoken_loss flag:

  • xtoken_loss=False (strict): has_exact_map = (sorted_values[:, 0] == 1.0) & (projection_indices[:, 1] == -1). On collision (multiple students mapping to the same teacher id), the earliest (lowest) student index wins.

  • xtoken_loss=True (relaxed): has_exact_map = sorted_values[:, 0] >= 0.6. On collision, the student with the highest first-projection weight wins; ties are broken by lowest student index.

Both branches are vectorized via scatter_reduce so the build is O(V_s) and happens once per (path, device, xtoken_loss, teacher_vocab_size) for the run.

Parameters:
  • path – Path to a torch.saved projection-matrix file (dense top-k format).

  • device – Device the returned tensors must live on.

  • xtoken_loss – Selects strict vs relaxed exact-map rule (see above).

  • teacher_vocab_size – Width of the teacher-side vocab axis. The partition is bounded by this — teacher ids outside the range are dropped.

Returns:

Dict with keys common_student, common_teacher (paired), uncommon_student, uncommon_teacher (each independently sorted). All [long] tensors on device.