Tensorflow联合如何从服务器更新模型

问题描述

Tensorflow的新手,因此不确定这是否是Tensorflow Federated的特定问题。

我正在此code中研究针对联合学习的对抗性攻击。我很好奇在客户端如何更新从服务器接收到的权重。

例如,以下是“良性”更新的代码

@tf.function
def compute_benign_update():
  """compute benign update sent back to the server."""
  tf.nest.map_structure(lambda a,b: a.assign(b),model_weights,initial_weights)

  num_examples_sum = benign_dataset.reduce(initial_state=tf.constant(0),reduce_func=reduce_fn)

  weights_delta_benign = tf.nest.map_structure(lambda a,b: a - b,model_weights.trainable,initial_weights.trainable)

  aggregated_outputs = model.report_local_outputs()
  return weights_delta_benign,aggregated_outputs,num_examples_sum

我可以看到,从服务器收到的初始权重已分配给model_weights,然后使用reduce_fn在本地客户端上训练一批数据。

@tf.function
def reduce_fn(num_examples_sum,batch):
  """Runs `tff.learning.Model.train_on_batch` on local client batch."""
  with tf.GradientTape() as tape:
    output = model.forward_pass(batch)
  gradients = tape.gradient(output.loss,model.trainable_variables)
  optimizer.apply_gradients(zip(gradients,model.trainable_variables))
  return num_examples_sum + tf.shape(output.predictions)[0]

在此功能内进行训练,并且(我认为)model.trainable_variables已更新。对我来说没有意义的部分是weights_delta_benign的计算方式:

weights_delta_benign = tf.nest.map_structure(lambda a,initial_weights.trainable)

似乎使用了model_weights.trainableinitial_weights.trainable间的区别,但是我们最初不是在compute_benign_update()函数的第一行中将它们设置为相等吗?我假设reduce_fn以某种方式更改了initial_weights,但是我看不到reduce函数中使用的model.trainable_variablesinitial_weights.trainable_variables间的联系。

谢谢,感谢您的帮助!

解决方法

在您指向的代码中,initial_weights只是值的集合(tf.Tensor对象),model_weights是对model变量的引用( tf.Variable个对象)。我们使用initial_weights为模型变量分配初始值。

然后,在对optimizer.apply_gradients(zip(gradients,model.trainable_variables))的调用中,您仅修改模型的变量。 ({model.trainable_variables,指的是与model_weights.trainable相同的对象。我承认,这有点令人困惑。)

因此weights_delta_benign的后续计算是计算模型的可训练变量在客户训练过程的开始和结束之间的差。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...