nvmidl.apps.fed_learn.trainers package
-
class
ClientTrainer
(trainer, uid, num_epochs, server_config, client_config, privacy, secure_train, save_checkpoint=False, dynamic_input_shape=False, train_epoch_stats_logger=None, train_step_stats_logger=None, train_events_logger=None, model_reader_writer=None, model_validator=None, pre_processors=None, post_processors=None, req_processors=None, extra_inputs=None, handlers=None) Bases:
<a href="../ai4med/ai4med.workflows.trainers.html#ai4med.workflows.trainers.supervised_trainer.SupervisedTrainer">ai4med.workflows.trainers.supervised_trainer.SupervisedTrainer</a>
-
close
() Shuts down the trainer.
Returns:
-
create_fed_client
()
-
deploy
()
-
train
() Start the training workflow.
The trainer builds the computation graph, creates and initializes a supervised fitter, and calls the fitter to fit.
Returns:
-
-
class
ServerTrainer
(task, loss, initial_learning_rate, model, data_source, optimizer=None, is_multi_gpu=False, ckpt_preload_path=None, model_aggregator=None, model_saver=None, pre_processors=None, post_processors=None, cmd_modules=None, result_processors=None, secure_train=False, dynamic_input_shape=False, extra_inputs=None, handlers=None, mmar_validator=None) Bases:
<a href="../ai4med/ai4med.workflows.trainers.html#ai4med.workflows.trainers.trainer.Trainer">ai4med.workflows.trainers.trainer.Trainer</a>
-
close
() Shuts down the trainer.
Returns:
-
create_TF_builder
(services)
-
create_fl_server
()
-
deploy
()
-
set_model_log_dir
(log_dir)
-
set_server_config
(server_config)
-
start_training
(services)
-
train
() Start training process.
Returns:
-