nemo_automodel.components.eval.tool_call_evaluator
nemo_automodel.components.eval.tool_call_evaluator
Generation-based evaluator for tool-call accuracy during agent SFT.
The loss-only validation that ships with the training recipe cannot
distinguish “loss going down because the model learned the format” from
“loss going down because the model is overfitting the response style
while emitting wrong tool names”. This evaluator closes that gap by
running model.generate() on held-out prompts that terminate right
before an assistant tool-call turn, parsing the generated text with
:mod:nemo_automodel.components.eval.tool_call_parser, and comparing
against the ground-truth tool calls extracted from the dataset.
The evaluator is intentionally framework-agnostic: it operates on any
HuggingFace-style model with a .generate() method and a tokenizer
that supports apply_chat_template(..., tools=...). Distributed
sharding and all-reduce of metrics are left to the caller (the training
recipe), which already has the dist environment in hand.
Module Contents
Classes
Data
API
Generation-based tool-call accuracy evaluator for agent SFT.
The evaluator lazily loads a list of eval samples (one per assistant
tool-call position in the source dataset). On each call to
:meth:evaluate it renders each sample’s prompt_messages and
tools through the tokenizer’s chat template, generates a
continuation, parses any tool calls out of the generated text, and
aggregates per-sample metrics into a corpus-level dict.
Constructor args (all keyword-only):
dataset_name: HF Hub dataset id to load eval samples from.
Mutually exclusive with path.
path: Local JSON/JSONL file (or list of files) to load eval
samples from. Mutually exclusive with dataset_name.
split: Dataset split (only used with dataset_name).
limit_dataset_samples: Cap on dialogues read before expansion.
max_eval_samples: Cap on total expanded eval samples.
max_new_tokens: Generation budget per sample.
max_prompt_tokens: If set, prompts longer than this many tokens
are skipped (logged once). Prevents OOM on degenerate samples.
do_sample: Generation sampling toggle. Default greedy for
reproducibility across validation checkpoints.
metric_prefix: Prefix applied to all returned metric keys.
sample_shard: Optional (rank, world_size) tuple. When set,
only every world_size-th sample starting at rank is
processed; the caller is responsible for all-reducing the
returned _count and weighted-summed metrics.
(rank, world_size) shard, or None to score every sample.
The training recipe sets this so each data-parallel rank scores a disjoint subset, but only when the model is replicated per rank (DDP); sharded strategies must keep every rank on the same samples.
Greedy decode using only model.forward().
Several Automodel custom model classes (notably Qwen2ForCausalLM)
inherit from HFCheckpointingMixin + Qwen2PreTrainedModel but not
from transformers.generation.GenerationMixin, so the FSDP-wrapped
instance has no .generate() method. We fall back to a minimal
token-by-token greedy decode that only requires the forward pass to
return logits. No KV cache, so cost is O(L * (P + L)) per sample
where P is prompt length and L is max_new_tokens — fine
for the small eval budgets used here (default 256 tokens).
Render one eval sample’s prompt through apply_chat_template.
We deliberately split the chat-template render (tokenize=False)
from the tokenization step: some templates / transformers versions
return a list of token strings under tokenize=True, which
then crashes torch.tensor(..., dtype=long) downstream. Going
through text first sidesteps that and matches the canonical HF
usage shown in the model cards.
Returns None if the template raises (e.g. doesn’t accept the
tools kwarg) or if the prompt exceeds max_prompt_tokens.
Run generation-based tool-call evaluation against model.
Caller is expected to have placed the model in eval mode and on the appropriate device. The evaluator infers the device from the first model parameter so it works with FSDP, DDP, or single-GPU layouts without explicit configuration.
Parameters:
a HuggingFace causal-LM with a .generate() method.
tokenizer paired with model; must have a chat
template that supports the tools kwarg.
Returns: Dict[str, float]
Dict of metric name -> float. All metric values are means in