ai4med.workflows.trainers package

class SupervisedTrainer(task: str, num_epochs: int, initial_learning_rate, model, loss, train_data_source, optimizer=None, train_summary_recording_interval=1, is_multi_gpu=False, dynamic_input_shape=False, train_epoch_stats_logger=None, train_step_stats_logger=None, train_events_logger=None, extra_inputs=None, stop_threshold=None, infer_in_training_mode=False, handlers=None)

Bases: ai4med.workflows.trainers.trainer.Trainer

This class implements a TF1 based supervised training workflow.

Parameters
  • task (str) – type of the training task: segmentation or classification

  • num_epochs (int) – number of training epochs

  • model (Model) – the model architecture

  • initial_learning_rate (float) – initial learning rate

  • loss (Loss) – the loss component

  • train_data_source (DataPipeline) – the data pipeline producing training data

  • optimizer (Optimizer) – the optimizer. If not specified, default to Adam.

  • train_summary_recording_interval (int) – interval (number of steps) for logging training summary data. 0

  • every epoch. (means) –

  • is_multi_gpu (bool) – whether to use multiple GPUs for training

  • dynamic_input_shape (bool) – whether to use dynamic input shape when building the graph

  • train_epoch_stats_logger – logger for logging per-epoch stats. Called at end of each epoch.

  • not specified (If) –

  • PrintTrainEpochStats is used. (the) –

  • train_step_stats_logger – logger for logging per-step stats. Called at end of each training step.

  • not specified

  • PrintTrainStepStats is used. (the) –

  • train_events_logger – logger for recording training summary data. Called at end of each step.

  • not specified

  • RecordTrainEvents is used. (the) –

add_auxiliary_operation(op)

Add an aux op to the graph. Call this method once for each aux op needed.

Parameters

op – the aux op to be added

Returns:

add_validation_metric(m: ai4med.components.metrics.metric.Metric)

Add a validation metric. Call this method once for each metric needed.

Parameters

m – the validation metric component

Returns:

close()

Shuts down the trainer.

Returns:

continue_from_checkpoint_epoch()

Specify whether to start the epoch counter from the epoch of the previously trained model, instead of starting from 0.

Returns:

get_train_context()
set_checkpoint_preload_path(path)

Set the path to the checkpoint files of previously trained model. Call this method only if you want to fine-tune the model based on a previous trained model.

Parameters

path

Returns:

set_learning_rate_policy(policy)

Set the learning rate policy. Only one LR policy can be set. Calling this method will overwrite the current policy.

Parameters

policy – the LR policy to be set.

Returns:

set_model_log_dir(log_dir: str)

Set model log file directory. This directory is used for saving data results produced during training, specifically the trained model checkpoint files and tensorboard event files.

If this method is not called or the log_dir is empty, then training results will not be saved.

Parameters

log_dir (str) – the log directory

Returns:

set_validation(data_source, inferer, val_interval, val_stats_logger=None, val_events_logger=None)

Set properties for model validation during training. If this method is not called, then validation will not be performed during training.

Parameters
  • data_source (DataPipeline) – the data pipeline producing validation samples

  • inferer (Inferer) – the inferer for performing inference

  • val_interval (int) – interval (number of epochs) for performing validation

  • val_stats_logger (StatsLogger) – logger for logging validation stats. If not specified, PrintValidationStats

  • be used as default. (will) –

  • val_events_logger – logger for logging validation summary data. If not specified, RecordValidationEvents

  • be used as default.

Returns:

train()

Start the training workflow.

The trainer builds the computation graph, creates and initializes a supervised fitter, and calls the fitter to fit.

Returns:

class Trainer

Bases: abc.ABC

This class defines the required methods for trainer implementations.

build(build_ctx)
abstract close()

Shuts down the trainer.

Returns:

get_train_context()
abstract train()

Start training process.

Returns:

© Copyright 2020, NVIDIA. Last updated on Feb 2, 2023.