TFF客户端是否有内部状态的方法?

问题描述

TFF教程和研究项目中的代码通常只跟踪服务器状态。我希望有内部客户端状态(例如,完全分散的客户端内部神经网络,并且不会以联合方式进行更新),这会影响联合客户端的计算。

但是,在我所看到的客户端计算中,它们只是服务器状态和数据的函数。是否可以完成上述任务?

解决方法

是的,这很容易在TFF中表达,并且将在默认执行堆栈中很好地执行。

您已经注意到,TFF存储库通常包含跨设备联合学习Kairouz et. al 2019)的示例。通常,我们谈论状态具有tff.SERVER的位置,并且一轮“联合”学习的函数签名具有结构(有关TFF类型速记的详细信息,请参见教程的Federated data部分):

(<State@SERVER,{Dataset}@CLIENTS> -> State@Server)

我们可以通过简单地扩展签名来表示有状态的客户端:

(<State@SERVER,{State}@Clients,{Dataset}@CLIENTS> -> <State@Server,{State}@Clients>)

实施包含客户端状态对象的联邦平均(McMahan et. al 2016)版本可能类似于:

@tff.tf_computation(
  model_type,client_state_type,# additional state parameter
  client_data_type)
def client_training_fn(model,state,dataset):
  model_update,new_state = # do some local training
  return model_update,new_state # return a tuple including updated state

@tff.federated_computation(
  tff.FederatedType(server_state_type,tff.SERVER),tff.FederatedType(client_state_type,tff.CLIENTS),# new parameter for state
  tff.FederatedType(client_data_type,tff.CIENTS))
def run_fed_avg(server_state,client_states,client_datasets):
  client_initial_models = tff.federated_broadcast(server_state.model)
  client_updates,new_client_state = tff.federated_map(client_training_fn,# Pass the client states as an argument.
    (client_initial_models,client_datasets))
  average_update = tff.federated_mean(client_updates)
  new_server_state = tff.federated_map(server_update_fn,(server_state,average_update))
  # Make sure to return the client states so they can be used in later rounds.
  return new_server_state,new_client_states

run_fed_avg的调用需要为参与回合的每个客户端传递张量/结构的Python list,方法调用的结果将是服务器状态,以及列表客户状态。