bridge.training.eval#

Module Contents#

Functions#

evaluate

Evaluation function.

evaluate_and_print_results

Helper function to evaluate and dump results on screen.

API#

bridge.training.eval.evaluate(
state: megatron.bridge.training.state.GlobalState,
forward_step_func: Callable,
data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
model: list[megatron.core.transformer.MegatronModule],
process_non_loss_data_func: Optional[Callable],
config: megatron.bridge.training.config.ConfigContainer,
verbose: bool = False,
non_loss_data_func: Optional[Callable] = None,
) tuple[Optional[dict[str, torch.Tensor]], Optional[Any], bool]#

Evaluation function.

Parameters:
  • state (GlobalState) – The global state object.

  • forward_step_func (Callable) – The function that performs a forward step.

  • data_iterator (Optional[Union[RerunDataIterator, list[RerunDataIterator]]]) – Iterator over evaluation data.

  • model (list[MegatronModule]) – list of model chunks.

  • process_non_loss_data_func (Optional[Callable]) – Function to process non-loss data.

  • config (ConfigContainer) – Configuration container (potentially redundant).

  • verbose (bool, optional) – Whether to print evaluation progress. Defaults to False.

  • non_loss_data_func (Optional[Callable], optional) – Function to compute non-loss data. Defaults to None.

Returns:

A tuple containing: - total_loss_dict: Dictionary of averaged losses. - collected_non_loss_data: Data collected by non_loss_data_func. - timelimit_hit: Boolean indicating if the time limit was reached.

Return type:

tuple[Optional[dict[str, torch.Tensor]], Optional[Any], bool]

bridge.training.eval.evaluate_and_print_results(
state: megatron.bridge.training.state.GlobalState,
prefix: str,
forward_step_func: Callable,
data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
model: list[megatron.core.transformer.MegatronModule],
config: megatron.bridge.training.config.ConfigContainer,
verbose: bool = False,
write_to_tensorboard: bool = True,
process_non_loss_data_func: Optional[Callable] = None,
non_loss_data_func: Optional[Callable] = None,
) None#

Helper function to evaluate and dump results on screen.

Parameters:
  • state (GlobalState) – The global state object.

  • prefix (str) – Prefix for logging evaluation results.

  • forward_step_func (Callable) – The function that performs a forward step.

  • data_iterator (Optional[Union[RerunDataIterator, list[RerunDataIterator]]]) – Iterator over evaluation data.

  • model (list[MegatronModule]) – list of model chunks.

  • config (ConfigContainer) – Configuration container (potentially redundant).

  • verbose (bool, optional) – Whether to print evaluation progress. Defaults to False.

  • write_to_tensorboard (bool, optional) – Whether to write results to TensorBoard. Defaults to True.

  • process_non_loss_data_func (Optional[Callable], optional) – Function to process non-loss data. Defaults to None.

  • non_loss_data_func (Optional[Callable], optional) – Function to compute non-loss data. Defaults to None.