Federated learning

Federated learning motivation

Deep learning, the fastest growing field in AI, is empowering immense progress in scientific and real-life applications. It has been widely accepted that the more data used in the deep learning training, the better models can be achieved. However, those deep-learning algorithms are meeting difficult challenges when applied to the real-world, especially in the financial and health care areas, etc. The data in these areas are typically subject to strong privacy regulations. It is often impossible to share data across those organizations. The annotated data is also hard to obtain and it represents an asset for the institutions.

How can you obtain models as good as those that can be obtained training on large datasets without violating privacy and property constraints? One approach to solve this issue is to use federated learning. Different institutions contribute to the construction of a powerful model by doing collaborative training without sharing any data. A generic pre-trained model is fine-tuned for a specific application or a specific patient population. Different institutions share the trained model, not the actual data! With this approach, you can achieve the goal of training better models, while still protecting the data privacy.

How does federated learning work

../../_images/fl_overview.png

Federated learning is split into two parts: the server and the client. The server manages the overall model training progress and broadcasts the original model to all the participating clients. Model training happens locally on each client’s side. This way, the server does not need to access the training data. The data is protected with the private access on each client. All the clients are sharing the model updates instead of the data. Each client has its own privacy controls on what percentage of the model weights will be sent to the server for aggregation.

Once the server receives the latest round of models from the clients, the server can build its own algorithm on how to aggregate the model. It could be a simple average from all the clients, or based on some weights from the historical contributions from the clients.

The server has the overall control of how many rounds of federal learning training to conduct. The participating clients can be added or removed during any round of the training. The federated learning provides the benefits for every participant: a stronger central model, and better local models.

Federated learning in Clara Train

Clara Train SDK provides sophisticated medical imaging specific deep learning transform modules through the Clara Train API. Clara Train supports the import/export of trained models from different training sources and allows for easy creation of training and evaluation workflows. Furthermore, Clara Train provides ready-to-use cutting-edge deep learning models with pre-processing pipelines addressing a wide range of real-life medical cases. The fully configurable deep learning model training and the open APIs provide the ability to easily integrate training with external models. With all these accomplishments in place, it is natural to extend Clara Train functionality to support federated learning across different institutions.

The approach we are taking is to use Clara Train on each client locally to perform training with the transforms made available by Clara Train API. On top of that, we will build a federated learning server to facilitate the overall model training. The server will manage the consolidated model from all the participating clients, control the training pace and how many rounds of overall training to conduct, and aggregate the training results by only sharing the training models from the clients, while the clients have control of how much of their weights to share with the server and how much protection to add on the data privacy.

Federated learning system architecture

../../_images/fl_overall_workflow.png

When running a federated learning model training, a server training service must first be started. The server session controls the minimum number of clients needed to start a round of FL training and the maximum number of clients that can join a round. If a client intends to join a FL session, it must first submit a login request to apply for FL training. The server will check the credentials of the FL client and perform the authentication process to validate the client. If the authentication is successful, the server sends a FL token back to the client for use in the following FL client-server communication. Otherwise, it sends an authentication rejection if the client can not be authenticated.

Then the client sends another request to get the current training model from the server to start the current round of model training. The client has its own control on how many epochs it will run during each round of FL training. Once the client finishes the current round of training, it sends the updated model to the server using the existing FL token. After the server receives all the updated models from all the participating clients, it performs the model aggregation based on the weights algorithms and gets the updated overall training model. This completes the current round of the FL training, and the FL training continues until it reaches the max rounds set on the server, num_rounds in the server configuration file.

Federated learning server-client communication protocol

../../_images/fl_comm_protocol.png

Federated learning uses the gPRC protocol between the server and clients during model training. There are 4 basic commands during the model training session:

Register: Client uses this command to notify intent to join a FL training. The server performs the client authentication then returns a FL token back to the client.

GetModel: Client uses this command to acquire the model of the current round from the server. The server checks the token, and sends back the global model to the client.

