nemo_rl.algorithms.x_token.token_aligner#
Cross-tokenizer token alignment.
The DP kernel, canonicalization helpers, anchor optimization, and post-processing implement the cross-tokenizer alignment consumed by the loss. Anything unrelated to running cross-tokenizer alignment for off-policy distillation (rule tracking, learnable projection, MSE loss, multiple compute_loss variants, accuracy, translation) is dropped.
Public surface:
- :class:AlignmentPair — per-pair record produced by the DP / anchor
pipeline; replaces the loose (s_tokens, t_tokens, s_start, s_end, t_start, t_end, is_correct) tuples that the helpers used to pass
around.
- :class:AlignmentBatch — dense-padded per-batch alignment payload that
covers all three loss modes (P-KL, gold_loss, xtoken_loss).
- :class:TokenAligner — owns the two tokenizers and the projection
matrix, exposes :meth:align for the collator.
Module Contents#
Classes#
One aligned span between student and teacher token sequences. |
|
Per-batch alignment payload covering all three loss modes. |
|
Aligns student and teacher tokenizations of the same source text. |
Functions#
Return a canonical representation of a tokenizer token. |
|
Canonicalize every token in a sequence, including byte-merging. |
|
Merge known multi-token mojibake patterns into single tokens. |
|
Return the byte value (0..255) for a single character, or None. |
|
Merge consecutive byte-fallback tokens back into Unicode characters. |
|
Decode 2-4 buffered byte tokens as a single UTF-8 character. |
|
Compare two strings, optionally after canonicalization. |
Data#
API#
- nemo_rl.algorithms.x_token.token_aligner.VISUAL_BYTE_MAP#
None
- nemo_rl.algorithms.x_token.token_aligner._MULTI_TOKEN_ARTIFACT_FIXES#
[([‘ĠâĪ’, ‘ij’], [‘Ġ∑’]), ([‘âĪ’, ‘ij’], [’∑’]), ([‘ĠâĪ’, ‘ı’], [‘Ġ∏’]), ([‘âĪ’, ‘ı’], [’∏’]), ([‘ĠâĪ’…
- nemo_rl.algorithms.x_token.token_aligner._UNICODE_FIXES#
None
- nemo_rl.algorithms.x_token.token_aligner._SPECIAL_TOKEN_MAP#
None
- class nemo_rl.algorithms.x_token.token_aligner.AlignmentPair#
One aligned span between student and teacher token sequences.
The DP / anchor / post-process helpers construct these as they trace the alignment;
_align_singlethen fills inis_correctfrom the canonicalized-text comparison. Insertions/deletions use-1for the empty side’s start/end indices... attribute:: s_tokens
Student tokens covered by this pair.
.. attribute:: t_tokens
Teacher tokens covered by this pair.
.. attribute:: s_start
Inclusive start index into the student token sequence (
-1for teacher-only insertions)... attribute:: s_end
Exclusive end index into the student token sequence (
-1for teacher-only insertions)... attribute:: t_start
Inclusive start index into the teacher token sequence (
-1for student-only insertions)... attribute:: t_end
Exclusive end index into the teacher token sequence (
-1for student-only insertions)... attribute:: is_correct
Truewhen the canonicalized student span text matches the canonicalized teacher span text. Defaults toFalseso DP / anchor stages can build pairs without computing the mask up front.- s_tokens: List[str]#
None
- t_tokens: List[str]#
None
- s_start: int#
None
- s_end: int#
None
- t_start: int#
None
- t_end: int#
None
- is_correct: bool#
False
- class nemo_rl.algorithms.x_token.token_aligner.AlignmentBatch#
Per-batch alignment payload covering all three loss modes.
The collator hands this dataclass directly to the loss fn alongside the tokenized batch. Tensors are dense-padded to the batch maximum so DTensor V2 can shard on dim 0 without knowing about cross-tokenizer specifics.
.. attribute:: pair_valid
[B, max_pairs]bool. False on padding entries... attribute:: pair_is_correct
[B, max_pairs]bool. True when canonicalized student span text matches canonicalized teacher span text... attribute:: student_exact_partition_mask
[B, T_s]bool. True at student tokens that sit on a 1-1 exact-match pair (gold_loss partition)... attribute:: teacher_exact_partition_mask
[B, T_t]bool. Counterpart... attribute:: student_chunk_id
[B, T_s]long. Chunk index (= pair index) the student token belongs to;-1if not in any chunk (insertion-only pair on student side)... attribute:: teacher_chunk_id
[B, T_t]long. Counterpart... attribute:: num_chunks
[B]long. Number of valid chunks in each sample.- pair_valid: torch.Tensor#
None
- pair_is_correct: torch.Tensor#
None
- student_exact_partition_mask: torch.Tensor#
None
- teacher_exact_partition_mask: torch.Tensor#
None
- student_chunk_id: torch.Tensor#
None
- teacher_chunk_id: torch.Tensor#
None
- num_chunks: torch.Tensor#
None
- class nemo_rl.algorithms.x_token.token_aligner.TokenAligner(
- student_tokenizer,
- teacher_tokenizer,
- projection_matrix_path: str,
- max_comb_len: int = 4,
Aligns student and teacher tokenizations of the same source text.
The alignment algorithm is a Needleman-Wunsch DP over canonicalized token strings, augmented with multi-token combination scoring (one student token can match a span of teacher tokens and vice versa, up to
max_comb_len) and anchor-based segmentation for long sequences.- Parameters:
student_tokenizer – HF tokenizer for the student model.
teacher_tokenizer – HF tokenizer for the teacher model.
projection_matrix_path –
Path retained on the aligner for downstream callers (e.g. the loss fn) that materialize the projection on their training device via
- func:
nemo_rl.algorithms.x_token.loss_utils.get_sparse_projection_matrixor :func:nemo_rl.algorithms.x_token.loss_utils.get_topk_projection.
max_comb_len – Maximum span length considered when matching one token on one side against multiple tokens on the other.
Initialization
- align(
- student_ids: torch.Tensor,
- teacher_ids: torch.Tensor,
- *,
- student_attention_mask: torch.Tensor | None = None,
- teacher_attention_mask: torch.Tensor | None = None,
Align a batch of student/teacher token id tensors.
- Parameters:
student_ids –
[B, T_s]long tensor.teacher_ids –
[B, T_t]long tensor.student_attention_mask –
optional
[B, T_s]mask (1 = real token, 0 = padding). When given, padded positions are forced to thechunk_id = -1/ partition-Falsesentinels so tokenizer padding never forms a valid chunk.alignruns the DP over the fully padded ids, so without this the pad run on each side can be aligned into chunks that survive- func:
valid_chunk_mask.
teacher_attention_mask – optional
[B, T_t]counterpart.
- Returns:
An :class:
AlignmentBatchwith all fields populated for the three loss modes.
- static _drop_padding(
- batch: nemo_rl.algorithms.x_token.token_aligner.AlignmentBatch,
- *,
- student_attention_mask: torch.Tensor | None,
- teacher_attention_mask: torch.Tensor | None,
Strip tokenizer padding out of the chunk-id / partition tensors.
Mutates
batchin place. For every position the attention mask marks as padding, reset*_chunk_idto-1and*_exact_partition_masktoFalse. Gating per position (rather than trimming a contiguous span) keeps this correct under either left- or right-padding. A pair whose tokens are entirely padding on one side then has size 0 there and is dropped by- Func:
nemo_rl.algorithms.x_token.loss_utils.valid_chunk_mask; a pair straddling the real/pad boundary shrinks to its real tokens.
- static _pairs_to_batch(
- per_sample_pairs: List[List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair]],
- *,
- b: int,
- t_s: int,
- t_t: int,
Pack per-sample alignment lists into dense-padded tensors.
- _align_single(
- student_tokens: List[str],
- teacher_tokens: List[str],
- exact_match_score: float = 3.0,
- combination_score_multiplier: float = 1.5,
- gap_penalty: float = -1.5,
- anchor_lengths: Tuple[int, ...] = (3,),
Run canonicalize -> anchor-DP -> post-process for one sample.
- Returns:
A list of :class:
AlignmentPair. Insertions/deletions use-1for the empty side’s start/end. Pair start/end indices address the original token sequences (not canonical space), so they can be written straight into the chunk-id tensors in :meth:_pairs_to_batch.
- static _remap_pairs_to_original(
- pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
- *,
- student_canon_to_orig: List[Tuple[int, int]],
- teacher_canon_to_orig: List[Tuple[int, int]],
Remap pair start/end indices from canonical to original space.
Mutates
pairsin place. Insertion/deletion pairs (where the empty side already has-1start/end) keep the sentinel.s_end - 1/t_end - 1indexes the last canonical token in the span; itsorig_endis the new exclusive end.
- _align_with_anchors(
- student_tokens: List[str],
- teacher_tokens: List[str],
- anchor_lengths: Tuple[int, ...] = (3,),
- *,
- exact_match_score: float,
- combination_score_multiplier: float,
- gap_penalty: float,
- max_combination_len: int,
- ignore_leading_char_diff: bool,
Optimize long alignments by pinning unique n-gram matches as anchors.
Falls back to plain DP when no anchors exist or when
anchor_lengthsis empty.
- static _align_dp(
- student_tokens: List[str],
- teacher_tokens: List[str],
- *,
- exact_match_score: float,
- combination_score_multiplier: float,
- gap_penalty: float,
- max_combination_len: int,
- ignore_leading_char_diff: bool,
Needleman-Wunsch DP with up-to-
max_combination_lentoken spans.
- static _shift_pairs(
- pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
- shift_s: int,
- shift_t: int,
Offset start/end indices of pairs after segment-level alignment.
- static _alignment_mask(
- aligned_pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
Compute is_correct for each pair using canonicalized text comparison.
- static _post_process_alignment(
- aligned_pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
- *,
- exact_match_score: float,
- combination_score_multiplier: float,
- gap_penalty: float,
- max_combination_len: int,
- end_mismatch_threshold: float = 0.2,
Post-process: combine misaligned consecutive pairs and re-align bad spans.
Combines misaligned consecutive pairs and re-aligns bad spans.
- static _build_pair_strings(
- aligned_pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
Precompute (s_str, t_str, is_match) for each pair.
- static _combine_consecutive_misaligned(
- aligned_pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
- pair_strings: List[Tuple[str, str, bool]],
- end_mismatch_threshold: float,
Combine consecutive misaligned pairs into single multi-token chunks.
- static _flatten_chunk(
- chunk: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
Concatenate tokens and collect span indices across a chunk of pairs.
- nemo_rl.algorithms.x_token.token_aligner.canonical_token(token: str, *, enabled: bool = True) str#
Return a canonical representation of a tokenizer token.
Public helper consumed by the alignment pipeline AND by the projection-prep CLIs in
tools/x_token/. Theenabledflag is a passthrough toggle: whenFalsethe input is returned unchanged (lets CLI call sites gate canonicalization via a single flag without branching at every site).
- nemo_rl.algorithms.x_token.token_aligner._canonicalize_sequence(
- seq: List[str],
Canonicalize every token in a sequence, including byte-merging.
- Returns:
(canon, canon_to_orig).canon_to_orig[k]is a half-open[orig_start, orig_end)range giving the original-token positions that canonical tokenkwas built from. Ranges are non-overlapping, strictly increasing, and jointly coverrange(len(seq))— required so that DP-output indices overcanoncan be remapped to positions on the original input-id axis (see :meth:TokenAligner._remap_pairs_to_original).
- nemo_rl.algorithms.x_token.token_aligner._merge_encoding_artifacts(
- tokens: List[str],
Merge known multi-token mojibake patterns into single tokens.
- Returns:
(merged, ranges)with one(orig_start, orig_end)entry per output token. Every entry in :data:_MULTI_TOKEN_ARTIFACT_FIXESrewrites to a single replacement token, so each merge contributes exactly one range covering the matched pattern.
- nemo_rl.algorithms.x_token.token_aligner._get_byte_value(token_char: str) int | None#
Return the byte value (0..255) for a single character, or None.
- nemo_rl.algorithms.x_token.token_aligner._merge_consecutive_bytes(
- tokens: List[str],
- in_ranges: List[Tuple[int, int]],
Merge consecutive byte-fallback tokens back into Unicode characters.
Propagates
in_rangesparallel totokens: when a byte buffer collapses to one character, its parallel range slice is collapsed to a single(start, end); otherwise ranges pass through unchanged.
- nemo_rl.algorithms.x_token.token_aligner._try_merge_byte_buffer(
- byte_tokens: List[str],
- byte_ranges: List[Tuple[int, int]],
Decode 2-4 buffered byte tokens as a single UTF-8 character.
Returns the merged single-character token plus a single collapsed range covering the whole buffer, or the unchanged buffer + ranges when no merge is possible.
- nemo_rl.algorithms.x_token.token_aligner._strings_equal_flexible(
- s1: str,
- s2: str,
- ignore_leading_char_diff: bool,
Compare two strings, optionally after canonicalization.