fed_learn.components package

class DataCompressor

Bases: fed_learn.components.data_processor.DataProcessor

process(data_ctx, app_ctx)

To compress the data_ctx.

Parameters
  • data_ctx

  • app_ctx

Returns

class DataDeCompressor

Bases: fed_learn.components.data_processor.DataProcessor

process(data_ctx, app_ctx)

To de_encrypt the data_ctx.

Parameters
  • data_ctx

  • app_ctx

Returns

class DataProcessor

Bases: object

abort(app_ctx)

Called to abort its work immediately :param app_ctx: :return:

process(data_ctx, app_ctx)

Called to perform data processing :param data_ctx: the context that contains the data points transformed/generated :param app_ctx: the overall app context (e.g. current phase, round, task, …) :return:

class ModelDecryptor

Bases: fed_learn.components.data_processor.DataProcessor

process(data_ctx, app_ctx)

To decrypt the data_ctx.

Parameters
  • data_ctx

  • app_ctx

Returns

class ModelEncryptor

Bases: fed_learn.components.data_processor.DataProcessor

process(data_ctx, app_ctx)

To encrypt the data_ctx.

Parameters
  • data_ctx

  • app_ctx

Returns

class DictModelReaderWriter

Bases: fed_learn.components.model_reader_writer.ModelProcessor

This is a training-platform independent model reader-writer. - The local model is dict-based - The local model will be accommodated inside client local memory

apply_model(model_params, options=None)

Assign the local dict model with params from remote model.

Parameters
  • model_params – a ModelData message or dict of locally trained weights.

  • options – additional options when applying model. no_proto_convert: If this is set, then do not execute proto conversion.

Returns

True if any vars in model_params are successfully assigned.

extract_model(model_vars)

We build a dict of weights to return, whose name exists inside model_vars.

Parameters

model_vars – list of var names to get

Returns

which could be part or complete set of our local_model

get_local_models()
initialize(fitter=None)
We actually do not need a fitter here, but just to initialize our

local model to be an empty dict.

class BestMetricGenerator(field)

Bases: fed_learn.components.metadata_generator.MetadataGenerator

process(data_ctx: fed_learn.model_meta.FLContext, app_ctx)

Called to perform data processing :param data_ctx: the context that contains the data points transformed/generated :param app_ctx: the overall app context (e.g. current phase, round, task, …) :return:

class InitialValidationMetricGenerator(field)

Bases: fed_learn.components.metadata_generator.MetadataGenerator

process(data_ctx: fed_learn.model_meta.FLContext, app_ctx)

Called to perform data processing :param data_ctx: the context that contains the data points transformed/generated :param app_ctx: the overall app context (e.g. current phase, round, task, …) :return:

class IteratorNumberGenerator(field)

Bases: fed_learn.components.metadata_generator.MetadataGenerator

process(data_ctx: fed_learn.model_meta.FLContext, app_ctx)

Called to perform data processing :param data_ctx: the context that contains the data points transformed/generated :param app_ctx: the overall app context (e.g. current phase, round, task, …) :return:

class MetadataGenerator(field)

Bases: fed_learn.components.data_processor.DataProcessor

get_field()
class MetricManager(metric_map=None)

Bases: object

Manages metrics for federated learning.

add_group(group_name, metric_dict)
encode_metrics() → dict

Encodes metric map (2-level dict) into a single level dict.

get_group(group_name)
to_dict()
decode_metrics(encoded_metrics)

Decodes encodec_metric (single level dict) back into 2 level dict.

validate_metrics(metric_map)
class ModelEvaluator

Bases: abc.ABC

abstract evaluate()

Runs the evaluation process.

abstract get_metric()

Gets the metric for selecting best model

Returns

An float number represents the metric.

initialize()
class ModelProcessor

Bases: abc.ABC

abstract apply_model(model_params, options=None)
abstract extract_model(model_vars)
abstract get_local_models()
abstract initialize(fitter)
class ModelSaver

Bases: abc.ABC

initialize(builder=None)
abstract load_model()
abstract save_model(model, is_best)
class ModelValidator

Bases: abc.ABC

abstract close()
abstract validate_model(configer, checkpoint)
class ModelVisualizer

Bases: abc.ABC

abstract visualize(model_data)
feed_vars(model: torch.nn.modules.module.Module, model_params)

feed variable values from model_params to pytorch state_dict.

Parameters
  • model – the local pytorch model

  • model_params – a ModelData message

