nemo_automodel.components.datasets.llm.prefix_tree
nemo_automodel.components.datasets.llm.prefix_tree
Shared-prefix rollout folding for multi-turn prefix-tree attention (cp=1).
Folds a group of rollouts that share one prompt prefix (one prompt -> N sampled
completions) into a single deduplicated flat token layout plus the prefix-tree
structure (node_lengths / sample_paths). The shared prompt is stored
once; every completion attends FULL to the prompt and CAUSAL to itself.
This is the verl RFC #6401 / Automodel #2385 shared-prefix RL training layout,
restricted to the cp=1 path (no context-parallel dispatch). The collate carries
the structure on the batch; the magi backend builds the AttnMaskSpec from it
and activates it (the datasets layer must not import components.distributed).
Enable it with model.backend.attn: magi.
Branch-point note: the shared prompt’s final position is stored once, so it can
predict only one next token. The N completions diverge there, so the
prompt -> first-completion-token transition is left unsupervised (label
-100 on the last prompt token); each completion is supervised causally from
its own first token onward. This is inherent to deduplicating the shared prefix.
Module Contents
Classes
Functions
Data
API
Deduplicated flat layout + prefix-tree structure for one shared-prefix group.
The attention mask itself is built in the magi backend from node_lengths /
sample_paths (via AttnMaskSpec.prefix_tree); the datasets layer must not
import components.distributed, so it only carries the structure.
Fold one shared-prefix rollout group into a deduplicated prefix-tree layout.
Labels follow this repo’s next-token convention: they are pre-shifted (the
loss does not shift), so position p carries the id of the token the model
should predict at p. Within each completion, token t predicts token
t + 1; the completion’s last token has no in-layout successor and is
masked. The shared prompt is masked entirely, including its last position:
that position is the branch point where the N completions diverge, so it
cannot supervise any single first-completion token (the cost of deduplicating
the prefix is that each completion’s first token is unsupervised).
Parameters:
the shared prompt tokens (node 0). May be empty.
one token-id list per sampled completion (one leaf each). Must be non-empty and every completion must be non-empty.
label value for unsupervised positions (default -100).
Returns: FoldedRollouts
class:FoldedRollouts with the flat tokens, labels, position ids and
Raises:
ValueError: ifcompletionsis empty or any completion is empty.
Collate one shared-prefix rollout group into a model-ready batch (cp=1).
Folds the group with :func:fold_shared_prefix_rollouts and emits the flat
tokens plus the prefix-tree structure. Only local_batch_size == 1 is
supported: each group already packs many completions into one flat sequence,
and the mask is per group. The prefix_tree entry is popped by the magi
backend (MagiState.prepare_llm_batch), which builds and activates the
AttnMaskSpec from it.
Parameters:
a length-1 list holding one rollout group dict with keys
prompt_ids and completions.
Returns: dict
Dict with input_ids, labels, position_ids (each [1, T])
Raises:
ValueError: ifbatchdoes not hold exactly one rollout group.