medl.apps.fed_learn.executors package

class FederatedLearner(local_epochs=1, steps_aggregation=0, model_reader_writer=None, integration_config=None)

Bases: nvflare.app_common.abstract.learner_spec.Learner

Init FLComponent.

The FLComponent is the base class of all FL Components. (executors, controllers, responders, filters, aggregrators, and widgets are all FLComponents)

FLComponents have the capability to handle and fire events and contain various methods for logging.

abort(fl_ctx: nvflare.apis.fl_context.FLContext)

Called (from another thread) to abort the current task (validate or train).

Note: this is to abort the current task only, not the Trainer. After aborting, the Learner. may still be called to perform another task.

Parameters

fl_ctx – FLContext of the running environment

finalize(fl_ctx: nvflare.apis.fl_context.FLContext)

Called to finalize the Learner (close/release resources gracefully).

After this call, the Learner will be destroyed.

Parameters

fl_ctx – FLContext of the running environment

get_model_for_validation(model_name: str, fl_ctx: nvflare.apis.fl_context.FLContext) → nvflare.apis.shareable.Shareable

Called to return the trained model from the Learner.

Parameters
  • model_name – type of the model for validation

  • fl_ctx – FLContext of the running environment

Returns: trained model for validation

initialize(parts: dict, fl_ctx: nvflare.apis.fl_context.FLContext)

Initialize the Learner object. This is called before the Learner can train or validate.

This is called only once.

Parameters
  • parts – components to be used by the Trainer

  • fl_ctx – FLContext of the running environment

train(data: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext, abort_signal: nvflare.apis.signal.Signal) → nvflare.apis.shareable.Shareable

Called to perform training. Can be called many times during the lifetime of the Learner.

Parameters
  • data – the training input data (e.g. model weights)

  • fl_ctx – FLContext of the running environment

  • abort_signal – signal to abort the train

Returns: train result in Shareable

validate(data: nvflare.apis.shareable.Shareable, fl_ctx: nvflare.apis.fl_context.FLContext, abort_signal: nvflare.apis.signal.Signal) → nvflare.apis.shareable.Shareable

Called to perform validation. Can be called many times during the lifetime of the Learner.

Parameters
  • data – the training input data (e.g. model weights)

  • fl_ctx – FLContext of the running environment

  • abort_signal – signal to abort the train

Returns: validate result in Shareable

class ModelShareableManager(model_reader_writer)

Bases: object

assign_current_model(network, multi_gpu, model_weights, fl_ctx)

Assign the shareable to the current model network. :param model_weights: model weights :param fl_ctx:

Returns:

extract_current_model(network, multi_gpu, model_vars, train_ctx, fl_ctx)
© Copyright 2021, NVIDIA. Last updated on Feb 2, 2023.