nemo_rl.data.cross_tokenizer_collate#
Collator that tokenizes raw text twice (student + teacher) and aligns.
The collator runs inside DataLoader worker processes. It does:
Tokenizes the same source text once with the student tokenizer and once with the teacher tokenizer (no chat template, no special handling).
Calls :class:
TokenAligner.alignto produce a dense-padded- class:
AlignmentBatchcovering all three loss modes (P-KL, gold_loss, xtoken_loss).
Returns a :class:
BatchedDataDictwith the keys- class:
Policy.trainexpects (input_ids,input_lengths,token_mask,sample_mask) plus teacher tensors and alignment tensors.
Loss-side projection-matrix work happens inside the loss fn; nothing related to KL/CE math runs here.
Module Contents#
Classes#
Tokenize twice, align once, return a flat tensor batch. |
API#
- class nemo_rl.data.cross_tokenizer_collate.CrossTokenizerCollator(
- *,
- student_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
- teacher_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
- aligner: nemo_rl.algorithms.x_token.token_aligner.TokenAligner,
- ctx_length_student: int,
- ctx_length_teacher: int,
- make_seq_div_by_student: int = 1,
- make_seq_div_by_teacher: int = 1,
Tokenize twice, align once, return a flat tensor batch.
- Parameters:
student_tokenizer – HF tokenizer matching the student model.
teacher_tokenizer – HF tokenizer matching the teacher model.
aligner – Pre-constructed :class:
TokenAligner.ctx_length_student – Hard tokenization length cap on the student side (also the padded sequence length of the student tensor).
ctx_length_teacher – Same on the teacher side.
make_seq_div_by_student – Round student sequence length up to a multiple of this value (typically TP * CP * 2 for DTensor V2).
make_seq_div_by_teacher – Same for the teacher side.
Initialization
- __call__(
- batch: List[nemo_rl.data.interfaces.DatumSpec],
- static _tokenize_batch(
- texts: List[str],
- tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
- ctx_length: int,
- make_seq_div_by: int,
Tokenize a batch and pad to a multiple of
make_seq_div_by.