medl.apps.fed_learn package
- class ByowFLEngine(engine)
Bases:
medl.apps.fed_learn.fl_engine.FLEngineSpec
- abort()
Call to terminate the current running train / validate job. Returns:
- close()
Call to terminate and close the engine. Returns:
- evaluate()
Call to evaluate the current model. Returns:
- get_epoch()
Get the current epoch number of the trainer Returns: epoch number
- get_iteration()
Get the current iteration number of the trainer Returns: iteration number
- get_key_metric_name()
Get the key metric name of the validator Returns: key metric name
- get_metrics()
Get the current validation metrics of the validator Returns: validation metrics
- get_network()
Get the current network of the trainer Returns: network
- get_num_of_gpu()
Get the number of GPUs of the trainer. Returns: num_of_gpu
- get_validation_network()
Get the current network of the Validator Returns: network
- init_train()
Initialize the train at the beginning of the round Returns:
- initialize(parts: dict, integration_config: [], log_dir: str)
Call to initialize the training engine. :param parts: configured component parts to be integrated into the engine. :param integration_config: config of integration with Fl components and engine.
Returns:
- train()
Call the engine to train model. Returns:
- validate()
Call to validate the current model. Returns:
- validate_before_train()
Validate at the beginning of each train round Returns:
- class FLEngineSpec
Bases:
medl.apps.engine_spec.EngineSpec
- get_epoch()
Get the current epoch number of the trainer Returns: epoch number
- get_iteration()
Get the current iteration number of the trainer Returns: iteration number
- get_key_metric_name()
Get the key metric name of the validator Returns: key metric name
- get_learning_rate()
Get the current learning rate of the trainer Returns: learning_rate
- get_metrics()
Get the current validation metrics of the validator Returns: validation metrics
- get_network()
Get the current network of the trainer Returns: network
- get_num_of_gpu()
Get the number of GPUs of the trainer. Returns: num_of_gpu
- get_validation_network()
Get the current network of the Validator Returns: network
- init_train()
Initialize the train at the beginning of the round Returns:
- initialize(parts: dict, integration_config: [], log_dir: str)
Call to initialize the training engine. :param parts: configured component parts to be integrated into the engine. :param integration_config: config of integration with Fl components and engine.
Returns:
- validate_before_train()
Validate at the beginning of each train round Returns:
- class IgniteFLEngine(engine, aggregation_epochs, aggregation_steps, cross_validator)
Bases:
medl.apps.fed_learn.fl_engine.FLEngineSpec
- abort()
Call to terminate the current running train / validate job. Returns:
- close()
Call to terminate and close the engine. Returns:
- evaluate()
Call to evaluate the current model. Returns:
- get_epoch()
Get the current epoch number of the trainer Returns: epoch number
- get_iteration()
Get the current iteration number of the trainer Returns: iteration number
- get_key_metric_name()
Get the key metric name of the validator Returns: key metric name
- get_learning_rate()
Get the current learning rate of the trainer Returns: learning_rate
- get_metrics()
Get the current validation metrics of the validator Returns: validation metrics
- get_network()
Get the current network of the trainer Returns: network
- get_num_of_gpu()
Get the number of GPUs of the trainer. Returns: num_of_gpu
- get_validation_network()
Get the current network of the Validator Returns: network
- init_train()
Initialize the train at the beginning of the round Returns:
- initialize(parts: dict, integration_config: [], log_dir: str)
Call to initialize the training engine. :param parts: configured component parts to be integrated into the engine. :param integration_config: config of integration with Fl components and engine.
Returns:
- train()
Call the engine to train model. Returns:
- validate()
Call to validate the current model. Returns:
- validate_before_train()
Validate at the beginning of each train round Returns:
- class TrainContext
Bases:
dlmed.common.ctx.SimpleContext
- class FLTrainer(mmar_root, args, aggregation_epochs, aggregation_steps)
Bases:
object
- MMAR_CROSS_VALIDATION_CONFIG = 'config/config_cross_validation.json'
- MMAR_TRAIN_CONFIG = 'config/config_train.json'
- get_fl_engine()