fed_learn.server package
-
class
ClientReply
(client_name, req: fed_learn.admin_defs.Message, reply: fed_learn.admin_defs.Message) Bases:
object
-
class
FedAdminServer
(fed_admin_interface, users, cmd_modules, file_upload_dir, file_download_dir, allowed_shell_cmds, host, port, ca_cert_file_name, server_cert_file_name, server_key_file_name, accepted_client_cns=None, mmar_validator=None) Bases:
dlmed.hci.server.hci.AdminServer
The FedAdminServer is the framework for developing admin commands.
- Parameters
fed_admin_interface – the Fed Server’s admin interface
users – a dict of user name => pwd hash
cmd_modules – a list of CommandModules
file_upload_dir – the directory for uploaded files
file_download_dir – the directory for files to be downloaded
allowed_shell_cmds – list of shell commands allowed. If not specified, all allowed.
host – the IP address of the admin server
port – port number of admin server
ca_cert_file_name – the root CA’s cert file name
server_cert_file_name – server’s cert, signed by the CA
server_key_file_name – server’s private key file
accepted_client_cns – list of accepted Common Names from client, if specified
Constructor. May be extended, do not override.
-
accept_reply
(client_name, reply: fed_learn.admin_defs.Message) This method is to be called by the FL Engine after a client response is received. Hence it is called from the FL Engine’s message processing thread.
- Parameters
client_name –
reply –
-
client_dead
(client_name) This method is called by the Fed Engine to indicate the client is dead.
- Parameters
client_name –
-
client_heartbeat
(client_name) This method is called by the Fed Engine to indicate the client is alive.
- Parameters
client_name –
-
get_client_names
() → ] Get names of existing clients.
-
get_outgoing_requests
(client_name, max_reqs=0) This method is called by FL Engine to get outgoing messages to the client, so it can send them to the client.
- Parameters
client_name –
max_reqs –
-
send_request_to_client
(req: fed_learn.admin_defs.Message, client_name: str, timeout_secs=2.0) → fed_learn.server.admin.ClientReply
-
send_request_to_clients
(req: fed_learn.admin_defs.Message, clients: [], timeout_secs=2.0 ) → [<class ‘fed_learn.server.admin.ClientReply’>]
-
send_requests
(requests: dict, timeout_secs=2.0) → [<class ‘fed_learn.server.admin.ClientReply’>] This method is to be used by a Command Handler to send requests to Clients. Hence it is run in the Command Handler’s handling thread. This is a blocking call - returned only after all responses are received or timeout.
- Parameters
requests – a dict of requests: client name => req message or list of msgs
timeout_secs – how long to wait for reply before timeout
Returns: a list of ClientReply
-
send_requests_to_all_clients
(reqs: [], timeout_secs=2.0 ) → [<class ‘fed_learn.server.admin.ClientReply’>] Send multiple request messages to all clients and wait for replies
- Parameters
reqs – requests to be sent
timeout_secs – how long to wait for reply before timeout
Returns: a list of ClientReply
-
send_to_all_clients
(req: fed_learn.admin_defs.Message, timeout_secs=2.0) → [<class ‘fed_learn.server.admin.ClientReply’>] This method is to be used by a Command Handler to send a request to all Clients. Hence it is run in the Command Handler’s handling thread. This is a blocking call - returned only after all responses are received or timeout.
- Parameters
req – the request to be sent
timeout_secs – how long to wait for reply before timeout
Returns: a list of ClientReply
-
stop
()
-
new_message
(conn: dlmed.hci.conn.Connection, topic, body) → fed_learn.admin_defs.Message
-
class
ClientManager
(task_name=None, min_num_clients=2, max_num_clients=10) Bases:
object
-
authenticate
(request, context)
-
authenticated_client
(client_login, context) Use SSL certificate for authenticate the client. :param client_login: :param context: :return:
-
get_clients
() get the list of registered clients. :return:
-
get_max_clients
()
-
get_min_clients
()
-
heartbeat
(token, client_id, context) update the heartbeat of the client. :param token: client ID token :return: If a new client needs to be created.
simple authentication of the client.
- Returns
True indicates it is a recognised client
-
is_valid_task
(task) check whether the requested task matches the server’s task
-
login_client
(client_login, context) validate the client state message
- Parameters
context – gRPC connection context
- Returns
client id if it’s a valid client
-
remove_client
(token)
-
validate_client
(client_state, context, allow_new=False) validate the client state message
- Parameters
client_state – A ClientState message received by server
context – gRPC connection context
allow_new – whether to allow new client. Its task should still match server’s.
- Returns
client id if it’s a valid client
-
federated server for aggregating and sharing federated model
-
class
BaseServer
(task_name=None, min_num_clients=2, max_num_clients=10, wait_after_min_clients=10, start_round=1, num_rounds=-1, heart_beat_timeout=600, handlers: [] = None , cmd_modules=None) Bases:
object
-
client_cleanup
()
-
close
() shutdown the server.
- Returns
-
deploy
(grpc_args=None, secure_train=False) start a grpc server and listening the designated port. :param fl_ctx:
-
fl_shutdown
()
-
get_all_clients
()
-
abstract
remove_client_data
(token)
-
remove_dead_clients
()
-
set_admin_server
(admin_server)
-
property
should_stop
num_rounds < 0 means non-stopping
-
start
()
-
-
class
FederatedServer
(task_name=None, min_num_clients=2, max_num_clients=10, wait_after_min_clients=10, start_round=1, num_rounds=-1, exclude_vars=None, model_log_dir=None, ckpt_preload_path=None, model_aggregator=None, model_saver=None, pre_processors: [] = None , post_processors: [] = None , cmd_modules=None, result_processors=[], heart_beat_timeout=600, handlers: [] = None ) Bases:
fed_learn.server.fed_server.BaseServer
,fed_learn.protos.federated_pb2_grpc.FederatedTrainingServicer
,fed_learn.server.sai.ServerAdminInterface
,fed_learn.protos.admin_pb2_grpc.AdminCommunicatingServicer
Federated model aggregation services
- Parameters
start_round – 0 indicates init. the global model randomly.
min_num_clients – minimum number of contributors at each round.
max_num_clients – maximum number of contributors at each round.
-
GetModel
(request, context) process client’s request of the current global model
-
GetValidationModels
(request, context) Send validation models to server.
-
Heartbeat
(request, context) client to server heartbeat keep live
-
Quit
(request, context) existing client quits the federated training process. Server will stop sharing the global model with the client, further contribution will be rejected.
This function does not change min_num_clients and max_num_clients.
-
Register
(request, context) register new clients on the fly. Each client must get registered before getting the global model. The server will expect updates from the registered clients for multiple federated rounds.
This function does not change min_num_clients and max_num_clients.
-
Retrieve
(request, context) client retrieve requests.
-
SendReply
(request, context) client send reply to server
-
SendResult
(request, context) client send process results to server
-
SubmitBestLocalModel
(request, context) Receive the best local model from clients.
-
SubmitCrossSiteValidationResults
(request, context) Get the cross validation results from client.
-
SubmitUpdate
(request, context) handling client’s submission of the federated updates running aggregation if there are enough updates
-
aggregate
() invoke model aggregation using the accumulator’s content, then reset the tokens and accumulator.
- Returns
-
close
() shutdown the server.
- Returns
-
get_current_model_meta_data
() Get the model meta data, which usually contains additional fields
-
is_valid_contribution
(contrib_meta_data) check if the client submitted a valid contribution contribution meta should be for the current task and for the current round; matching server’s model meta data.
- Parameters
contrib_meta_data – Contribution message’s meta data
- Returns
the meta data if the contrib’s meta data is valid, None otherwise.
-
property
model_meta_info
the model_meta_info uniquely defines the current model, it is used to reject outdated client’s update.
- Returns
model meta data object
-
register_processor
(processor: fed_learn.server.result_processor.ResultProcessor)
-
remove_client_data
(token)
-
reset_tokens
() restart the token set, so that each client can take a token and start fetching the current global model. This function is not thread-safe.
-
save_contribution
(client_contrib_id, data) save the client’s current contribution.
- Returns
True iff it is successfully saved
-
set_builder
(builder)
-
start
()
-
stop_training
()
-
class
TestFederatedServer
(task_name=None, min_num_clients=2, max_num_clients=10, heart_beat_timeout=600, cmd_modules=None, result_processors=None) Bases:
fed_learn.server.fed_server.FederatedServer
- Parameters
start_round – 0 indicates init. the global model randomly.
min_num_clients – minimum number of contributors at each round.
max_num_clients – maximum number of contributors at each round.
-
close
() shutdown the server.
- Returns
-
property
current_round
-
property
should_stop
True to stop the main thread
- Type
return
-
start
()
-
class
MMARAuthzService
Bases:
object
-
static
initialize
(mmar_validator)
-
mmar_validator
= None
-
static
-
class
Aggregator
Bases:
fed_learn.components.data_processor.DataProcessor
-
process
(accumulator: [], fl_ctx: fed_learn.model_meta.FLContext ) Method to aggregate the contributions from all the submitted FL clients. :param accumulator: List of all the contributions. :param fl_ctx: FLContext :return:
-
-
class
ModelAggregator
(exclude_vars=None, aggregation_weights=None) Bases:
fed_learn.server.model_aggregator.Aggregator
-
process
(accumulator: [], fl_ctx: fed_learn.model_meta.FLContext ) Aggregate model variables. This function is not thread-safe.
:return Return True to indicates the current model is the best model so far.
-
-
class
ResultProcessor
Bases:
object
The RequestProcessor is responsible for processing a request.
-
get_topics
() → [<class ‘str’>] Get topics that this processor will handle :return: list of topics
-
process
(client_name, req: fed_learn.admin_defs.Message) Called to process the specified request :param req: :param app_ctx: :return: a reply message
-
-
class
ValidateResultProcessor
Bases:
fed_learn.server.result_processor.ResultProcessor
-
get_topics
() → [<class ‘str’>] Get topics that this processor will handle :return: list of topics
-
process
(client_name, message: fed_learn.admin_defs.Message) Called to process the specified request :param req: :param app_ctx: :return: a reply message
-
federated server for aggregating and sharing federated model
-
class
RoundRobinFederatedServer
(task_name=None, min_num_clients=2, max_num_clients=10, wait_after_min_clients=10, start_round=1, num_rounds=-1, exclude_vars=None, model_log_dir=None, ckpt_preload_path=None, model_aggregator=None, model_saver=None, pre_processors: [] = None , post_processors: [] = None , cmd_modules=None, result_processors=[], heart_beat_timeout=600, handlers: [] = None ) Bases:
fed_learn.server.fed_server.FederatedServer
Federated model aggregation services
Init for Round Robin server.
-
GetModel
(request, context) process client’s request of the current global model
-
Quit
(request, context) existing client quits the federated training process. Server will stop sharing the global model with the client, further contribution will be rejected.
This function does not change min_num_clients and max_num_clients.
-
aggregate
() invoke model aggregation using the accumulator’s content, then reset the tokens and accumulator.
- Returns
-
remove_client_data
(token)
-
remove_dead_clients
()
-
reset_active_client
()
-
-
class
ServerAdminInterface
(server, args, workers=3) Bases:
object
-
check_aggregation
()
-
check_status
() → Tuple[str,Any,Any]
-
close
()
-
delete_run_number
(num)
-
deploy_mmar
(src, dest)
-
fl_restart
() → str
-
fl_shutdown
() → str
-
get_all_client_names
()
-
get_all_clients
()
-
get_all_instance_names
()
-
get_all_register_tokens
(client_name)
-
get_all_taskname
() → str
-
get_all_tokens_from_inputs
(inputs)
-
get_client_mmar
(client_name)
-
get_client_name_from_instance_name
(instance_name)
-
get_cross_val_directory
() → str
-
get_cross_val_filename
() → str
-
get_cross_val_results
(model_client, data_client) → dict
-
get_instance_name_from_token
(token)
-
get_mmar_path
(src)
-
get_run_number
()
-
initialize_cross_val
()
-
load_cross_val_dict
()
-
remove_custom_path
()
-
remove_dead_client
(token)
-
reset_model_for_client
(client)
-
set_min_clients
(num)
-
set_run_number
(num)
-
start_server_training
() → str
-
stop_server_training
() → str
-
update_cross_val_dict
(data_client, val_client_names, val_client_metrics) Updates cross site validation results
- Parameters
data_client (str) – Name of the client who sent results
val_client_names (str) – List of clients validated
val_client_metrics (str) – Metric dicts for each client
- Returns
Updated dictionary
- Return type
dict
-
write_cross_val_results
(cross_val_dict)
-
-
copy_new_server_properties
(server, new_server)
-
server_shutdown
(server, touch_file)
-
start_server_training
(server, args, mmar_root)
-
class
ServerModelManager
(start_round=0, num_rounds=- 1, exclude_vars=None, model_log_dir=None, ckpt_preload_path=None, model_aggregator=None, model_saver=None) Bases:
object
Global model manager lives on the server side.
-
close
() TODO final saving before quitting
-
initialize
(builder)
-
run_validation
() Run validation
-
property
should_stop
num_rounds < 0 means non-stopping
-
update_model
(accumulator, fl_ctx: fed_learn.model_meta.FLContext) Aggregate tensorflow variables. This function is not thread-safe. :param fl_ctx:
-
-
class
ServerStatus
Bases:
object
-
TRAINING_NOT_STARTED
= 0
-
TRAINING_STARTED
= 2
-
TRAINING_STARTING
= 1
-
TRAINING_STOPPED
= 3
-
status_messages
= {0: 'training not started', 1: 'training starting', 2: 'training started', 3: 'training stopped'}
-
-
get_status_message
(status)
-
class
ShellCommandModule
Bases:
dlmed.hci.reg.CommandModule
-
get_spec
()
-
-
class
SystemCommandModule
(allowed_commands=None) Bases:
dlmed.hci.reg.CommandModule
-
get_client_cmd_reply
(sai, requests, server, conn)
-
get_spec
()
-
sys_info
(conn: dlmed.hci.conn.Connection, args: [] )
-
-
class
TrainingCommandModule
(allowed_commands=None) Bases:
dlmed.hci.reg.CommandModule
-
abort_clients
(clients, conn, message, requests, sai)
-
abort_training
(conn: dlmed.hci.conn.Connection, args: [] )
-
check_status
(conn: dlmed.hci.conn.Connection, args: [] )
-
delete_run_number
(conn: dlmed.hci.conn.Connection, args: [] )
-
deploy
(conn: dlmed.hci.conn.Connection, args: [] )
-
get_client_cmd_reply
(sai, requests, server)
-
get_spec
()
-
remove_client
(conn: dlmed.hci.conn.Connection, args: [] )
-
restart
(conn: dlmed.hci.conn.Connection, args: [] )
-
restart_clients
(clients, conn, message, requests, sai)
-
set_min_clients
(conn: dlmed.hci.conn.Connection, args: [] )
-
set_run_number
(conn: dlmed.hci.conn.Connection, args: [] )
-
set_timeout
(conn: dlmed.hci.conn.Connection, args: [] )
-
shutdown
(conn: dlmed.hci.conn.Connection, args: [] )
-
start_mgpu_training
(conn: dlmed.hci.conn.Connection, args: [] )
-
start_training
(conn: dlmed.hci.conn.Connection, args: [] )
-
-
class
ValidationCommandModule
(allowed_commands=None) Bases:
dlmed.hci.reg.CommandModule
-
do_validation_command
(conn: dlmed.hci.conn.Connection, args: [] )
-
get_spec
()
-
get_taskname
(conn: dlmed.hci.conn.Connection, args: [] )
-