NVIDIA Clara Train 3.1
3.1

fed_learn.server package

class ClientReply(client_name, req: fed_learn.admin_defs.Message, reply: fed_learn.admin_defs.Message)

Bases: object

class FedAdminServer(fed_admin_interface, users, cmd_modules, file_upload_dir, file_download_dir, allowed_shell_cmds, host, port, ca_cert_file_name, server_cert_file_name, server_key_file_name, accepted_client_cns=None, mmar_validator=None)

Bases: dlmed.hci.server.hci.AdminServer

The FedAdminServer is the framework for developing admin commands.

Parameters
  • fed_admin_interface – the Fed Server’s admin interface

  • users – a dict of user name => pwd hash

  • cmd_modules – a list of CommandModules

  • file_upload_dir – the directory for uploaded files

  • file_download_dir – the directory for files to be downloaded

  • allowed_shell_cmds – list of shell commands allowed. If not specified, all allowed.

  • host – the IP address of the admin server

  • port – port number of admin server

  • ca_cert_file_name – the root CA’s cert file name

  • server_cert_file_name – server’s cert, signed by the CA

  • server_key_file_name – server’s private key file

  • accepted_client_cns – list of accepted Common Names from client, if specified

Constructor. May be extended, do not override.

accept_reply(client_name, reply: fed_learn.admin_defs.Message)

This method is to be called by the FL Engine after a client response is received. Hence it is called from the FL Engine’s message processing thread.

Parameters
  • client_name

  • reply

client_dead(client_name)

This method is called by the Fed Engine to indicate the client is dead.

Parameters

client_name

client_heartbeat(client_name)

This method is called by the Fed Engine to indicate the client is alive.

Parameters

client_name

get_client_names() → ]

Get names of existing clients.

get_outgoing_requests(client_name, max_reqs=0)

This method is called by FL Engine to get outgoing messages to the client, so it can send them to the client.

Parameters
  • client_name

  • max_reqs

send_request_to_client(req: fed_learn.admin_defs.Message, client_name: str, timeout_secs=2.0) → fed_learn.server.admin.ClientReply
send_request_to_clients(req: fed_learn.admin_defs.Message, clients: [ ], timeout_secs=2.0 ) → [<class ‘fed_learn.server.admin.ClientReply’>]
send_requests(requests: dict, timeout_secs=2.0) → [<class ‘fed_learn.server.admin.ClientReply’>]

This method is to be used by a Command Handler to send requests to Clients. Hence it is run in the Command Handler’s handling thread. This is a blocking call - returned only after all responses are received or timeout.

Parameters
  • requests – a dict of requests: client name => req message or list of msgs

  • timeout_secs – how long to wait for reply before timeout

Returns: a list of ClientReply

send_requests_to_all_clients(reqs: [ ], timeout_secs=2.0 ) → [<class ‘fed_learn.server.admin.ClientReply’>]

Send multiple request messages to all clients and wait for replies

Parameters
  • reqs – requests to be sent

  • timeout_secs – how long to wait for reply before timeout

Returns: a list of ClientReply

send_to_all_clients(req: fed_learn.admin_defs.Message, timeout_secs=2.0) → [<class ‘fed_learn.server.admin.ClientReply’>]

This method is to be used by a Command Handler to send a request to all Clients. Hence it is run in the Command Handler’s handling thread. This is a blocking call - returned only after all responses are received or timeout.

Parameters
  • req – the request to be sent

  • timeout_secs – how long to wait for reply before timeout

Returns: a list of ClientReply

stop()
new_message(conn: dlmed.hci.conn.Connection, topic, body) → fed_learn.admin_defs.Message
class ClientManager(task_name=None, min_num_clients=2, max_num_clients=10)

Bases: object

authenticate(request, context)
authenticated_client(client_login, context)

Use SSL certificate for authenticate the client. :param client_login: :param context: :return:

get_clients()

get the list of registered clients. :return:

get_max_clients()
get_min_clients()
heartbeat(token, client_id, context)

update the heartbeat of the client. :param token: client ID token :return: If a new client needs to be created.

is_from_authorized_client(client_id)

simple authentication of the client.

Returns

True indicates it is a recognised client

is_valid_task(task)

