nemo_rl.data.cross_tokenizer_collate#

Collator that tokenizes raw text twice (student + teacher) and aligns.

The collator runs inside DataLoader worker processes. It does:

  1. Tokenizes the same source text once with the student tokenizer and once with the teacher tokenizer (no chat template, no special handling).

  2. Calls :class:TokenAligner.align to produce a dense-padded

    class:

    AlignmentBatch covering all three loss modes (P-KL, gold_loss, xtoken_loss).

  3. Returns a :class:BatchedDataDict with the keys

    class:

    Policy.train expects (input_ids, input_lengths, token_mask, sample_mask) plus teacher tensors and alignment tensors.

Loss-side projection-matrix work happens inside the loss fn; nothing related to KL/CE math runs here.

Module Contents#

Classes#

CrossTokenizerCollator

Tokenize twice, align once, return a flat tensor batch.

API#

class nemo_rl.data.cross_tokenizer_collate.CrossTokenizerCollator(
*,
student_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
teacher_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
aligner: nemo_rl.algorithms.x_token.token_aligner.TokenAligner,
ctx_length_student: int,
ctx_length_teacher: int,
make_seq_div_by_student: int = 1,
make_seq_div_by_teacher: int = 1,
)#

Tokenize twice, align once, return a flat tensor batch.

Parameters:
  • student_tokenizer – HF tokenizer matching the student model.

  • teacher_tokenizer – HF tokenizer matching the teacher model.

  • aligner – Pre-constructed :class:TokenAligner.

  • ctx_length_student – Hard tokenization length cap on the student side (also the padded sequence length of the student tensor).

  • ctx_length_teacher – Same on the teacher side.

  • make_seq_div_by_student – Round student sequence length up to a multiple of this value (typically TP * CP * 2 for DTensor V2).

  • make_seq_div_by_teacher – Same for the teacher side.

Initialization

__call__(
batch: List[nemo_rl.data.interfaces.DatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#
static _tokenize_batch(
texts: List[str],
tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
ctx_length: int,
make_seq_div_by: int,
) tuple[torch.Tensor, torch.Tensor]#

Tokenize a batch and pad to a multiple of make_seq_div_by.