是否可以将不同的权重子集发送给不同的客户端?

问题描述

我正在尝试使用 tensorflow-federated 在服务器上选择不同的权重子集并将它们发送给客户端。然后客户端将训练并发送回训练好的权重。服务器汇总结果并开始新的通信回合。

主要问题是我无法访问权重的 numpy 版本,因此我不知道如何为每一层访问它们的子集。我尝试使用 tf.gather_nd 和 tf.tensor_scatter_nd_update 来执行选择和更新,但它们仅适用于张量,而不适用于张量列表(因为 server_state 在 tensorflow-federated 中)。

有没有人有任何提示可以解决这个问题?是否可以向每个客户端发送不同的权重?

解决方法

如果我正确地遵循,编写 TFF 类型速记中描述的高级计算的方法是:

@tff.federated_computation(...)
def run_one_round(server_state,client_datasets):
  weights_subset = tff.federated_map(subset_fn,server_state)
  clients_weights_subset = tff.federated_broadcast(weights_subset)
  client_models = tff.federated_map(client_training_fn,(clients_weights_subset,client_datasets))
  aggregated_update = tff.federated_aggregate(client_models,...)
  new_server_state = tff.federated_map(apply_aggregated_update_fn,server_state)
  return new_server_state

如果这是真的,似乎大部分工作需要在 subset_fn 中进行,它获取服务器状态并返回全局模式权重的子集。通常,模型是 tf.Tensor 的结构(listdict,可能嵌套),正如您所观察到的,它不能用作 tf.gather_nd 或 {{3} 的参数}.但是,它们可以逐点应用于使用 tf.tensor_scatter_nd_update 的张量结构。例如,从三个张量的嵌套结构中选择 [0,0] 处的值:

import tensorflow as tf
import pprint
struct_of_tensors = {
    'trainable': [tf.constant([[2.0,4.0,6.0]]),tf.constant([[5.0]])],'non_trainable': [tf.constant([[1.0]])],}
pprint.pprint(tf.nest.map_structure(
    lambda tensor: tf.gather_nd(params=tensor,indices=[[0,0]]),struct_of_tensors))

>>> {'non_trainable': [<tf.Tensor: shape=(1,),dtype=float32,numpy=array([1.],dtype=float32)>],'trainable': [<tf.Tensor: shape=(1,numpy=array([2.],dtype=float32)>,<tf.Tensor: shape=(1,numpy=array([5.],dtype=float32)>]}

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...