nvmidl.apps.fed_learn.trainers package

class ClientTrainer(trainer, uid, num_epochs, server_config, client_config, privacy, secure_train, save_checkpoint=False, dynamic_input_shape=False, train_epoch_stats_logger=None, train_step_stats_logger=None, train_events_logger=None, model_reader_writer=None, model_validator=None, pre_processors=None, post_processors=None, req_processors=None, extra_inputs=None, handlers=None)

Bases: ai4med.workflows.trainers.supervised_trainer.SupervisedTrainer

close()

Shuts down the trainer.

Returns:

create_fed_client()
deploy()
train()

Start the training workflow.

The trainer builds the computation graph, creates and initializes a supervised fitter, and calls the fitter to fit.

Returns:

class ServerTrainer(task, loss, initial_learning_rate, model, data_source, optimizer=None, is_multi_gpu=False, ckpt_preload_path=None, model_aggregator=None, model_saver=None, pre_processors=None, post_processors=None, cmd_modules=None, result_processors=None, secure_train=False, dynamic_input_shape=False, extra_inputs=None, handlers=None, mmar_validator=None)

Bases: ai4med.workflows.trainers.trainer.Trainer

close()

Shuts down the trainer.

Returns:

create_TF_builder(services)
create_fl_server()
deploy()
set_model_log_dir(log_dir)
set_server_config(server_config)
start_training(services)
train()

Start training process.

Returns:

© Copyright 2020, NVIDIA. Last updated on Feb 2, 2023.