ai4med.workflows.fitters package

3.1
class Fitter

Bases: abc.ABC

This class defines required methods for fitter implementations

abstract close()

Shuts down the fitter.

Returns:

abstract fit()

Executes the fitting loop

Returns:

get_train_context()
class SupervisedFitter(num_epochs: int, train_data_source: ai4med.components.data.data_pipeline.DataPipeline, initial_learning_rate, learning_rate_placeholder, model_input_placeholder, label_input_placeholder, is_train_placeholder, model_output_tensor, step_tensors: dict, task='segmentation', tf_config=None, graph=None, running_in_ngc=False, multi_gpu=False, summary_tensors=None, print_tensors=None, val_output_tensors=None, model_log_dir=None, val_data_source=None, val_inferer=None, val_interval=0, val_metrics=None, save_checkpoint=True, log_graph=False, continue_from_ckpt_epoch=False, ckpt_preload_path=None, train_summary_recording_interval=1, lr_policy=None, train_epoch_stats_logger=None, train_step_stats_logger=None, train_events_logger=None, val_stats_logger=None, val_events_logger=None, microbatches_size=None, global_round=0, federate_learn=False, num_steps_for_aggr=0, stop_threshold=None, infer_in_training_mode=False, handlers=None, build_ctx=None)

Bases: ai4med.workflows.fitters.fitter.Fitter

close()

Shuts down the fitter.

Returns:

do_validation()
fit()

Executes the fitting loop

Returns:

get_next_batch(data_source)
get_train_context()
graph_reset()
initialize()
restore_session(session, train_ctx)
validate_and_log_tensorboard(all_val_metrics, my_rank, session, train_ctx)
log_to_tensorboard(summary_writer, step, summary_dict)

Utility function for logging to tensorboard

Parameters
  • summary_writer – a TF summary file writer

  • step – iteration number

  • dict – dictionary, with key value pairs of metric_name and metric value, respectively

print_progress_train(i, num_iterations, e, epochs, print_summaries, duration)
print_progress_validation(e, epochs, train_summaries, validation_summaries, duration)
© Copyright 2020, NVIDIA. Last updated on Feb 2, 2023.