-
class
Fitter
Bases:
abc.ABC
This class defines required methods for fitter implementations
-
abstract
close
() Shuts down the fitter.
Returns:
-
abstract
fit
() Executes the fitting loop
Returns:
-
get_train_context
()
-
abstract
-
class
SupervisedFitter
(num_epochs: int, train_data_source: ai4med.components.data.data_pipeline.DataPipeline, initial_learning_rate, learning_rate_placeholder, model_input_placeholder, label_input_placeholder, is_train_placeholder, model_output_tensor, step_tensors: dict, task='segmentation', tf_config=None, graph=None, running_in_ngc=False, multi_gpu=False, summary_tensors=None, print_tensors=None, val_output_tensors=None, model_log_dir=None, val_data_source=None, val_inferer=None, val_interval=0, val_metrics=None, save_checkpoint=True, log_graph=False, continue_from_ckpt_epoch=False, ckpt_preload_path=None, train_summary_recording_interval=1, lr_policy=None, train_epoch_stats_logger=None, train_step_stats_logger=None, train_events_logger=None, val_stats_logger=None, val_events_logger=None, microbatches_size=None, global_round=0, federate_learn=False, num_steps_for_aggr=0, stop_threshold=None, infer_in_training_mode=False, handlers=None, build_ctx=None) Bases:
ai4med.workflows.fitters.fitter.Fitter
-
close
() Shuts down the fitter.
Returns:
-
do_validation
()
-
fit
() Executes the fitting loop
Returns:
-
get_next_batch
(data_source)
-
get_train_context
()
-
graph_reset
()
-
initialize
()
-
restore_session
(session, train_ctx)
-
validate_and_log_tensorboard
(all_val_metrics, my_rank, session, train_ctx)
-
-
log_to_tensorboard
(summary_writer, step, summary_dict) Utility function for logging to tensorboard
- Parameters
summary_writer – a TF summary file writer
step – iteration number
dict – dictionary, with key value pairs of metric_name and metric value, respectively
-
print_progress_train
(i, num_iterations, e, epochs, print_summaries, duration)
-
print_progress_validation
(e, epochs, train_summaries, validation_summaries, duration)