check whether the requested task matches the server’s task

login_client(client_login, context)

validate the client state message

Parameters

context – gRPC connection context

Returns

client id if it’s a valid client

remove_client(token)
validate_client(client_state, context, allow_new=False)

validate the client state message

Parameters
  • client_state – A ClientState message received by server

  • context – gRPC connection context

  • allow_new – whether to allow new client. Its task should still match server’s.

Returns

client id if it’s a valid client

federated server for aggregating and sharing federated model

class BaseServer(task_name=None, min_num_clients=2, max_num_clients=10, wait_after_min_clients=10, start_round=1, num_rounds=-1, heart_beat_timeout=600, handlers: [ ] = None , cmd_modules=None)

Bases: object

client_cleanup()
close()

shutdown the server.

Returns

deploy(grpc_args=None, secure_train=False)

start a grpc server and listening the designated port. :param fl_ctx:

fl_shutdown()
get_all_clients()
abstract remove_client_data(token)
remove_dead_clients()
set_admin_server(admin_server)
property should_stop

num_rounds < 0 means non-stopping

start()
class FederatedServer(task_name=None, min_num_clients=2, max_num_clients=10, wait_after_min_clients=10, start_round=1, num_rounds=-1, exclude_vars=None, model_log_dir=None, ckpt_preload_path=None, model_aggregator=None, model_saver=None, pre_processors: [ ] = None , post_processors: [ ] = None , cmd_modules=None, result_processors=[], heart_beat_timeout=600, handlers: [ ] = None )

Bases: <a href="#fed_learn.server.fed_server.BaseServer">fed_learn.server.fed_server.BaseServer</a>, <a href="fed_learn.protos.html#fed_learn.protos.federated_pb2_grpc.FederatedTrainingServicer">fed_learn.protos.federated_pb2_grpc.FederatedTrainingServicer</a>, <a href="#fed_learn.server.sai.ServerAdminInterface">fed_learn.server.sai.ServerAdminInterface</a>, <a href="fed_learn.protos.html#fed_learn.protos.admin_pb2_grpc.AdminCommunicatingServicer">fed_learn.protos.admin_pb2_grpc.AdminCommunicatingServicer</a>

Federated model aggregation services

Parameters
  • start_round – 0 indicates init. the global model randomly.

  • min_num_clients – minimum number of contributors at each round.

  • max_num_clients – maximum number of contributors at each round.

GetModel(request, context)

process client’s request of the current global model

GetValidationModels(request, context)

Send validation models to server.

Heartbeat(request, context)

client to server heartbeat keep live

Quit(request, context)

existing client quits the federated training process. Server will stop sharing the global model with the client, further contribution will be rejected.

This function does not change min_num_clients and max_num_clients.

Register(request, context)

register new clients on the fly. Each client must get registered before getting the global model. The server will expect updates from the registered clients for multiple federated rounds.

This function does not change min_num_clients and max_num_clients.

Retrieve(request, context)

client retrieve requests.

SendReply(request, context)

client send reply to server

SendResult(request, context)

client send process results to server

SubmitBestLocalModel(request, context)

Receive the best local model from clients.

SubmitCrossSiteValidationResults(request, context)

Get the cross validation results from client.

SubmitUpdate(request, context)

handling client’s submission of the federated updates running aggregation if there are enough updates

aggregate()

invoke model aggregation using the accumulator’s content, then reset the tokens and accumulator.

Returns

close()

shutdown the server.

Returns

get_current_model_meta_data()

Get the model meta data, which usually contains additional fields

is_valid_contribution(contrib_meta_data)

check if the client submitted a valid contribution contribution meta should be for the current task and for the current round; matching server’s model meta data.

Parameters

contrib_meta_data – Contribution message’s meta data

Returns

the meta data if the contrib’s meta data is valid, None otherwise.

property model_meta_info

the model_meta_info uniquely defines the current model, it is used to reject outdated client’s update.

Returns

model meta data object

register_processor(processor: fed_learn.server.result_processor.ResultProcessor)
remove_client_data(token)
reset_tokens()

restart the token set, so that each client can take a token and start fetching the current global model. This function is not thread-safe.

save_contribution(client_contrib_id, data)

