nemo_rl.algorithms.x_token.token_aligner#

Cross-tokenizer token alignment.

The DP kernel, canonicalization helpers, anchor optimization, and post-processing implement the cross-tokenizer alignment consumed by the loss. Anything unrelated to running cross-tokenizer alignment for off-policy distillation (rule tracking, learnable projection, MSE loss, multiple compute_loss variants, accuracy, translation) is dropped.

Public surface: - :class:AlignmentPair — per-pair record produced by the DP / anchor pipeline; replaces the loose (s_tokens, t_tokens, s_start, s_end,       t_start, t_end, is_correct) tuples that the helpers used to pass around. - :class:AlignmentBatch — dense-padded per-batch alignment payload that covers all three loss modes (P-KL, gold_loss, xtoken_loss). - :class:TokenAligner — owns the two tokenizers and the projection matrix, exposes :meth:align for the collator.

Module Contents#

Classes#

AlignmentPair

One aligned span between student and teacher token sequences.

AlignmentBatch

Per-batch alignment payload covering all three loss modes.

TokenAligner

Aligns student and teacher tokenizations of the same source text.

Functions#

canonical_token

Return a canonical representation of a tokenizer token.

_canonicalize_sequence

Canonicalize every token in a sequence, including byte-merging.

_merge_encoding_artifacts

Merge known multi-token mojibake patterns into single tokens.

_get_byte_value

Return the byte value (0..255) for a single character, or None.

_merge_consecutive_bytes

Merge consecutive byte-fallback tokens back into Unicode characters.

_try_merge_byte_buffer

Decode 2-4 buffered byte tokens as a single UTF-8 character.

_strings_equal_flexible

Compare two strings, optionally after canonicalization.

Data#

API#

nemo_rl.algorithms.x_token.token_aligner.VISUAL_BYTE_MAP#

None

nemo_rl.algorithms.x_token.token_aligner._MULTI_TOKEN_ARTIFACT_FIXES#

