nemo_rl.algorithms.x_token.loss_utils#
Shared utilities for cross-tokenizer distillation.
Hosts pieces that are used by both :mod:token_aligner (in this package)
and :mod:nemo_rl.algorithms.loss.loss_functions:
- :class:`Fp32SparseMM` — FP32 sparse-dense matmul autograd Function
that ignores the surrounding BF16 autocast (PyTorch has no BF16
sparse-mm kernel).
- :func:`chunk_average_log_probs`, :func:`valid_chunk_mask` —
chunk-aggregation helpers for the cross-tokenizer KL paths.
- :func:`dp_all_reduce_sum` — sum-reduce a scalar count across the
data-parallel group so the chunk-KL denominator is the global
valid-chunk count rather than a per-rank mean.
- :func:`parse_projection_file` — single source of truth for
reading the on-disk projection matrix file (both the dense top-k
format and the sparse ``dict[(s, t)] -> count`` format) into COO
components. Callers retain their own validation / sizing rules.
- :func:`get_sparse_projection_matrix`, :func:`get_topk_projection`
— process-local lazy caches for the materialized projection
matrix on a given device. Driver processes never trigger a fill;
each Ray worker populates its own cache on the first loss call.
- :func:`build_exact_token_map` — derived common/uncommon vocab
partition for the gold-loss path. Cached per
``(path, device, xtoken_loss, teacher_vocab_size)`` because the
partition depends on those four inputs.
- :func:`alignment_from_flat_batch` — rehydrate the flat
``alignment_*`` transport keys on the loss data dict into a single
:class:`AlignmentBatch` so the loss bodies access alignment via
attributes instead of repeating flat field names.
Module Contents#
Classes#
FP32 |
Functions#
Rebuild :class: |
|
View-only rebuild of the microbatch’s teacher-logits slice from IPC. |
|
Average |
|
Per-chunk validity gate: both sides non-empty and pair is valid. |
|
Sum-reduce a scalar count across the data-parallel group. |
|
Parse a projection-matrix file into COO components. |
|
Return the sparse-COO projection matrix on |
|
Return the dense top-k |
|
Build the common/uncommon vocab partition for the gold path (cached). |
Data#
API#
- nemo_rl.algorithms.x_token.loss_utils.alignment_from_flat_batch(
- data: Mapping[str, Any],
Rebuild :class:
AlignmentBatchfrom the flatalignment_*keys.The field set is driven off :class:
AlignmentBatchso the helper can’t drift from the schema.
- nemo_rl.algorithms.x_token.loss_utils.rebuild_teacher_full_logits_from_ipc(
- data: Mapping[str, Any],
View-only rebuild of the microbatch’s teacher-logits slice from IPC.
The producer maintains a persistent IPC buffer on its GPU sized
[B_r, T_t, V_t]; the buffer (and the IPC handle it was captured with) survives across training steps, with fresh logits.copy_()-ed in each step. Because the producer never frees the buffer between steps, holding a view into the IPC-imported storage is safe: the producer-side allocation isn’t fighting the consumer’s refcount, it’s simply alive for the worker’s lifetime.Every per-sample entry in
teacher_full_logits_ipccarries the same stable rank-level handle plus its rank-localsample_idx_within_rank. We rebuild that single handle once and slice[mb_start:mb_end]for the current microbatch – zero allocation on the consumer, dtype preserved (the loss fn casts if/where it needs fp32).- Parameters:
data – The loss data dict, carrying
teacher_full_logits_ipc– a list of per-sample IPC handle dicts produced byPolicy.get_full_logits_ipc.- Returns:
A
[mb_B, T_t, V_t]view into the producer’s GPU memory (no copy).
- class nemo_rl.algorithms.x_token.loss_utils.Fp32SparseMM#
Bases:
torch.autograd.FunctionFP32
M.t() @ dense(sparse-dense matmul) ignoring surrounding autocast.addmm_sparse_cudahas no BF16 kernel on either forward or backward. The worker wraps forward + loss + backward inautocast(BF16), so a plainwith autocast(enabled=False):around the forward call is not enough —loss.backward()runs inside the outer autocast and the sparse-mm backward kernel is still dispatched as BF16. Thecustom_fwd(cast_inputs=torch.float32)/custom_bwddecorators are PyTorch’s official escape: they force FP32 inputs on forward and run the backward as if autocast were disabled.autograd’s builtin sparse-mm backward computes
M @ grad_out. The gradient w.r.t. the sparse argument isn’t needed (the projection matrix is frozen), so it’s returned asNone.- static forward(
- ctx: Any,
- sparse_M: torch.Tensor,
- dense: torch.Tensor,
- static backward(
- ctx: Any,
- grad_out: torch.Tensor,
- nemo_rl.algorithms.x_token.loss_utils.chunk_average_log_probs(
- log_probs: torch.Tensor,
- chunk_id: torch.Tensor,
- max_chunks: int,
Average
log_probsover the chunks defined bychunk_id.Builds a one-hot chunk mask from
chunk_id(-1means “no chunk”, contributes to no bucket), thenbmm-aggregates and divides by chunk sizes.- Parameters:
log_probs –
[B, T, V]log-probabilities.chunk_id –
[B, T]long tensor, values in[-1, max_chunks).max_chunks – number of chunk buckets.
- Returns:
[B, max_chunks, V]averaged log-probs. chunk_sizes:[B, max_chunks]float tensor of bucket sizes.- Return type:
chunk_log_probs
- nemo_rl.algorithms.x_token.loss_utils.valid_chunk_mask(
- s_sizes: torch.Tensor,
- t_sizes: torch.Tensor,
- pair_valid: torch.Tensor,
Per-chunk validity gate: both sides non-empty and pair is valid.
- nemo_rl.algorithms.x_token.loss_utils.dp_all_reduce_sum(local: torch.Tensor) torch.Tensor#
Sum-reduce a scalar count across the data-parallel group.
Used to compute
global_valid_chunksfrom each rank’s local chunk count, so the chunk-KL denominator matches thesum(global_valid_chunk_kl) / sum(global_valid_chunks)objective (the same convention CE follows viaglobal_valid_toks). The cross-tokenizer setup assertstensor_parallel_size=1andcontext_parallel_size=1inxtoken_off_policy_distillation.setup, so the default process group equals the DP group — calling all-reduce on the default group therefore sums across DP only.Returns a fresh
float32scalar; the input tensor is not modified. Falls back to a copy of the local value when distributed is not initialized (unit tests).
- nemo_rl.algorithms.x_token.loss_utils.parse_projection_file(
- path: Union[str, os.PathLike],
Parse a projection-matrix file into COO components.
Detects either the dense top-k format (
dict["indices"]/dict["likelihoods"]) or the sparse multi-token format (dict[(student_id, teacher_id)] -> count) and converts both to a uniform COO representation.The function does not apply any sizing or validity policy: the
-1sentinel used by_exact_map_remappedprojection files is preserved in the returnedindices, and the inferred vocab sizes are derived from the file alone (caller may override them upward against tokenizer / config knowledge). This keeps a single parser while letting :mod:token_alignerand the loss fn keep their own clipping rules.- Parameters:
path – Path to a
torch.saved projection-matrix file.- Returns:
LongTensor[2, nnz]—(student_idx, teacher_idx). values:FloatTensor[nnz]. v_student_inferred:int— dense format: row count; sparse format:max(student_idx) + 1. v_teacher_inferred:int—max(positive teacher_idx) + 1(0if no positive entries exist).- Return type:
indices
- Raises:
FileNotFoundError –
pathdoes not exist.ValueError – the file is not in a recognized format.
- nemo_rl.algorithms.x_token.loss_utils._SPARSE_PROJECTION_CACHE: dict[Tuple[str, torch.device, int, int], torch.Tensor]#
None
- nemo_rl.algorithms.x_token.loss_utils._TOPK_PROJECTION_CACHE: dict[Tuple[str, torch.device], Tuple[torch.Tensor, torch.Tensor]]#
None
- nemo_rl.algorithms.x_token.loss_utils.get_sparse_projection_matrix(
- path: Union[str, os.PathLike],
- device: torch.device,
- *,
- student_vocab_size: int,
- teacher_vocab_size: int,
Return the sparse-COO projection matrix on
device(cached).On a cache miss, parses the file via :func:
parse_projection_file, drops-1teacher sentinels (illegal in sparse-COO), sizesV_s = max(student_vocab_size, max_observed_student_idx + 1)andV_t = max(teacher_vocab_size, max_observed_teacher_idx + 1), and builds a coalescedtorch.sparse_coo_tensorondevice. Subsequent calls with the same(path, device, student_vocab_size, teacher_vocab_size)return the cached tensor — no disk I/O, no re-materialization.Both vocab sizes are keyword-only to prevent a positional swap (two same-magnitude ints, no error if confused).
- Parameters:
path – Path to a
torch.saved projection-matrix file.device – Device the sparse tensor must live on.
student_vocab_size – Minimum width of the student-side axis.
teacher_vocab_size – Minimum width of the teacher-side axis.
- Returns:
torch.sparse_coo_tensorof shape(V_s, V_t), coalesced,dtype=float32.
- nemo_rl.algorithms.x_token.loss_utils.get_topk_projection(
- path: Union[str, os.PathLike],
- device: torch.device,
Return the dense top-k
(indices, likelihoods)projection ondevice(cached).Used by the gold-loss exact-map builder, which needs the per-row top-k weights — the sparse
dict[(s, t)] -> countprojection format doesn’t carry those, so this loader rejects it.- Parameters:
path – Path to a
torch.saved projection-matrix file.device – Device the returned tensors must live on.
- Returns:
(indices, likelihoods)—LongTensor[V_s, top_k]andFloatTensor[V_s, top_k]ondevice.- Raises:
FileNotFoundError –
pathdoes not exist.ValueError – the file is not in the dense top-k format.
- nemo_rl.algorithms.x_token.loss_utils._EXACT_TOKEN_MAP_CACHE: dict[Tuple[str, torch.device, bool, int], Dict[str, torch.Tensor]]#
None
- nemo_rl.algorithms.x_token.loss_utils.build_exact_token_map(
- path: Union[str, os.PathLike],
- device: torch.device,
- *,
- xtoken_loss: bool,
- teacher_vocab_size: int,
Build the common/uncommon vocab partition for the gold path (cached).
Reads the dense projection arrays via :func:
get_topk_projection, sorts each student row’s projection weights descending, then picks an exact-token map per thextoken_lossflag:xtoken_loss=False(strict):has_exact_map = (sorted_values[:, 0] == 1.0) & (projection_indices[:, 1] == -1). On collision (multiple students mapping to the same teacher id), the earliest (lowest) student index wins.xtoken_loss=True(relaxed):has_exact_map = sorted_values[:, 0] >= 0.6. On collision, the student with the highest first-projection weight wins; ties are broken by lowest student index.
Both branches are vectorized via
scatter_reduceso the build is O(V_s) and happens once per(path, device, xtoken_loss, teacher_vocab_size)for the run.- Parameters:
path – Path to a
torch.saved projection-matrix file (dense top-k format).device – Device the returned tensors must live on.
xtoken_loss – Selects strict vs relaxed exact-map rule (see above).
teacher_vocab_size – Width of the teacher-side vocab axis. The partition is bounded by this — teacher ids outside the range are dropped.
- Returns:
Dict with keys
common_student,common_teacher(paired),uncommon_student,uncommon_teacher(each independently sorted). All[long]tensors ondevice.