save the client’s current contribution.

Returns

True iff it is successfully saved

set_builder(builder)
start()
stop_training()
class TestFederatedServer(task_name=None, min_num_clients=2, max_num_clients=10, heart_beat_timeout=600, cmd_modules=None, result_processors=None)

Bases: <a href="#fed_learn.server.fed_server.FederatedServer">fed_learn.server.fed_server.FederatedServer</a>

Parameters
  • start_round – 0 indicates init. the global model randomly.

  • min_num_clients – minimum number of contributors at each round.

  • max_num_clients – maximum number of contributors at each round.

close()

shutdown the server.

Returns

property current_round
property should_stop

True to stop the main thread

Type

return

start()
class MMARAuthzService

Bases: object

static authorize_deploy(mmar_path: str, sites: [ ]) -> ( , )
static authorize_upload(mmar_path: str) -> ( , )
static initialize(mmar_validator)
mmar_validator = None
class Aggregator

Bases: <a href="fed_learn.components.html#fed_learn.components.data_processor.DataProcessor">fed_learn.components.data_processor.DataProcessor</a>

process(accumulator: [ ], fl_ctx: fed_learn.model_meta.FLContext )

Method to aggregate the contributions from all the submitted FL clients. :param accumulator: List of all the contributions. :param fl_ctx: FLContext :return:

class ModelAggregator(exclude_vars=None, aggregation_weights=None)

Bases: <a href="#fed_learn.server.model_aggregator.Aggregator">fed_learn.server.model_aggregator.Aggregator</a>

process(accumulator: [ ], fl_ctx: fed_learn.model_meta.FLContext )

Aggregate model variables. This function is not thread-safe.

:return Return True to indicates the current model is the best model so far.

class ResultProcessor

Bases: object

The RequestProcessor is responsible for processing a request.

get_topics() → [<class ‘str’>]

Get topics that this processor will handle :return: list of topics

process(client_name, req: fed_learn.admin_defs.Message)

Called to process the specified request :param req: :param app_ctx: :return: a reply message

class ValidateResultProcessor

Bases: <a href="#fed_learn.server.result_processor.ResultProcessor">fed_learn.server.result_processor.ResultProcessor</a>

get_topics() → [<class ‘str’>]

Get topics that this processor will handle :return: list of topics

process(client_name, message: fed_learn.admin_defs.Message)

Called to process the specified request :param req: :param app_ctx: :return: a reply message

federated server for aggregating and sharing federated model

class RoundRobinFederatedServer(task_name=None, min_num_clients=2, max_num_clients=10, wait_after_min_clients=10, start_round=1, num_rounds=-1, exclude_vars=None, model_log_dir=None, ckpt_preload_path=None, model_aggregator=None, model_saver=None, pre_processors: [ ] = None , post_processors: [ ] = None , cmd_modules=None, result_processors=[], heart_beat_timeout=600, handlers: [ ] = None )

Bases: <a href="#fed_learn.server.fed_server.FederatedServer">fed_learn.server.fed_server.FederatedServer</a>

Federated model aggregation services

Init for Round Robin server.

GetModel(request, context)

process client’s request of the current global model

Quit(request, context)

existing client quits the federated training process. Server will stop sharing the global model with the client, further contribution will be rejected.

This function does not change min_num_clients and max_num_clients.

aggregate()

invoke model aggregation using the accumulator’s content, then reset the tokens and accumulator.

Returns

remove_client_data(token)
remove_dead_clients()
reset_active_client()
class ServerAdminInterface(server, args, workers=3)

Bases: object

check_aggregation()
check_status() → Tuple[str,Any,Any]
close()
delete_run_number(num)
deploy_mmar(src, dest)
fl_restart() → str
fl_shutdown() → str
get_all_client_names()
get_all_clients()
get_all_instance_names()
get_all_register_tokens(client_name)
get_all_taskname() → str
get_all_tokens_from_inputs(inputs)
get_client_mmar(client_name)
get_client_name_from_instance_name(instance_name)
get_cross_val_directory() → str
get_cross_val_filename() → str
get_cross_val_results(model_client, data_client) → dict
get_instance_name_from_token(token)
get_mmar_path(src)
get_run_number()
initialize_cross_val()
load_cross_val_dict()
remove_custom_path()
remove_dead_client(token)
reset_model_for_client(client)
set_min_clients(num)
set_run_number(num)
start_server_training() → str
stop_server_training() → str
update_cross_val_dict(data_client, val_client_names, val_client_metrics)

