bridge.training.eval
#
Module Contents#
Functions#
Evaluation function. |
|
Helper function to evaluate and dump results on screen. |
API#
- bridge.training.eval.evaluate(
- state: megatron.bridge.training.state.GlobalState,
- forward_step_func: Callable,
- data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- model: list[megatron.core.transformer.MegatronModule],
- process_non_loss_data_func: Optional[Callable],
- config: megatron.bridge.training.config.ConfigContainer,
- verbose: bool = False,
- non_loss_data_func: Optional[Callable] = None,
Evaluation function.
- Parameters:
state (GlobalState) β The global state object.
forward_step_func (Callable) β The function that performs a forward step.
data_iterator (Optional[Union[RerunDataIterator, list[RerunDataIterator]]]) β Iterator over evaluation data.
model (list[MegatronModule]) β list of model chunks.
process_non_loss_data_func (Optional[Callable]) β Function to process non-loss data.
config (ConfigContainer) β Configuration container (potentially redundant).
verbose (bool, optional) β Whether to print evaluation progress. Defaults to False.
non_loss_data_func (Optional[Callable], optional) β Function to compute non-loss data. Defaults to None.
- Returns:
A tuple containing: - total_loss_dict: Dictionary of averaged losses. - collected_non_loss_data: Data collected by non_loss_data_func. - timelimit_hit: Boolean indicating if the time limit was reached.
- Return type:
tuple[Optional[dict[str, torch.Tensor]], Optional[Any], bool]
- bridge.training.eval.evaluate_and_print_results(
- state: megatron.bridge.training.state.GlobalState,
- prefix: str,
- forward_step_func: Callable,
- data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- model: list[megatron.core.transformer.MegatronModule],
- config: megatron.bridge.training.config.ConfigContainer,
- verbose: bool = False,
- write_to_tensorboard: bool = True,
- process_non_loss_data_func: Optional[Callable] = None,
- non_loss_data_func: Optional[Callable] = None,
Helper function to evaluate and dump results on screen.
- Parameters:
state (GlobalState) β The global state object.
prefix (str) β Prefix for logging evaluation results.
forward_step_func (Callable) β The function that performs a forward step.
data_iterator (Optional[Union[RerunDataIterator, list[RerunDataIterator]]]) β Iterator over evaluation data.
model (list[MegatronModule]) β list of model chunks.
config (ConfigContainer) β Configuration container (potentially redundant).
verbose (bool, optional) β Whether to print evaluation progress. Defaults to False.
write_to_tensorboard (bool, optional) β Whether to write results to TensorBoard. Defaults to True.
process_non_loss_data_func (Optional[Callable], optional) β Function to process non-loss data. Defaults to None.
non_loss_data_func (Optional[Callable], optional) β Function to compute non-loss data. Defaults to None.