ai4med.workflows.trainers package
-
class
SupervisedTrainer
(task: str, num_epochs: int, initial_learning_rate, model, loss, train_data_source, optimizer=None, train_summary_recording_interval=1, is_multi_gpu=False, dynamic_input_shape=False, train_epoch_stats_logger=None, train_step_stats_logger=None, train_events_logger=None, extra_inputs=None, stop_threshold=None, infer_in_training_mode=False, handlers=None) Bases:
ai4med.workflows.trainers.trainer.Trainer
This class implements a TF1 based supervised training workflow.
- Parameters
task (str) – type of the training task: segmentation or classification
num_epochs (int) – number of training epochs
model (Model) – the model architecture
initial_learning_rate (float) – initial learning rate
loss (Loss) – the loss component
train_data_source (DataPipeline) – the data pipeline producing training data
optimizer (Optimizer) – the optimizer. If not specified, default to Adam.
train_summary_recording_interval (int) – interval (number of steps) for logging training summary data. 0
every epoch. (means) –
is_multi_gpu (bool) – whether to use multiple GPUs for training
dynamic_input_shape (bool) – whether to use dynamic input shape when building the graph
train_epoch_stats_logger – logger for logging per-epoch stats. Called at end of each epoch.
not specified (If) –
PrintTrainEpochStats is used. (the) –
train_step_stats_logger – logger for logging per-step stats. Called at end of each training step.
not specified –
PrintTrainStepStats is used. (the) –
train_events_logger – logger for recording training summary data. Called at end of each step.
not specified –
RecordTrainEvents is used. (the) –
-
add_auxiliary_operation
(op) Add an aux op to the graph. Call this method once for each aux op needed.
- Parameters
op – the aux op to be added
Returns:
-
add_validation_metric
(m: ai4med.components.metrics.metric.Metric) Add a validation metric. Call this method once for each metric needed.
- Parameters
m – the validation metric component
Returns:
-
close
() Shuts down the trainer.
Returns:
-
continue_from_checkpoint_epoch
() Specify whether to start the epoch counter from the epoch of the previously trained model, instead of starting from 0.
Returns:
-
get_train_context
()
-
set_checkpoint_preload_path
(path) Set the path to the checkpoint files of previously trained model. Call this method only if you want to fine-tune the model based on a previous trained model.
- Parameters
path –
Returns:
-
set_learning_rate_policy
(policy) Set the learning rate policy. Only one LR policy can be set. Calling this method will overwrite the current policy.
- Parameters
policy – the LR policy to be set.
Returns:
-
set_model_log_dir
(log_dir: str) Set model log file directory. This directory is used for saving data results produced during training, specifically the trained model checkpoint files and tensorboard event files.
If this method is not called or the log_dir is empty, then training results will not be saved.
- Parameters
log_dir (str) – the log directory
Returns:
-
set_validation
(data_source, inferer, val_interval, val_stats_logger=None, val_events_logger=None) Set properties for model validation during training. If this method is not called, then validation will not be performed during training.
- Parameters
data_source (DataPipeline) – the data pipeline producing validation samples
inferer (Inferer) – the inferer for performing inference
val_interval (int) – interval (number of epochs) for performing validation
val_stats_logger (StatsLogger) – logger for logging validation stats. If not specified, PrintValidationStats
be used as default. (will) –
val_events_logger – logger for logging validation summary data. If not specified, RecordValidationEvents
be used as default. –
Returns:
-
train
() Start the training workflow.
The trainer builds the computation graph, creates and initializes a supervised fitter, and calls the fitter to fit.
Returns:
-
class
Trainer
Bases:
abc.ABC
This class defines the required methods for trainer implementations.
-
build
(build_ctx)
-
abstract
close
() Shuts down the trainer.
Returns:
-
get_train_context
()
-
abstract
train
() Start training process.
Returns:
-