Bring your own components for federated learning

The components used in Clara Federated Learning are designed to allow for users to bring their own components in a modular way like for Bring your own components (BYOC).

Please look through this FL Jupyter notebook for detailed examples.

Below is a list of example components users can add.

Model Aggregator

Model aggregation happens on the FL Server as specified in the config_fed_server.json file. Clara train comes with a built in aggregator:

Built in aggregator

This aggregator is based on an algorithm in Federated Learning for Breast Density Classification: A Real-World Implementation. The ModelAggregator computes a weighted sum of the model gradients from each client, where the default weights are based on the number of training iterations that the client executed in this round of FL. The user can further adjust the client weights by adding custom weights in the argument of this component in config_fed_server.json.

Bring your own Aggregator to FL

Users can write their own aggregators, and just change the component name tag into path to point to the custom code like for Bring your own components (BYOC). A sample of a custom aggregator is shown below:

import logging

import numpy as np

from fed_learn.numproto import proto_to_ndarray, ndarray_to_proto
from fed_learn.server.model_aggregator import Aggregator
from fed_learn.model_meta import FLContext


class CustomModelAggregator(Aggregator):
    def __init__(self):
        self.logger = logging.getLogger('CustomModelAggregator')

    def process(self, accumulator: [FLContext], fl_ctx: FLContext):
        """Aggregate the contributions from all the submitted FL clients.

        For the FLContext type we can use get_model() method to get the model data.
        The model data is a protobuf message and its format is defined as below.

        // A model consists of multiple tensors
        message ModelData {
            map<string, NDArray> params = 1;
        }

        // NDArray data for protobuf
        message NDArray {
          bytes ndarray = 1;
        }

        In this aggregation method we are using local number of iterations to weight each
        contribution and get a weighted average of that to be our new value.

        This function is not thread-safe.

        :param accumulator: List of all the contributions in FLContext.
        :param fl_ctx: An instance of FLContext.
        :return: Return True to indicates the current model is the best model so far.
        """
        # The model data is in model.params as a dict.
        model = fl_ctx.get_model()
        vars_to_aggregate = [set(item.get_model().params) for item in accumulator]
        vars_to_aggregate = set.union(*vars_to_aggregate)

        for v_name in vars_to_aggregate:
            n_local_iters, np_vars = [], []
            for item in accumulator:
                data = item.get_model()
                if v_name not in data.params:
                    continue  # this item doesn't have the variable from client

                # contribution is a protobuf msg
                #   it has `n_iter` which represents number of local iterations
                #   used to compute this contribution
                acc = item.get_prop('_contribution')
                float_n_iter = np.float(acc.n_iter)
                n_local_iters.append(float_n_iter)

                # it also has `client` and client has `uid`
                self.logger.info(f'Get contribution from client {acc.client.uid}')

                # weighted using local iterations
                weighted_value = proto_to_ndarray(data.params[v_name]) * float_n_iter
                np_vars.append(weighted_value)
            if not n_local_iters:
                continue  # didn't receive this variable from any clients
            new_val = np.sum(np_vars, axis=0) / np.sum(n_local_iters)
            new_val += proto_to_ndarray(model.params[v_name])

            # Update the params in model using CopyFrom because it is a ProtoBuf structure
            model.params[v_name].CopyFrom(ndarray_to_proto(new_val))
        return False