SubmitUpdate: Client uses this command to send the updated local model after the current round of FL training to the server for the aggregation.

Quit: The command to use when the client decides to quit from the current FL model training.

Federated learning client-server authentication

../../_images/fl_authentication.png

In order for the client to participate in the federated learning model training, the client needs to provide its credentials to be authenticated. There are many ways of client authentication for the client-server SOA implementation. For the federated learning deployment, we are not expecting the clients to frequently change their credentials. In this case, we choose to use the self-signed SSL certificates from the server authority to authenticate the client identifies.

The client first sends a login request to the server for participating in federated learning training. After successful authentication, a FL token is returned to the client. The client uses the FL token for all the following communication requests. The token expiration is managed on the server side. Once the FL token is expired, the client needs to make a new login request and get a new FL token.

Federated learning server-side workflow

../../_images/fl_server_workflow.png

When starting a federated learning server-side service, the server side config files, including FL service name, gPRC communication ports, SSL certificate keys, minimum and maximum number of clients, etc, are used to initialize and restore the initial model and start the FL service. After the initialization, the server enters into a loop, waiting for clients’ joining request, then issuing the model to the clients, and waiting for the clients to send back the updated models. Once the server receives all the updated models from the clients, it performs the aggregation based on the weight aggregation algorithms, and updates the current overall model. This updated overall model is then used for the next round of model training, and this process is repeated until the server reaches the maximum rounds of the federated learning training.

Federated learning client-side workflow

../../_images/fl_client_workflow.png

On the client side of federated learning model training, the client first uses the client configuration to initialize. Then, the client uses the client credential to make a login request to the server to get a FL training token. Once the token is obtained, the client requests the current model from the server. It uses the current global model to build and restore the TF session to start the local training using the local data for fitting the current model. During the local training, the client has control how many epochs to run for each round of FL training. It also has to control whether the local training is run on a single GPU or multiple GPUs.

Once the client finishes the current round of the local model training, the clients sends the updated local model to the server. The client can configure its own privacy preserving policies on how much of the weights to send back to the server for aggregation. After that, the client makes a request to the server asking for the new global model to start a new round of federated learning training.

Federated learning Clara MMAR integration

The federated learning functions are packaged with the same MMAR structure as in Clara. In terms of the model training configurations, the FL model training uses the same Clara transform pipeline solution. The MMAR folder structure stays the same, and the Clara training commands and the training pipeline configurations stay the same.

There will be two additional server-client configurations to describe the federated learning behaviors. The server trainer controls how many rounds of FL training to conduct for the whole model training process, aggregates the overall models from the participating clients, and coordinates the global model training progress. The client trainer controls how many epochs the model training needs to run for each FL round, and the privacy protection policy to use when publishing the local training model back to the FL server for aggregation.

When starting a federated learning model training, you first start the FL training service from the server with the server_train.sh command. This service manages the FL training task identity and gPRC communication service location URLs. During the FL life cycle, the service listens for the clients to join, broadcasts the global model, and aggregates the updated models from the client.

From the client side, the client uses the client_train.sh command to start a FL client training task. It gets the FL token through a login request, gets the global model from the server, trains and updates the local model using protected data locally, and submits the model to the server for aggregation after each round of FL training.

Federated learning deployment security

One of the major motivations for federated learning is to safeguard data privacy. The federated learning model training typically involves multiple different organizations. Through the federated learning, each organization is enabled to train the model locally, sharing only the model, not the private data. However, the client-server communication is also critical to keep the data and model communication secure without being compromised.

In order to achieve federated learning security, use the FL token to establish trust between the client and server. The FL token is used throughout the FL training session life cycle. Clients need to verify the server identity and the server needs to verify and authenticate the clients. The client-server data exchanges are based on the HTTPS protocol for secure communication. The self-signed SSL certificates are used to build the client-server trust.

Instructions on how to create the self-signed SSL Certificate Authority and server-clients certificates