[([‘ĠâĪ’, ‘ij’], [‘Ġ∑’]), ([‘âĪ’, ‘ij’], [’∑’]), ([‘ĠâĪ’, ‘ı’], [‘Ġ∏’]), ([‘âĪ’, ‘ı’], [’∏’]), ([‘ĠâĪ’…

nemo_rl.algorithms.x_token.token_aligner._UNICODE_FIXES#

None

nemo_rl.algorithms.x_token.token_aligner._SPECIAL_TOKEN_MAP#

None

class nemo_rl.algorithms.x_token.token_aligner.AlignmentPair#

One aligned span between student and teacher token sequences.

The DP / anchor / post-process helpers construct these as they trace the alignment; _align_single then fills in is_correct from the canonicalized-text comparison. Insertions/deletions use -1 for the empty side’s start/end indices.

.. attribute:: s_tokens

Student tokens covered by this pair.

.. attribute:: t_tokens

Teacher tokens covered by this pair.

.. attribute:: s_start

Inclusive start index into the student token sequence (-1 for teacher-only insertions).

.. attribute:: s_end

Exclusive end index into the student token sequence (-1 for teacher-only insertions).

.. attribute:: t_start

Inclusive start index into the teacher token sequence (-1 for student-only insertions).

.. attribute:: t_end

Exclusive end index into the teacher token sequence (-1 for student-only insertions).

.. attribute:: is_correct

True when the canonicalized student span text matches the canonicalized teacher span text. Defaults to False so DP / anchor stages can build pairs without computing the mask up front.

s_tokens: List[str]#

None

t_tokens: List[str]#

None

s_start: int#

None

s_end: int#

None

t_start: int#

None

t_end: int#

None

is_correct: bool#

False

class nemo_rl.algorithms.x_token.token_aligner.AlignmentBatch#

Per-batch alignment payload covering all three loss modes.

The collator hands this dataclass directly to the loss fn alongside the tokenized batch. Tensors are dense-padded to the batch maximum so DTensor V2 can shard on dim 0 without knowing about cross-tokenizer specifics.

.. attribute:: pair_valid

[B, max_pairs] bool. False on padding entries.

.. attribute:: pair_is_correct

[B, max_pairs] bool. True when canonicalized student span text matches canonicalized teacher span text.

.. attribute:: student_exact_partition_mask

[B, T_s] bool. True at student tokens that sit on a 1-1 exact-match pair (gold_loss partition).

.. attribute:: teacher_exact_partition_mask

[B, T_t] bool. Counterpart.

.. attribute:: student_chunk_id

[B, T_s] long. Chunk index (= pair index) the student token belongs to; -1 if not in any chunk (insertion-only pair on student side).

.. attribute:: teacher_chunk_id

[B, T_t] long. Counterpart.

.. attribute:: num_chunks

[B] long. Number of valid chunks in each sample.

pair_valid: torch.Tensor#

None

pair_is_correct: torch.Tensor#

None

student_exact_partition_mask: torch.Tensor#

None

teacher_exact_partition_mask: torch.Tensor#

None

student_chunk_id: torch.Tensor#

None

teacher_chunk_id: torch.Tensor#

None

num_chunks: torch.Tensor#

None

class nemo_rl.algorithms.x_token.token_aligner.TokenAligner(
student_tokenizer,
teacher_tokenizer,
projection_matrix_path: str,
max_comb_len: int = 4,
)#

Aligns student and teacher tokenizations of the same source text.

The alignment algorithm is a Needleman-Wunsch DP over canonicalized token strings, augmented with multi-token combination scoring (one student token can match a span of teacher tokens and vice versa, up to max_comb_len) and anchor-based segmentation for long sequences.

Parameters:
  • student_tokenizer – HF tokenizer for the student model.

  • teacher_tokenizer – HF tokenizer for the teacher model.

  • projection_matrix_path

    Path retained on the aligner for downstream callers (e.g. the loss fn) that materialize the projection on their training device via

    func:

    nemo_rl.algorithms.x_token.loss_utils.get_sparse_projection_matrix or :func:nemo_rl.algorithms.x_token.loss_utils.get_topk_projection.

  • max_comb_len – Maximum span length considered when matching one token on one side against multiple tokens on the other.

Initialization

align(
student_ids: torch.Tensor,
teacher_ids: torch.Tensor,
*,
student_attention_mask: torch.Tensor | None = None,
teacher_attention_mask: torch.Tensor | None = None,
) nemo_rl.algorithms.x_token.token_aligner.AlignmentBatch#

Align a batch of student/teacher token id tensors.

Parameters:
  • student_ids[B, T_s] long tensor.

  • teacher_ids[B, T_t] long tensor.

  • student_attention_mask

    optional [B, T_s] mask (1 = real token, 0 = padding). When given, padded positions are forced to the chunk_id = -1 / partition-False sentinels so tokenizer padding never forms a valid chunk. align runs the DP over the fully padded ids, so without this the pad run on each side can be aligned into chunks that survive

    func:

    valid_chunk_mask.

  • teacher_attention_mask – optional [B, T_t] counterpart.

Returns:

An :class:AlignmentBatch with all fields populated for the three loss modes.

static _drop_padding(
batch: nemo_rl.algorithms.x_token.token_aligner.AlignmentBatch,
*,
student_attention_mask: torch.Tensor | None,
teacher_attention_mask: torch.Tensor | None,
) None#

Strip tokenizer padding out of the chunk-id / partition tensors.

Mutates batch in place. For every position the attention mask marks as padding, reset *_chunk_id to -1 and *_exact_partition_mask to False. Gating per position (rather than trimming a contiguous span) keeps this correct under either left- or right-padding. A pair whose tokens are entirely padding on one side then has size 0 there and is dropped by

Func:

nemo_rl.algorithms.x_token.loss_utils.valid_chunk_mask; a pair straddling the real/pad boundary shrinks to its real tokens.

static _pairs_to_batch(
per_sample_pairs: List[List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair]],
*,
b: int,
t_s: int,
t_t: int,
) nemo_rl.algorithms.x_token.token_aligner.AlignmentBatch#

Pack per-sample alignment lists into dense-padded tensors.

_align_single(
student_tokens: List[str],
teacher_tokens: List[str],
exact_match_score: float = 3.0,
combination_score_multiplier: float = 1.5,
gap_penalty: float = -1.5,
anchor_lengths: Tuple[int, ...] = (3,),
) List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair]#

Run canonicalize -> anchor-DP -> post-process for one sample.

Returns:

A list of :class:AlignmentPair. Insertions/deletions use -1 for the empty side’s start/end. Pair start/end indices address the original token sequences (not canonical space), so they can be written straight into the chunk-id tensors in :meth:_pairs_to_batch.

static _remap_pairs_to_original(
pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
*,
student_canon_to_orig: List[Tuple[int, int]],
teacher_canon_to_orig: List[Tuple[int, int]],
) None#

Remap pair start/end indices from canonical to original space.

Mutates pairs in place. Insertion/deletion pairs (where the empty side already has -1 start/end) keep the sentinel. s_end - 1 / t_end - 1 indexes the last canonical token in the span; its orig_end is the new exclusive end.

_align_with_anchors(
student_tokens: List[str],
teacher_tokens: List[str],
anchor_lengths: Tuple[int, ...] = (3,),
*,
exact_match_score: float,
combination_score_multiplier: float,
gap_penalty: float,
max_combination_len: int,
ignore_leading_char_diff: bool,
) Tuple[List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair], float]#