Returns

the assigned ops

class PTModelReaderWriter

Bases: fed_learn.components.model_reader_writer.ModelProcessor

apply_model(model_params, options=None)

Set the local model according to model_data

Parameters

model_params – a ModelData message

Returns

True if the local model changed

extract_model(model_vars)
get_local_models()
initialize(fitter)

Set the fitter for pt model reader and writer.

class PTModelSaver(exclude_vars=None, model_log_dir=None, ckpt_preload_path=None)

Bases: fed_learn.components.model_saver.ModelSaver

close()
initialize(builder=None)
load_model()

Convert initialised model into protobuf message. This function sets self.model to a ModelData protobuf message.

save_model(model, is_best)
class RoundRobinClientFieldProcessor

Bases: fed_learn.components.data_processor.DataProcessor

This processor is called before the end of GetModel to handle round robin related fields assignment.

abort(app_ctx)

Called to abort its work immediately :param app_ctx: :return:

process(data_ctx, app_ctx)

Called to perform data processing :param data_ctx: the context that contains the data points transformed/generated :param app_ctx: the overall app context (e.g. current phase, round, task, …) :return:

class RoundRobinClientStateProcessor

Bases: fed_learn.components.data_processor.DataProcessor

This processor is called durig the api get_client_state to pass round robin specific rr_train_steps

abort(app_ctx)

Called to abort its work immediately :param app_ctx: :return:

process(data_ctx, app_ctx)

Called to perform data processing :param data_ctx: the context that contains the data points transformed/generated :param app_ctx: the overall app context (e.g. current phase, round, task, …) :return:

class TensorboardModelVisualizer

Bases: fed_learn.components.model_visualizer.ModelVisualizer

visualize(model_data)

Add model diff values into tensorboard for visualization.

Parameters

model_data – it should contain model_diff values.

add_tensorboard_histogram(tf_writer, step, tag, values, bins=1000)

add numpy to tensorboard histogram

copy_ckpt_model(model_dir, src_prefix, dst_prefix)

Copy CKPTs related model files from src to dst in same model_dir.

feed_vars(tf_session, model_params)

feed variable values from model_params to tf_session.

Parameters
  • tf_session – a tensorflow session instance

  • model_params – a ModelData message

Returns

the assign ops

global_vars()
make_feedable(tf_graph)

Make the graph writable via placeholders.

Parameters

tf_graph – tensorflow graph with whose trainable vars to be written

make_ndarray(tensor)
make_tensor(value)
new_session(fitter)
tf_reset_default_graph()
class TFKerasV1ModelSaver(exclude_vars=None, model_log_dir=None, model_path=None, options=None)

Bases: fed_learn.components.model_saver.ModelSaver

This Class handles the model saving logic for TF1.x+Keras trained hdf5 model.

close()
get_proto_model_from_hdf5()

Convert loaded keras model into protobuf message. This function sets self.model to a ModelData protobuf message.

initialize(builder=None)
load_model()
property model_log_dir
save_model(model, is_best)
class TFModelReaderWriter

Bases: fed_learn.components.model_reader_writer.ModelProcessor

apply_model(model_params, options=None)

Set the local tf session’s model according to model_data

Parameters

model_params – a ModelData message

Returns

True if the local model changed

extract_model(model_vars)
get_local_models()

Returns the local models as a dictionary of files, bytes. Includes best model and final_model.

Returns

dictionary of filename, bytes.

initialize(fitter)

Set the fitter for tf model reader and writer.

class TFModelSaver(exclude_vars=None, model_log_dir=None, ckpt_preload_path=None, options=None)

Bases: fed_learn.components.model_saver.ModelSaver

close()
initialize(builder=None)
initialize_uninitialized(session)
load_model()
load_weights_into_session(session)

Create a TF session for the server Loading existing checkpoint if required.

make_init_proto()

Convert initialised model into protobuf message. This function sets self.model to a ModelData protobuf message.

save_checkpoint(is_best)

Save the model as a checkpoint.

save_model(model, is_best)
class TFModelValidator

Bases: fed_learn.components.model_validator.ModelValidator

close()
validate_model(cross_site_val_configer, checkpoint)

Validate a model using a CrossValConfiger and checkpoint files.

Parameters
  • cross_site_val_configer – An instance of CrossValConfigurer

  • model_params – A dict of name, bytes for checkpoint files.

© Copyright 2020, NVIDIA. Last updated on Feb 2, 2023.