## Create the Root Key
`openssl genrsa -out rootCA.key 2048`   Or
`openssl genrsa -des3 -out rootCA.key 2048` (with password)

## Create the self-signed Root certificate
`openssl req -x509 -new -nodes -key rootCA.key -sha256 -days 1024 -out rootCA.pem`

## Create the certificate for server and clients

### Create a private key
`openssl genrsa -out device.key 2048`

### Generate the certificate signing request
`openssl req -new -key device.key -out device.csr`

### sign the CSR, uses the CA root key, to generate the certificate
`openssl x509 -req -in device.csr -CA rootCA.pem -CAkey rootCA.key -CAcreateserial -out device.crt -days 500 -sha256`

For development purposes, insecure gRPC communication between the server and clients is supported. This communication mode is not recommended to be used in the federated learning production deployment.

How to use Clara federated learning model training

FL server configuration

FL server configuration file: config_fed_server.json

Example:

{
    "servers": [
        {
            "name": "prostate_segmentation",
            "service": {
                "target": "localhost:8002",
                "options": [
                    ["grpc.max_send_message_length",    1000000000],
                    ["grpc.max_receive_message_length", 1000000000]
                ]
            },
            "ssl_private_key": "resources/certs/server.key",
            "ssl_cert": "resources/certs/server.crt",
            "ssl_root_cert": "resources/certs/rootCA.pem",
            "min_num_clients": 2,
            "max_num_clients": 100,
            "start_round": 0,
            "num_rounds": 300,
            "exclude_vars": "dummy",
            "num_server_workers": 100
        }
    ]
}

Variable

Description

servers

The list of servers runs the FL service

name

The FL model training task name

target

FL gRPC service location URL

grpc.max_send_message_length

Maximum length of gRPC message send

grpc.max_receive_message_length

Maximum length of gRPC message receive

ssl_private_key

gRPC secure communication private key

ssl_cert

gRPC secure communication SSL certificate

ssl_root_cert

gRPC secure communication trusted root certificate

min_num_clients

Minimum number of clients required for FL model training

max_nun_clients

Maximum number of clients required for FL model training

start_round

FL training starting round number

num_rounds

Round number to continue conducting training until

exclude_vars

Excluded variables from the privacy preserving

num_server_workers

Maximum number of workers to support the FL model training

More details on the variables above:

start_round

The current FL server training will start from this number and continue until the value of num_rounds. Depending on the status of the FL training, you can adjust this accordingly.

exclude_vars

This option accepts a string argument, and this string will be interpreted as a regular expression. The exclude_vars regex is then used to filter out server model parameters which are not to be shared with the clients.

num_server_workers

This is used to control how many workers are allocated for the gRPC services from the server side. This may slightly affect the performance of gRPC communication.

FL client configuration

FL client configuration file: config_fed_client.json

Example:

{
    "servers": [
        {
            "name": "prostate_segmentation",
            "service": {
                "target": "localhost:8002",
                "options": [
                    ["grpc.max_send_message_length",    1000000000],
                    ["grpc.max_receive_message_length", 1000000000]
                ]
            }
        }
    ],
    "client": {
        "local_epochs": 20,
        "exclude_vars": "dummy",
        "privacy": {
            "dp_type": "none",
            "percentile": 75,
            "gamma": 1
        },
        "ssl_private_key": "resources/certs/client1.key",
        "ssl_cert": "resources/certs/client1.crt",
        "ssl_root_cert": "resources/certs/rootCA.pem"
    }
}

Variable

Description

servers

Same as the server configuration for the FL training task identification and service location URLs

client

The section to describe the FL client

local_epochs

How many epochs to run for each FL training round

exclude_vars

Excluded variables from the privacy preserving

privacy

Privacy preserving algorithm

ssl_private_key

gRPC secure communication private key

ssl_cert

gRPC secure communication SSL certificate

ssl_root_cert

gRPC secure communication trusted root certificate

More details on the variables above:

servers

Currently this list supports one element only.

privacy