Updates cross site validation results

Parameters
  • data_client (str) – Name of the client who sent results

  • val_client_names (str) – List of clients validated

  • val_client_metrics (str) – Metric dicts for each client

Returns

Updated dictionary

Return type

dict

write_cross_val_results(cross_val_dict)
copy_new_server_properties(server, new_server)
server_shutdown(server, touch_file)
start_server_training(server, args, mmar_root)
class ServerModelManager(start_round=0, num_rounds=- 1, exclude_vars=None, model_log_dir=None, ckpt_preload_path=None, model_aggregator=None, model_saver=None)

Bases: object

Global model manager lives on the server side.

close()

TODO final saving before quitting

initialize(builder)
run_validation()

Run validation

property should_stop

num_rounds < 0 means non-stopping

update_model(accumulator, fl_ctx: fed_learn.model_meta.FLContext)

Aggregate tensorflow variables. This function is not thread-safe. :param fl_ctx:

class ServerStatus

Bases: object

TRAINING_NOT_STARTED = 0
TRAINING_STARTED = 2
TRAINING_STARTING = 1
TRAINING_STOPPED = 3
status_messages = {0: 'training not started', 1: 'training starting', 2: 'training started', 3: 'training stopped'}
get_status_message(status)
class ShellCommandModule

Bases: dlmed.hci.reg.CommandModule

get_spec()
class SystemCommandModule(allowed_commands=None)

Bases: dlmed.hci.reg.CommandModule

authorize_sys_info(conn: dlmed.hci.conn.Connection, args: [ ] )
get_client_cmd_reply(sai, requests, server, conn)
get_spec()
sys_info(conn: dlmed.hci.conn.Connection, args: [ ] )
class TrainingCommandModule(allowed_commands=None)

Bases: dlmed.hci.reg.CommandModule

abort_clients(clients, conn, message, requests, sai)
abort_training(conn: dlmed.hci.conn.Connection, args: [ ] )
authorize_check_status(conn: dlmed.hci.conn.Connection, args: [ ] )
authorize_deploy(conn: dlmed.hci.conn.Connection, args: [ ] )
authorize_operate(conn: dlmed.hci.conn.Connection, args: [ ] )
authorize_remove_client(conn: dlmed.hci.conn.Connection, args: [ ] )
authorize_set_run_number(conn: dlmed.hci.conn.Connection, args: [ ] )
authorize_start_mgpu_training(conn: dlmed.hci.conn.Connection, args: [ ] )
authorize_start_training(conn: dlmed.hci.conn.Connection, args: [ ] )
check_status(conn: dlmed.hci.conn.Connection, args: [ ] )
delete_run_number(conn: dlmed.hci.conn.Connection, args: [ ] )
deploy(conn: dlmed.hci.conn.Connection, args: [ ] )
get_client_cmd_reply(sai, requests, server)
get_spec()
remove_client(conn: dlmed.hci.conn.Connection, args: [ ] )
restart(conn: dlmed.hci.conn.Connection, args: [ ] )
restart_clients(clients, conn, message, requests, sai)
set_min_clients(conn: dlmed.hci.conn.Connection, args: [ ] )
set_run_number(conn: dlmed.hci.conn.Connection, args: [ ] )
set_timeout(conn: dlmed.hci.conn.Connection, args: [ ] )
shutdown(conn: dlmed.hci.conn.Connection, args: [ ] )
start_mgpu_training(conn: dlmed.hci.conn.Connection, args: [ ] )
start_training(conn: dlmed.hci.conn.Connection, args: [ ] )
class ValidationCommandModule(allowed_commands=None)

Bases: dlmed.hci.reg.CommandModule

do_validation_command(conn: dlmed.hci.conn.Connection, args: [ ] )
get_spec()
get_taskname(conn: dlmed.hci.conn.Connection, args: [ ] )
© Copyright 2020, NVIDIA. Last updated on Feb 2, 2023.