nat.plugins.customizer.dpo.trajectory_builder#

DPO (Direct Preference Optimization) Trajectory Builder.

This module provides a trajectory builder that collects preference data from workflows that produce TTC_END intermediate steps with TTCEventData.

The builder: 1. Runs evaluation to collect intermediate steps 2. Filters for TTC_END steps with the configured name 3. Extracts data from TTCEventData (turn_id, candidate_index, score, input, output) 4. Groups candidates by turn_id 5. Generates preference pairs based on score differences 6. Builds trajectories with DPOItem episodes for DPO training

Attributes#

Classes#

CandidateStep

Parsed candidate from a TTC intermediate step.

PreferencePair

A preference pair for DPO training.

DPOTrajectoryBuilder

Trajectory builder for DPO (Direct Preference Optimization) training.

Module Contents#

logger#
PromptType#
class CandidateStep#

Parsed candidate from a TTC intermediate step.

Represents a single candidate response that was generated and scored for a particular turn in the workflow.

example_id: str#

Unique identifier for the dataset example.

turn_id: str#

Identifier for the turn (groups candidates competing for the same prompt).

candidate_index: int#

Index of this candidate within the turn.

prompt: PromptType#

Input prompt that produced this response (string or list of OpenAIMessage).

response: str#

Model’s response/completion.

score: float#

Score assigned to this candidate (higher is better).

raw_metadata: dict[str, Any]#

Original metadata from the intermediate step.

class PreferencePair#

A preference pair for DPO training.

Represents a single (prompt, chosen, rejected) triple where the chosen response has a higher score than the rejected response.

example_id: str#

Unique identifier for the dataset example.

turn_id: str#

Identifier for the turn.

prompt: PromptType#

Input prompt (same for both responses).

chosen_response: str#

Response that was preferred (higher score).

rejected_response: str#

Response that was not preferred (lower score).

chosen_score: float#

Score of the chosen response.

rejected_score: float#

Score of the rejected response.

score_diff: float#

Difference between chosen and rejected scores.

chosen_index: int#

Candidate index of the chosen response.

rejected_index: int#

Candidate index of the rejected response.

metadata: dict[str, Any]#

Additional metadata for the pair.

class DPOTrajectoryBuilder(
trajectory_builder_config: nat.plugins.customizer.dpo.config.DPOTrajectoryBuilderConfig,
)#

Bases: nat.finetuning.interfaces.trajectory_builder.TrajectoryBuilder

Trajectory builder for DPO (Direct Preference Optimization) training.

This builder collects preference pairs from workflows that produce TTC_END intermediate steps with TTCEventData. It uses the structured data model to extract turn_id, candidate_index, score, input (prompt), and output.

Key features: - Uses TTCEventData model directly (no brittle dictionary key configuration) - Supports prompts as strings or list of OpenAIMessage - Exhaustive or best-vs-worst pair generation modes - Configurable score difference filtering - Grouping by example for curriculum learning - Builds trajectories with DPOItem episodes

Example workflow integration:

trajectory_builders:
  dpo_builder:
    _type: dpo_traj_builder
    ttc_step_name: dpo_candidate_move
    exhaustive_pairs: true
    min_score_diff: 0.05

Initialize the DPO Trajectory Builder.

Args:

trajectory_builder_config: Configuration for the builder.

config: nat.plugins.customizer.dpo.config.DPOTrajectoryBuilderConfig#
evaluation_runs: dict[str, asyncio.Task[nat.eval.config.EvaluationRunOutput]]#
_metrics: dict[str, Any]#
async start_run(run_id: str, meta: dict | None = None) None#

Start a single evaluation run to collect intermediate steps.

Args:

run_id: Unique identifier for this run. meta: Optional metadata for the run.

Raises:

ValueError: If a run with this ID is already in progress.

async finalize(
run_id: str,
meta: dict | None = None,
) nat.data_models.finetuning.TrajectoryCollection#

Wait for evaluation, collect TTC steps, and build DPO trajectories.

