core.rerun_state_machine#
Module Contents#
Classes#
Class capturing where validate_result() is called from. |
|
Class capturing a function call. |
|
Enum representing the different diagnostic attributions. |
|
Enum representing the different run mode for the rerun state machine. |
|
Enum representing the different states of the rerun state machine. |
|
Enum representing the status of a record in the tracker log file |
|
Class implementing the re-run state machine used to validate calculations. |
|
A wrapper class for data iterators that adds replay capability. |
|
Simple class to keep track of distribution of a statistic. |
|
A class to manage error injection into the rerun state machine. |
Functions#
Helper function to initialize the rerun machine instance. |
|
Helper function to shut down the rerun machine instance. |
|
Helper function to return the singleton instance of the rerun machine. |
|
Internal function to set the singleton instance of the rerun machine. |
|
Internal function that safely checks and returns the rank of the caller. |
|
Internal function that implements the default compare_func. |
Data#
API#
- core.rerun_state_machine.logger#
‘getLogger(…)’
- core.rerun_state_machine._GLOBAL_RERUN_STATE_MACHINE: Optional[core.rerun_state_machine.RerunStateMachine]#
None
- core.rerun_state_machine.EXIT_CODE_RESUME_TO_DISAMBIGUATE: int#
16
- core.rerun_state_machine.EXIT_CODE_FAILED_ON_RESULT_VALIDATION: int#
17
- core.rerun_state_machine.SerializableStateType#
None
- core.rerun_state_machine.DataIteratorArgType#
None
- class core.rerun_state_machine.Caller#
Bases:
typing.NamedTupleClass capturing where validate_result() is called from.
- message: str#
None
- rank: int#
None
- class core.rerun_state_machine.Call#
Bases:
typing.NamedTupleClass capturing a function call.
- caller: core.rerun_state_machine.Caller#
None
- sequence: int#
None
- class core.rerun_state_machine.RerunDiagnostic#
Bases:
str,enum.EnumEnum representing the different diagnostic attributions.
CORRECT_RESULT: the result was the expected result given the input. TRANSIENT_ERROR: the result could not be reproduced on the same GPU. PERSISTENT_ERROR: the result could be reproduced on the same GPU, but not on a different GPU.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- CORRECT_RESULT#
‘correct_result’
- TRANSIENT_ERROR#
‘transient_error’
- PERSISTENT_ERROR#
‘persistent_error’
- class core.rerun_state_machine.RerunMode#
Bases:
str,enum.EnumEnum representing the different run mode for the rerun state machine.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- DISABLED#
‘disabled’
- VALIDATE_RESULTS#
‘validate_results’
- REPORT_DETERMINISM_STATS#
‘report_determinism_stats’
- class core.rerun_state_machine.RerunState(*args, **kwds)#
Bases:
enum.EnumEnum representing the different states of the rerun state machine.
Description of states (would benefit from a diagram):
NOT_RUNNING_YET State before the should_rerun_forward_and_backward while loop has been entered (and not restarting from a checkpoint for a 2nd re-run), and after it has been successfully completed (all validation succeeded).
INITIAL_RUN State during the initial run of the should_rerun_forward_and_backward while loop.
RERUNNING_IN_PLACE State during the second run of the should_rerun_forward_and_backward (1+ validation has failed).
WILL_RERUN_FROM_CHECKPOINT State after the should_rerun_forward_and_backward while loop has exited (on initial job run) and before the while loop has been entered (on the second job run restarted from the checkpoint) when the 1st re-run yielded the same result than on the initial run.
RERUNNING_FROM_CHECKPOINT State during first (and only) run of the should_rerun_forward_and_backward while loop when the job was restarted from a checkpoint.
RERUNNING_AGAIN_FROM_CHECKPOINT State when the re-run from checkpoint was rescheduled on the same potentially faulty GPU.
Initialization
- NOT_RUNNING_YET#
0
- INITIAL_RUN#
1
- RERUNNING_IN_PLACE#
2
- WILL_RERUN_FROM_CHECKPOINT#
3
- RERUNNING_FROM_CHECKPOINT#
4
- RERUNNING_AGAIN_FROM_CHECKPOINT#
5
- class core.rerun_state_machine.RerunValidationStatus#
Bases:
str,enum.EnumEnum representing the status of a record in the tracker log file
Initialization
Initialize self. See help(type(self)) for accurate signature.
- RERUN_DISABLED#
‘rerun_disabled’
- INITIAL_RUN#
‘initial_run’
- FIRST_RERUN_NOT_REPRODUCIBLE#
‘first_rerun_not_reproducible’
- FIRST_RERUN_REPRODUCIBLE#
‘first_rerun_reproducible’
- SECOND_RERUN_NOT_REPRODUCIBLE#
‘second_rerun_not_reproducible’
- SECOND_RERUN_REPRODUCIBLE#
‘second_rerun_reproducible’
- core.rerun_state_machine.COMPARISON_MATCH: float#
0.0
- core.rerun_state_machine.COMPARISON_MISMATCH: float#
None
- class core.rerun_state_machine.RerunStateMachine(
- state_save_func: Optional[Callable[[], core.rerun_state_machine.SerializableStateType]] = None,
- state_restore_func: Optional[Callable[[core.rerun_state_machine.SerializableStateType], None]] = None,
- mode: core.rerun_state_machine.RerunMode = RerunMode.DISABLED,
- error_injector: Optional[core.rerun_state_machine.RerunErrorInjector] = None,
- result_rejected_tracker_filename: Optional[str] = None,
Class implementing the re-run state machine used to validate calculations.
This class is a singleton and should not be instantiated directly. The instance should be initialized by calling the initialize_rerun_state_machine() helper function instead.
- Parameters:
state_save_func – optional function to save any additional state that needs to be restore to rerun the iteration.
state_restore_func – optional function to restore the state saved by state_save_func.
mode – operating mode for the rerun state machine, default is disabled.
error_injector – optional result injection engine, default is no result injection.
result_rejected_tracker_filename – optional name of file tracking
result rejectedevents.
Example usage:
def state_save_func(): # save any custom state that may change during the # forward-backward pass and that needs to be saved/restored # when re-running the iteration (Python/NumPy/Pytorch/CUDA # RNG states already taken care of) return { 'mystate': get_state(...) } def state_restore_func(state_dict): restore_state(state_dict['mystate']) initialize_rerun_state_machine( state_save_func=state_save_func, state_restore_func=state_restore_func, error_injector=RerunErrorInjector( error_injection_rate=100000, error_injection_type=RerunDiagnostic.TRANSIENT_ERROR, ), )To use the rerun state machine, the training code needs to be modified as described in the documentation for each of the public methods.
Caveats and assumptions:
A core assumption of the rerun state machine is that execution (flow control) of the iteration is deterministic w.r.t. the state captured by the rerun state (_save_state() and _restore_state() methods below). More specifically, the requirement is that a re-run of the iteration yields the same calls to validate_results() as in the initial run. On the other hand, computations are NOT required to be deterministic, i.e. results may vary slightly across re-runs of the iteration.
The re-run logic is currently only able to re-run the current step. It may be that an unexpected result (e.g. spiky loss) is the result of a calculation that happened at a previous iteration. The current implementation will not catch such issues. We’re planning to add the capability to re-run multiple steps in a future implementation.
Initialization
- REPORTING_INTERVAL_ITERATIONS: int#
2
- set_mode(mode: core.rerun_state_machine.RerunMode) None#
Method to set the operating mode
- get_mode() core.rerun_state_machine.RerunMode#
Method to get the operating mode
- _reduce_any(
- value: Union[bool, List[bool]],
All-reduce a boolean value (or multiple boolean values) across the world group.
If any of the ranks have a True value, return True. If all the ranks have a False value, return False.
For multiple inputs, returns a tuple.
- should_run_forward_backward(
- data_iterator: core.rerun_state_machine.DataIteratorArgType,
Method instructing whether to (re)run the forward-backward pass.
- Parameters:
data_iterator – data iterator or list of data iterators used in this step, or None if no data iterator
- Returns:
A boolean telling whether the forward-backward pass should be (re)run.
Example usage:
def train_step(data_iterator, ...): rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_rerun_forward_and_backward(data_iterator): optimizer.zero_grad() data = next(data) outputs = model(data) loss = loss_fn(outputs) loss.backward() ... optimizer.step()
- should_checkpoint_and_exit() Tuple[bool, bool, int]#
Method instructing whether to checkpoint and/or abort the job.
- Parameters:
None
- Returns:
a boolean telling whether a checkpoint should be taken.
a boolean telling whether the job should be aborted.
an exit code (int) to return if aborting (0 if not aborting).
- Return type:
A tuple formed of
Example usage:
def train_step(data_iterator, ...): rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_rerun_forward_and_backward(data_iterator): ... should_checkpoint, should_exit, exit_code = ( rerun_state_machine.should_checkpoint_and_exit() ) if should_checkpoint: save_checkpoint() if should_exit: sys.exit(exit_code) optimizer.step()
- validate_result(
- result: Any,
- rejection_func: Callable[[Any], bool],
- message: str,
- comparison_func: Optional[Callable[[Any, Any], float]] = None,
- tolerance: float = 0.0,
- fatal: bool = True,
This method verifies a result and possibly triggers a re-run.
- Parameters:
result – result to verify.
rejection_func – function taking a result as input and returning whether the result fails validation (e.g. torch.isnan, returns True if result is NaN).
message – message describing the validation test (e.g. “spiky loss”).
comparison_func – optional function used to compare the results of the original run and of a rerun. It should return a float representing the relative difference between the 2. The default implementation is for 0-dim float tensors.
tolerance – tolerance used in combination with comparison_func to determine reproducibility of results. Default is no tolerance (deterministic calculations).
fatal – whether to abort the job when fault attribution is complete (transient/permanent/not HW)
- Returns:
None
Example usage:
def train_step(data_iterator, ...): rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_rerun_forward_and_backward(data_iterator): optimizer.zero_grad() data = next(data) outputs = model(data) loss = loss_fn(outputs) rerun_state_machine.validate_result( result=loss, rejection_func=torch.is_nan, # rejects result if NaN message="loss is NaN", tolerance=0.001, # max 0.1% difference in results due to non-determinism fatal=True, # abort job if validation fails ) loss.backward()We establish the diagnostic using this overall flow:
an irreproducible result is detected by rerunning the iteration locally (same GPU) and verifying the result is different.
a mismatching result is detected by rerunning the iteration on a different GPU by verifying the result is different.
an expected result is detected by rerunning the iteration on a different GPU and verifying the result is the same.
- is_unexpectedly_large(
- result: torch.Tensor,
- threshold: float,
- context: str,
- num_samples: int = 100,
- resample: bool = False,
Helper method to estimate whether a result is unexpectedly large.
Some calculation errors manifest themselves as results with unexpectedly large exponents, e.g. spiky loss or grads. This method keeps track of a value over time and flags it if it exceeds a certain threshold expressed as a multiple factor of the max value observed.
- Parameters:
loss_tensor – a zero-dim tensor containing the current loss.
threshold – a float representing the minimum trigger threshold e.g. 10 means > 10x max absolute value observed.
context – a string identifying the value. This is used to differentiate between different invocations of validate_results targeting different values, e.g. loss and grads.
num_samples – the sample size used to estimate the max value. Default is 100 value samples.
reset – whether to resample the max value. Default is False.
- Returns:
A boolean telling whether the current loss deviates from the previous loss by a factor greater than the threshold
This method can be passed as a rejection function to the validate_result() method.
Example usage:
def train_step(data_iterator, ...): rerun_machine = get_rerun_machine() while rerun_machine.should_rerun_forward_and_backward(data_iterator): optimizer.zero_grad() data = next(data) outputs = model(data) loss = loss_fn(outputs) rerun_machine.validate_result( result=loss, rejection_func=partial( rerun_machine.is_unexpectedly_large, threshold=10, context="loss", ), message="Spiky loss", tolerance=0.0, fatal=False, )
- state_dict(
- data_iterator: core.rerun_state_machine.DataIteratorArgType,
- ckpt_format: str,
- force: bool = False,
Method that returns a state dict to be checkpointed.
- Parameters:
data_iterator – the data iterator that needs to be checkpointed (or None if this checkpoint is not requested by the rerun state machine).
ckpt_format – the checkpoint format to use.
- Returns:
A state dict representing the rerun state machine.
Example usage:
def save_my_model_checkpoint(data_iterator, ...): checkpoint = {} ... rerun_state_machine = get_rerun_state_machine() checkpoint['rerun_state_machine'] = ( rerun_state_machine.state_dict(data_iterator, "torch_dist") ) ... return checkpoint
- validate_state_dict(state_dict: dict[str, Any]) bool#
Method that validate a checkpoint state dict before loading it.
- Parameters:
state_dict – the state dict saved in the checkpoint and originally obtained from state_dict().
- Returns:
bool
- load_state_dict(state_dict: dict[str, Any]) None#
Method that restores the state from a checkpoint.
- Parameters:
state_dict – the state dict saved in the checkpoint and originally obtained from state_dict().
- Returns:
None
Example usage:
def load_checkpoint(checkpoint, ...) ... if 'rerun_state_machine' in checkpoint: rerun_state_machine = get_rerun_state_machine() rerun_state_machine.load_state_dict(checkpoint['rerun_state_machine'])
- _sanitize_data_iterators(
- data_iterator: core.rerun_state_machine.DataIteratorArgType,
- _get_validation_call_info(
- message: str,
Internal method to get the context about the caller to validate_result().
- _save_state() None#
Internal method that saves the state that needs to be restored when rewound.
Any state that may change during the execution of a step before the optimizer is updated, e.g. RNG state, should be saved here. The state of the data iterator is taken care separately by the RerunDataIterator class.
At this point, this only consists in the RNG state.
- _restore_state() None#
Internal method that restores the state that was saved in _save_state().
- _maybe_report_stats() None#
Internal method that reports stats if needed.
- _log_validation_error_to_file(
- status: core.rerun_state_machine.RerunValidationStatus,
- result: Any,
- message: str,
- classmethod get_skipped_iterations_from_tracker_file(
- tracker_file_name: str,
Get list of iterations to skip from results recorded in tracker file. If an “abnormality” (e.g., NaN or infinity in gradient) is seen more than once on a given rank and iteration, the corresponding iteration is skipped.
- Parameters:
tracker_file_name (str) – Name of tracker file.
- Returns:
List of iterations to skip.
- Return type:
list[int]
- class core.rerun_state_machine.RerunDataIterator(iterable: Iterable[Any])#
A wrapper class for data iterators that adds replay capability.
- Parameters:
iterable – data iterator that needs the replay capability.
make_iterable – if set, iterator is created by calling iter() on iterable.
The RerunState class below uses the rewind capability to replay all the microbatches fetched during an iteration.
Example usage:
class MyDataIterator: ... data_iterator = MyDataIterator(...) replay_data_iterator = RerunDataIterator(data_iterator)Initialization
- __next__() Any#
next method override adding replay capability.
- rewind() None#
Method to rewind the data iterator to the first microbatch of the iteration.
- advance() None#
Method to drop all the buffered microbatches and jump to the next iteration.
- state_dict() core.rerun_state_machine.SerializableStateType#
Method to capture the state of the iterator as a serializable dict.
- load_state_dict(
- state_dict: core.rerun_state_machine.SerializableStateType,
Method to restore the state saved as a serializable dict.
- class core.rerun_state_machine.QuickStats(max_size: int = 100000)#
Simple class to keep track of distribution of a statistic.
- Parameters:
max_size – maximum number of samples to keep.
Initialization
- record(data: float) None#
Record a new sample.
- combine(others: list[core.rerun_state_machine.QuickStats]) None#
Append the samples from multiple instances into one object.
- reset() None#
Forget all data.
- print_stats() str#
Return a string describing the data distribution.
- __getstate_() Any#
Pickle method, used by torch.distributed.gather_object.
- __setstate(state: Any) Any#
Unpickle method, used by torch.distributed.gather_object.
- class core.rerun_state_machine.RerunErrorInjector(
- error_injection_rate: int = 0,
- error_injection_type: core.rerun_state_machine.RerunDiagnostic = RerunDiagnostic.TRANSIENT_ERROR,
A class to manage error injection into the rerun state machine.
Initialization
- _ERROR_NAMES: dict[core.rerun_state_machine.RerunDiagnostic, str]#
None
- maybe_inject() bool#
Method that decides whether to inject an error.
- maybe_miscompare(
- comparison_func: Callable[[Any, Any], float],
- initial_result: Any,
- result: Any,
- state: core.rerun_state_machine.RerunState,
Method that introduces mismatching results during reruns when an error is injected.
When no error is injected, this method defers to the user-provided comparison function. When an error is injected, it returns matching or mismatching results depending on the type of error being injected and on the re-run state.
- state_dict() core.rerun_state_machine.SerializableStateType#
Method to capture the state of the error injector as a serializable dict.
- load_state_dict(
- state_dict: core.rerun_state_machine.SerializableStateType,
Method to restore the state saved as a serializable dict.
- core.rerun_state_machine.initialize_rerun_state_machine(**kwargs) None#
Helper function to initialize the rerun machine instance.
Check the RerunStateMachine class for the details.
- core.rerun_state_machine.destroy_rerun_state_machine() None#
Helper function to shut down the rerun machine instance.
- core.rerun_state_machine.get_rerun_state_machine() core.rerun_state_machine.RerunStateMachine#
Helper function to return the singleton instance of the rerun machine.
- core.rerun_state_machine._set_rerun_state_machine(rerun_state_machine) None#
Internal function to set the singleton instance of the rerun machine.
- core.rerun_state_machine._safe_get_rank() int#
Internal function that safely checks and returns the rank of the caller.
- core.rerun_state_machine._compare_floats(a: torch.Tensor, b: torch.Tensor) float#
Internal function that implements the default compare_func.
Check the validate_result() method of the RerunStateMachine class for details.