medl.apps.fed_learn package

1.0
class ByowFLEngine(engine)

Bases: medl.apps.fed_learn.fl_engine.FLEngineSpec

abort()

Call to terminate the current running train / validate job. Returns:

close()

Call to terminate and close the engine. Returns:

evaluate()

Call to evaluate the current model. Returns:

get_epoch()

Get the current epoch number of the trainer Returns: epoch number

get_iteration()

Get the current iteration number of the trainer Returns: iteration number

get_key_metric_name()

Get the key metric name of the validator Returns: key metric name

get_metrics()

Get the current validation metrics of the validator Returns: validation metrics

get_network()

Get the current network of the trainer Returns: network

get_num_of_gpu()

Get the number of GPUs of the trainer. Returns: num_of_gpu

get_validation_network()

Get the current network of the Validator Returns: network

init_train()

Initialize the train at the beginning of the round Returns:

initialize(parts: dict, integration_config: [], log_dir: str)

Call to initialize the training engine. :param parts: configured component parts to be integrated into the engine. :param integration_config: config of integration with Fl components and engine.

Returns:

train()

Call the engine to train model. Returns:

validate()

Call to validate the current model. Returns:

validate_before_train()

Validate at the beginning of each train round Returns:

class FLEngineSpec

Bases: medl.apps.engine_spec.EngineSpec

get_epoch()

Get the current epoch number of the trainer Returns: epoch number

get_iteration()

Get the current iteration number of the trainer Returns: iteration number

get_key_metric_name()

Get the key metric name of the validator Returns: key metric name

get_learning_rate()

Get the current learning rate of the trainer Returns: learning_rate

get_metrics()

Get the current validation metrics of the validator Returns: validation metrics

get_network()

Get the current network of the trainer Returns: network

get_num_of_gpu()

Get the number of GPUs of the trainer. Returns: num_of_gpu

get_validation_network()

Get the current network of the Validator Returns: network

init_train()

Initialize the train at the beginning of the round Returns:

initialize(parts: dict, integration_config: [], log_dir: str)

Call to initialize the training engine. :param parts: configured component parts to be integrated into the engine. :param integration_config: config of integration with Fl components and engine.

Returns:

validate_before_train()

Validate at the beginning of each train round Returns:

class IgniteFLEngine(engine, aggregation_epochs, aggregation_steps, cross_validator)

Bases: medl.apps.fed_learn.fl_engine.FLEngineSpec

abort()

Call to terminate the current running train / validate job. Returns:

close()

Call to terminate and close the engine. Returns:

evaluate()

Call to evaluate the current model. Returns:

get_epoch()

Get the current epoch number of the trainer Returns: epoch number

get_iteration()

Get the current iteration number of the trainer Returns: iteration number

get_key_metric_name()

Get the key metric name of the validator Returns: key metric name

get_learning_rate()

Get the current learning rate of the trainer Returns: learning_rate

get_metrics()

Get the current validation metrics of the validator Returns: validation metrics

get_network()

Get the current network of the trainer Returns: network

get_num_of_gpu()

Get the number of GPUs of the trainer. Returns: num_of_gpu

get_validation_network()

Get the current network of the Validator Returns: network

init_train()

Initialize the train at the beginning of the round Returns:

initialize(parts: dict, integration_config: [], log_dir: str)

Call to initialize the training engine. :param parts: configured component parts to be integrated into the engine. :param integration_config: config of integration with Fl components and engine.

Returns:

train()

Call the engine to train model. Returns:

validate()

Call to validate the current model. Returns:

validate_before_train()

Validate at the beginning of each train round Returns:

class TrainContext

Bases: dlmed.common.ctx.SimpleContext

class FLTrainer(mmar_root, args, aggregation_epochs, aggregation_steps)

Bases: object

MMAR_CROSS_VALIDATION_CONFIG = 'config/config_cross_validation.json'
MMAR_TRAIN_CONFIG = 'config/config_train.json'
get_fl_engine()
© Copyright 2021, NVIDIA. Last updated on Feb 2, 2023.