nvmidl.apps.fed_learn.trainers package
-
class
ClientTrainer(trainer, uid, num_epochs, server_config, client_config, privacy, secure_train, save_checkpoint=False, dynamic_input_shape=False, train_epoch_stats_logger=None, train_step_stats_logger=None, train_events_logger=None, model_reader_writer=None, model_validator=None, pre_processors=None, post_processors=None, req_processors=None, extra_inputs=None, handlers=None) Bases:
ai4med.workflows.trainers.supervised_trainer.SupervisedTrainer-
close() Shuts down the trainer.
Returns:
-
create_fed_client()
-
deploy()
-
train() Start the training workflow.
The trainer builds the computation graph, creates and initializes a supervised fitter, and calls the fitter to fit.
Returns:
-
-
class
ServerTrainer(task, loss, initial_learning_rate, model, data_source, optimizer=None, is_multi_gpu=False, ckpt_preload_path=None, model_aggregator=None, model_saver=None, pre_processors=None, post_processors=None, cmd_modules=None, result_processors=None, secure_train=False, dynamic_input_shape=False, extra_inputs=None, handlers=None, mmar_validator=None) Bases:
ai4med.workflows.trainers.trainer.Trainer-
close() Shuts down the trainer.
Returns:
-
create_TF_builder(services)
-
create_fl_server()
-
deploy()
-
set_model_log_dir(log_dir)
-
set_server_config(server_config)
-
start_training(services)
-
train() Start training process.
Returns:
-