nemo_automodel.components.datasets.llm.prefix_tree

View as Markdown

Shared-prefix rollout folding for multi-turn prefix-tree attention (cp=1).

Folds a group of rollouts that share one prompt prefix (one prompt -> N sampled completions) into a single deduplicated flat token layout plus the prefix-tree structure (node_lengths / sample_paths). The shared prompt is stored once; every completion attends FULL to the prompt and CAUSAL to itself.

This is the verl RFC #6401 / Automodel #2385 shared-prefix RL training layout, restricted to the cp=1 path (no context-parallel dispatch). The collate carries the structure on the batch; the magi backend builds the AttnMaskSpec from it and activates it (the datasets layer must not import components.distributed). Enable it with model.backend.attn: magi.

Branch-point note: the shared prompt’s final position is stored once, so it can predict only one next token. The N completions diverge there, so the prompt -> first-completion-token transition is left unsupervised (label -100 on the last prompt token); each completion is supervised causally from its own first token onward. This is inherent to deduplicating the shared prefix.

Module Contents

Classes

NameDescription
FoldedRolloutsDeduplicated flat layout + prefix-tree structure for one shared-prefix group.

Functions

NameDescription
fold_shared_prefix_rolloutsFold one shared-prefix rollout group into a deduplicated prefix-tree layout.
prefix_tree_collate_fnCollate one shared-prefix rollout group into a model-ready batch (cp=1).

Data

CROSS_ENTROPY_IGNORE_IDX

API

class nemo_automodel.components.datasets.llm.prefix_tree.FoldedRollouts(
input_ids: list[int],
labels: list[int],
position_ids: list[int],
node_lengths: list[int],
sample_paths: list[list[int]]
)
Dataclass

Deduplicated flat layout + prefix-tree structure for one shared-prefix group.

The attention mask itself is built in the magi backend from node_lengths / sample_paths (via AttnMaskSpec.prefix_tree); the datasets layer must not import components.distributed, so it only carries the structure.

input_ids
list[int]
labels
list[int]
node_lengths
list[int]
position_ids
list[int]
sample_paths
list[list[int]]
nemo_automodel.components.datasets.llm.prefix_tree.fold_shared_prefix_rollouts(
prompt_ids: list[int],
completions: list[list[int]],
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX
) -> nemo_automodel.components.datasets.llm.prefix_tree.FoldedRollouts

Fold one shared-prefix rollout group into a deduplicated prefix-tree layout.

Labels follow this repo’s next-token convention: they are pre-shifted (the loss does not shift), so position p carries the id of the token the model should predict at p. Within each completion, token t predicts token t + 1; the completion’s last token has no in-layout successor and is masked. The shared prompt is masked entirely, including its last position: that position is the branch point where the N completions diverge, so it cannot supervise any single first-completion token (the cost of deduplicating the prefix is that each completion’s first token is unsupervised).

Parameters:

prompt_ids
list[int]

the shared prompt tokens (node 0). May be empty.

completions
list[list[int]]

one token-id list per sampled completion (one leaf each). Must be non-empty and every completion must be non-empty.

ignore_idx
intDefaults to CROSS_ENTROPY_IGNORE_IDX

label value for unsupervised positions (default -100).

Returns: FoldedRollouts

class:FoldedRollouts with the flat tokens, labels, position ids and

Raises:

  • ValueError: if completions is empty or any completion is empty.
nemo_automodel.components.datasets.llm.prefix_tree.prefix_tree_collate_fn(
batch: list[dict]
) -> dict

Collate one shared-prefix rollout group into a model-ready batch (cp=1).

Folds the group with :func:fold_shared_prefix_rollouts and emits the flat tokens plus the prefix-tree structure. Only local_batch_size == 1 is supported: each group already packs many completions into one flat sequence, and the mask is per group. The prefix_tree entry is popped by the magi backend (MagiState.prepare_llm_batch), which builds and activates the AttnMaskSpec from it.

Parameters:

batch
list[dict]

a length-1 list holding one rollout group dict with keys prompt_ids and completions.

Returns: dict

Dict with input_ids, labels, position_ids (each [1, T])

Raises:

  • ValueError: if batch does not hold exactly one rollout group.
nemo_automodel.components.datasets.llm.prefix_tree.CROSS_ENTROPY_IGNORE_IDX = -100