nemo_automodel.components.eval.tool_call_evaluator

View as Markdown

Generation-based evaluator for tool-call accuracy during agent SFT.

The loss-only validation that ships with the training recipe cannot distinguish “loss going down because the model learned the format” from “loss going down because the model is overfitting the response style while emitting wrong tool names”. This evaluator closes that gap by running model.generate() on held-out prompts that terminate right before an assistant tool-call turn, parsing the generated text with :mod:nemo_automodel.components.eval.tool_call_parser, and comparing against the ground-truth tool calls extracted from the dataset.

The evaluator is intentionally framework-agnostic: it operates on any HuggingFace-style model with a .generate() method and a tokenizer that supports apply_chat_template(..., tools=...). Distributed sharding and all-reduce of metrics are left to the caller (the training recipe), which already has the dist environment in hand.

Module Contents

Classes

NameDescription
ToolCallAccuracyEvaluatorGeneration-based tool-call accuracy evaluator for agent SFT.

Data

logger

API

class nemo_automodel.components.eval.tool_call_evaluator.ToolCallAccuracyEvaluator(
dataset_name: typing.Optional[str] = None,
path: typing.Optional[typing.Union[str, typing.List[str]]] = None,
split: str = 'train',
limit_dataset_samples: typing.Optional[int] = None,
max_eval_samples: typing.Optional[int] = None,
max_new_tokens: int = 256,
max_prompt_tokens: typing.Optional[int] = None,
do_sample: bool = False,
metric_prefix: str = 'tool_call',
sample_shard: typing.Optional[tuple] = None,
raise_on_cuda_oom: bool = True,
run_on_fsdp2: bool = False
)

Generation-based tool-call accuracy evaluator for agent SFT.

The evaluator lazily loads a list of eval samples (one per assistant tool-call position in the source dataset). On each call to :meth:evaluate it renders each sample’s prompt_messages and tools through the tokenizer’s chat template, generates a continuation, parses any tool calls out of the generated text, and aggregates per-sample metrics into a corpus-level dict.

Constructor args (all keyword-only): dataset_name: HF Hub dataset id to load eval samples from. Mutually exclusive with path. path: Local JSON/JSONL file (or list of files) to load eval samples from. Mutually exclusive with dataset_name. split: Dataset split (only used with dataset_name). limit_dataset_samples: Cap on dialogues read before expansion. max_eval_samples: Cap on total expanded eval samples. max_new_tokens: Generation budget per sample. max_prompt_tokens: If set, prompts longer than this many tokens are skipped (logged once). Prevents OOM on degenerate samples. do_sample: Generation sampling toggle. Default greedy for reproducibility across validation checkpoints. metric_prefix: Prefix applied to all returned metric keys. sample_shard: Optional (rank, world_size) tuple. When set, only every world_size-th sample starting at rank is processed; the caller is responsible for all-reducing the returned _count and weighted-summed metrics.

METRIC_KEYS
_samples_cache
Optional[List[Dict[str, Any]]] = None
metric_prefix
= metric_prefix.rstrip('/')
sample_shard
Optional[tuple]

(rank, world_size) shard, or None to score every sample.

The training recipe sets this so each data-parallel rank scores a disjoint subset, but only when the model is replicated per rank (DDP); sharded strategies must keep every rank on the same samples.

nemo_automodel.components.eval.tool_call_evaluator.ToolCallAccuracyEvaluator._cleanup_cuda() -> None
nemo_automodel.components.eval.tool_call_evaluator.ToolCallAccuracyEvaluator._greedy_generate_manual(
model,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
max_new_tokens: int,
eos_token_id: typing.Optional[int]
) -> torch.Tensor

Greedy decode using only model.forward().

Several Automodel custom model classes (notably Qwen2ForCausalLM) inherit from HFCheckpointingMixin + Qwen2PreTrainedModel but not from transformers.generation.GenerationMixin, so the FSDP-wrapped instance has no .generate() method. We fall back to a minimal token-by-token greedy decode that only requires the forward pass to return logits. No KV cache, so cost is O(L * (P + L)) per sample where P is prompt length and L is max_new_tokens — fine for the small eval budgets used here (default 256 tokens).

nemo_automodel.components.eval.tool_call_evaluator.ToolCallAccuracyEvaluator._iter_my_samples() -> typing.List[typing.Dict[str, typing.Any]]
nemo_automodel.components.eval.tool_call_evaluator.ToolCallAccuracyEvaluator._load_samples() -> typing.List[typing.Dict[str, typing.Any]]
nemo_automodel.components.eval.tool_call_evaluator.ToolCallAccuracyEvaluator._render_prompt_ids(
tokenizer,
sample: typing.Dict[str, typing.Any],
skip_reasons: typing.Optional[typing.Dict[str, int]] = None
) -> typing.Optional[typing.List[int]]

Render one eval sample’s prompt through apply_chat_template.

We deliberately split the chat-template render (tokenize=False) from the tokenization step: some templates / transformers versions return a list of token strings under tokenize=True, which then crashes torch.tensor(..., dtype=long) downstream. Going through text first sidesteps that and matches the canonical HF usage shown in the model cards.

Returns None if the template raises (e.g. doesn’t accept the tools kwarg) or if the prompt exceeds max_prompt_tokens.

nemo_automodel.components.eval.tool_call_evaluator.ToolCallAccuracyEvaluator.evaluate(
model,
tokenizer
) -> typing.Dict[str, float]

Run generation-based tool-call evaluation against model.

Caller is expected to have placed the model in eval mode and on the appropriate device. The evaluator infers the device from the first model parameter so it works with FSDP, DDP, or single-GPU layouts without explicit configuration.

Parameters:

model

a HuggingFace causal-LM with a .generate() method.

tokenizer

tokenizer paired with model; must have a chat template that supports the tools kwarg.

Returns: Dict[str, float]

Dict of metric name -> float. All metric values are means in

nemo_automodel.components.eval.tool_call_evaluator.logger = logging.getLogger(__name__)