This method: 1. Waits for the evaluation run to complete 2. Collects and groups candidates by turn_id using TTCEventData 3. Generates preference pairs 4. Builds trajectories with DPOItem episodes 5. Groups trajectories by example for curriculum learning

Args:

run_id: Unique identifier for the run. meta: Optional metadata for the run.

Returns:

TrajectoryCollection with DPO preference trajectories.

Raises:

ValueError: If no run with this ID exists.

log_progress(
run_id: str,
metrics: dict[str, Any],
output_dir: str | None = None,
) None#

Log trajectory building progress.

Args:

run_id: The training run ID. metrics: Dictionary of metrics to log. output_dir: Optional output directory override.

_collect_candidates(
eval_result: nat.eval.config.EvaluationRunOutput,
) dict[str, list[CandidateStep]]#

Extract TTC_END intermediate steps and group by turn_id.

This method: 1. Iterates through all evaluation input items 2. Filters for TTC_END steps with the configured name 3. Extracts data from TTCEventData model directly 4. Groups candidates by (example_id, turn_id)

Args:

eval_result: The evaluation run output.

Returns:

Dictionary mapping turn keys to lists of candidates.

_is_target_step(
step: nat.data_models.intermediate_step.IntermediateStep,
) bool#

Check if an intermediate step is a target TTC step.

Args:

step: The intermediate step to check.

Returns:

True if this is a TTC_END step with the configured name.

_parse_candidate(
example_id: str,
step: nat.data_models.intermediate_step.IntermediateStep,
) CandidateStep | None#

Parse a CandidateStep from a TTC intermediate step using TTCEventData.

Args:

example_id: The example ID this step belongs to. step: The intermediate step to parse.

Returns:

CandidateStep if parsing succeeds, None otherwise.

_extract_prompt(input_data: Any) PromptType#

Extract prompt from TTCEventData.input.

Handles both string prompts and list of OpenAIMessage.

Args:

input_data: The input field from TTCEventData.

Returns:

String prompt or list of OpenAIMessage.

_generate_preference_pairs(
candidates_by_turn: dict[str, list[CandidateStep]],
) list[PreferencePair]#

Generate preference pairs from grouped candidates.

If exhaustive_pairs=True:

For candidates [A, B, C] with scores [0.9, 0.7, 0.5]: Pairs: (A>B), (A>C), (B>C) - all pairwise comparisons

If exhaustive_pairs=False:

For candidates [A, B, C] with scores [0.9, 0.7, 0.5]: Pairs: (A>C) only - best vs worst

Args:

candidates_by_turn: Dictionary mapping turn keys to candidate lists.

Returns:

List of preference pairs.

_generate_exhaustive_pairs(
sorted_candidates: list[CandidateStep],
) list[PreferencePair]#

Generate all pairwise comparisons where score(chosen) > score(rejected).

Args:

sorted_candidates: Candidates sorted by score (descending).

Returns:

List of preference pairs, sorted by score difference (descending).

_generate_best_vs_worst_pair(
sorted_candidates: list[CandidateStep],
) list[PreferencePair]#

Generate a single pair: best candidate vs worst candidate.

Args:

sorted_candidates: Candidates sorted by score (descending).

Returns:

List with at most one preference pair.

_build_trajectories(
pairs: list[PreferencePair],
) list[nat.data_models.finetuning.Trajectory]#

Convert preference pairs to Trajectory format with DPOItem episodes.

Each trajectory contains: - episode: [DPOItem] with prompt, chosen_response, rejected_response - reward: score_diff (if reward_from_score_diff) or chosen_score - metadata: Contains pair information for tracking

Args:

pairs: List of preference pairs.

Returns:

List of trajectories with DPOItem episodes.

_group_by_example(
trajectories: list[nat.data_models.finetuning.Trajectory],
) list[list[nat.data_models.finetuning.Trajectory]]#

Group trajectories by example ID for curriculum learning.

This grouping enables: - Filtering by average reward per example - Expansion from easy to hard examples

Args:

trajectories: List of trajectories to group.

Returns:

List of trajectory lists, where each inner list contains trajectories for one example.