The amount of noise depends on the learning rate, data preprocessing steps, and the type of optimisers that the user have chosen. Just like setting a learning rate, setting the privacy parameters is data and task dependent. A parameter tuning strategy would be to start with no noise added (set “dp_type” to “none”; meaning no privacy-related model data processing), then gradually add noise while checking the convergence and accuracy on a pilot study dataset (the noise is controlled by “gamma” which can be an arbitrarily small positive number, while “noise_var” could be an arbitrarily large number). With this information, choose the largest amount of noise given the lowest acceptable model accuracy (in the client configuration file when “dp_type” is “laplacian”, options “gamma” corresponds to “gamma”, “tau” corresponds to “tau”, “epsilon” corresponds to “epsilon_1”, “fraction” corresponds to “Q”, and “noise_var” corresponds to “epsilon_3”, described in the paper https://arxiv.org/abs/1910.00962. When “dp_type” is “partial”, “percentile” indicate the percentage of the model parameters shared with the server and “gamma” indicates the clipping value applied to all model parameters before sharing).

In the above example, target is configured as localhost:8002. Based on the network environment in which FL server and clients are deployed, users need to ensure DNS is configured correctly so clients can find the server by the server’s name. Another way for the name resolution without DNS is to edit the client’s /etc/hosts file. Before starting clients, add the host IP address and its name in /etc/hosts file. The following is one example:

127.0.0.1   localhost
::1 localhost ip6-localhost ip6-loopback
fe00::0     ip6-localnet
ff00::0     ip6-mcastprefix
ff02::1     ip6-allnodes
ff02::2     ip6-allrouters
172.18.0.2  c96cc6ba74ab

# Add the following for IP and name of FL server
10.110.11.22 fedserver

When launching the FL server docker with docker run, users also need to expose the port via the -p option.

How to run Clara federated learning model training

The commands to start the FL model training are: server_train.sh (from the server machine) and client_train.sh (from the client machine).

There are two options to start the federated learning server, training from scratch, or starting from a previously trained model. By adding the option “MMAR_CKPT=$MMAR_ROOT/models/FL_global_model.ckpt” in the command line, the FL server will start FL training using the pre-trained model “FL_global_model.ckpt”. Without this option, the FL server will start the training from scratch.

Example (training from scratch)

server_train.sh:

python3 -u  -m nvmidl.apps.fed_learn.server.fed_aggregate \
    -m $MMAR_ROOT \
    -c $CONFIG_FILE \
    -e $ENVIRONMENT_FILE \
    -s $SERVER_FILE \
    --set \
    secure_train=true

env_server.json:

{
    "PROCESSING_TASK": "segmentation",
    "MMAR_CKPT_DIR": "models"
}

Example (starting from previously trained model)

server_train.sh:

python3 -u  -m nvmidl.apps.fed_learn.server.fed_aggregate \
    -m $MMAR_ROOT \
    -c $CONFIG_FILE \
    -e $ENVIRONMENT_FILE \
    -s $SERVER_FILE \
    --set \
    MMAR_CKPT=$MMAR_ROOT/models/FL_global_model.ckpt \
    secure_train=true

env_server.json:

{
    "MMAR_CKPT": "FL_global_model.ckpt",
    "PROCESSING_TASK": "segmentation",
    "MMAR_CKPT_DIR": "models"
}

Note

To train from scratch, make sure that “MMAR_CKPT” is not set to a pre-existing checkpoint both in the command to launch training as well as any environment files.

The sequence of starting the federated learning model training is to start the server first, then start the client to join the FL training. The client requires the service to be available during the start up for login and getting the FL token. The server side prints out the status of how many clients have joined the FL training and the FL token each client has been issued.

The following case is not yet supported with federated learning:

  • Federated learning client training with multi-gpu

  • Very rarely, FL training clients may not exit gracefully after the training is successfully finished. You can use CTR-C to shutdown the FL client, or kill the FL client process.

This will not affect any model training results and the trained model checkpoint will be saved correctly.