Optimize long alignments by pinning unique n-gram matches as anchors.

Falls back to plain DP when no anchors exist or when anchor_lengths is empty.

static _align_dp(
student_tokens: List[str],
teacher_tokens: List[str],
*,
exact_match_score: float,
combination_score_multiplier: float,
gap_penalty: float,
max_combination_len: int,
ignore_leading_char_diff: bool,
) Tuple[List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair], float]#

Needleman-Wunsch DP with up-to-max_combination_len token spans.

static _shift_pairs(
pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
shift_s: int,
shift_t: int,
) List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair]#

Offset start/end indices of pairs after segment-level alignment.

static _alignment_mask(
aligned_pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
) List[bool]#

Compute is_correct for each pair using canonicalized text comparison.

static _post_process_alignment(
aligned_pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
*,
exact_match_score: float,
combination_score_multiplier: float,
gap_penalty: float,
max_combination_len: int,
end_mismatch_threshold: float = 0.2,
) List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair]#

Post-process: combine misaligned consecutive pairs and re-align bad spans.

Combines misaligned consecutive pairs and re-aligns bad spans.

static _build_pair_strings(
aligned_pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
) List[Tuple[str, str, bool]]#

Precompute (s_str, t_str, is_match) for each pair.

static _combine_consecutive_misaligned(
aligned_pairs: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
pair_strings: List[Tuple[str, str, bool]],
end_mismatch_threshold: float,
) List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair]#

Combine consecutive misaligned pairs into single multi-token chunks.

static _flatten_chunk(
chunk: List[nemo_rl.algorithms.x_token.token_aligner.AlignmentPair],
) Tuple[List[str], List[str], List[int], List[int]]#

Concatenate tokens and collect span indices across a chunk of pairs.

nemo_rl.algorithms.x_token.token_aligner.canonical_token(token: str, *, enabled: bool = True) str#

Return a canonical representation of a tokenizer token.

Public helper consumed by the alignment pipeline AND by the projection-prep CLIs in tools/x_token/. The enabled flag is a passthrough toggle: when False the input is returned unchanged (lets CLI call sites gate canonicalization via a single flag without branching at every site).

nemo_rl.algorithms.x_token.token_aligner._canonicalize_sequence(
seq: List[str],
) Tuple[List[str], List[Tuple[int, int]]]#

Canonicalize every token in a sequence, including byte-merging.

Returns:

(canon, canon_to_orig). canon_to_orig[k] is a half-open [orig_start, orig_end) range giving the original-token positions that canonical token k was built from. Ranges are non-overlapping, strictly increasing, and jointly cover range(len(seq)) — required so that DP-output indices over canon can be remapped to positions on the original input-id axis (see :meth:TokenAligner._remap_pairs_to_original).

nemo_rl.algorithms.x_token.token_aligner._merge_encoding_artifacts(
tokens: List[str],
) Tuple[List[str], List[Tuple[int, int]]]#

Merge known multi-token mojibake patterns into single tokens.

Returns:

(merged, ranges) with one (orig_start, orig_end) entry per output token. Every entry in :data:_MULTI_TOKEN_ARTIFACT_FIXES rewrites to a single replacement token, so each merge contributes exactly one range covering the matched pattern.

nemo_rl.algorithms.x_token.token_aligner._get_byte_value(token_char: str) int | None#

Return the byte value (0..255) for a single character, or None.

nemo_rl.algorithms.x_token.token_aligner._merge_consecutive_bytes(
tokens: List[str],
in_ranges: List[Tuple[int, int]],
) Tuple[List[str], List[Tuple[int, int]]]#

Merge consecutive byte-fallback tokens back into Unicode characters.

Propagates in_ranges parallel to tokens: when a byte buffer collapses to one character, its parallel range slice is collapsed to a single (start, end); otherwise ranges pass through unchanged.

nemo_rl.algorithms.x_token.token_aligner._try_merge_byte_buffer(
byte_tokens: List[str],
byte_ranges: List[Tuple[int, int]],
) Tuple[List[str], List[Tuple[int, int]]]#

Decode 2-4 buffered byte tokens as a single UTF-8 character.

Returns the merged single-character token plus a single collapsed range covering the whole buffer, or the unchanged buffer + ranges when no merge is possible.

nemo_rl.algorithms.x_token.token_aligner._strings_equal_flexible(
s1: str,
s2: str,
ignore_leading_char_diff: bool,
) bool#

Compare two strings, optionally after canonicalization.