问题描述
TFF教程和研究项目中的代码通常只跟踪服务器状态。我希望有内部客户端状态(例如,完全分散的客户端内部神经网络,并且不会以联合方式进行更新),这会影响联合客户端的计算。
但是,在我所看到的客户端计算中,它们只是服务器状态和数据的函数。是否可以完成上述任务?
解决方法
是的,这很容易在TFF中表达,并且将在默认执行堆栈中很好地执行。
您已经注意到,TFF存储库通常包含跨设备联合学习(Kairouz et. al 2019)的示例。通常,我们谈论状态具有tff.SERVER
的位置,并且一轮“联合”学习的函数签名具有结构(有关TFF类型速记的详细信息,请参见教程的Federated data部分):
(<State@SERVER,{Dataset}@CLIENTS> -> State@Server)
我们可以通过简单地扩展签名来表示有状态的客户端:
(<State@SERVER,{State}@Clients,{Dataset}@CLIENTS> -> <State@Server,{State}@Clients>)
实施包含客户端状态对象的联邦平均(McMahan et. al 2016)版本可能类似于:
@tff.tf_computation(
model_type,client_state_type,# additional state parameter
client_data_type)
def client_training_fn(model,state,dataset):
model_update,new_state = # do some local training
return model_update,new_state # return a tuple including updated state
@tff.federated_computation(
tff.FederatedType(server_state_type,tff.SERVER),tff.FederatedType(client_state_type,tff.CLIENTS),# new parameter for state
tff.FederatedType(client_data_type,tff.CIENTS))
def run_fed_avg(server_state,client_states,client_datasets):
client_initial_models = tff.federated_broadcast(server_state.model)
client_updates,new_client_state = tff.federated_map(client_training_fn,# Pass the client states as an argument.
(client_initial_models,client_datasets))
average_update = tff.federated_mean(client_updates)
new_server_state = tff.federated_map(server_update_fn,(server_state,average_update))
# Make sure to return the client states so they can be used in later rounds.
return new_server_state,new_client_states
run_fed_avg
的调用需要为参与回合的每个客户端传递张量/结构的Python list
,方法调用的结果将是服务器状态,以及列表客户状态。