TensorFlow 联合压缩:如何实现在 TFF 的 build_federated_averaging_process 中使用的有状态编码器?

问题描述

在 Tensorflow Federated (TFF) 中,您可以向 tff.learning.build_federated_averaging_process 传递一个 broadcast_process一个 aggregation_process,它们可以嵌入定制的编码器,例如应用自定义压缩。

说到我的问题,我正在尝试实现一个编码器来稀疏模型更新/模型权重。

我正在尝试通过实现 EncodingStageInterface 中的 tensorflow_model_optimization.python.core.internal 来构建这样的编码器。 但是,我正在努力实现(本地)状态以逐轮累积模型更新/模型权重的归零坐标。请注意,不应传达此状态,而只需要在本地维护(因此 AdaptiveEncodingStageInterface 应该没有帮助)。一般来说,问题是如何在 Encoder 内部维护一个本地状态,然后传递给 fedavg 进程。

我附上了我的编码器实现的代码(除了我想添加的状态之外,它可以像预期的那样无状态地正常工作)。 然后我附上我使用编码器实现的代码的摘录。 如果我对 stateful_encoding_stage_topk.py 中的注释部分进行注释,则代码不起作用:我无法弄清楚如何在 TF 非渴望模式下管理状态(即张量)。

stateful_encoding_stage_topk.py

import tensorflow as tf
import numpy as np
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te


@te.core.tf_style_encoding_stage
class StatefulTopKEncodingStage(te.core.EncodingStageInterface):

  ENCODED_VALUES_KEY = 'stateful_topk_values'
  INDICES_KEY = 'indices'
  
  
  def __init__(self):
    super().__init__()
    # Here I would like to init my state
    #self.A = tf.zeros([800],dtype=tf.float32)

  @property
  def name(self):
    """See base class."""
    return 'stateful_topk'

  @property
  def compressible_tensors_keys(self):
    """See base class."""
    return [self.ENCODED_VALUES_KEY]

  @property
  def commutes_with_sum(self):
    """See base class."""
    return True

  @property
  def decode_needs_input_shape(self):
    """See base class."""
    return True

  def get_params(self):
    """See base class."""
    return {},{}

  def encode(self,x,encode_params):
    """See base class."""
    del encode_params  # Unused.

    dW = tf.reshape(x,[-1])
    # Here I would like to retrieve the state
    A = tf.zeros([800],dtype=tf.float32)
    #A = self.residual
    
    dW_and_A = tf.math.add(A,dW)

    percentage = tf.constant(0.4,dtype=tf.float32)
    k_float = tf.multiply(percentage,tf.cast(tf.size(dW),tf.float32))
    k_int = tf.cast(tf.math.round(k_float),dtype=tf.int32)

    values,indices = tf.math.top_k(tf.math.abs(dW_and_A),k = k_int,sorted = False)
    indices = tf.expand_dims(indices,1)
    sparse_dW = tf.scatter_nd(indices,values,tf.shape(dW_and_A))
    
    # Here I would like to update the state
    A_updated = tf.math.subtract(dW_and_A,sparse_dW)
    #self.A = A_updated
    
    encoded_x = {self.ENCODED_VALUES_KEY: values,self.INDICES_KEY: indices}

    return encoded_x

  def decode(self,encoded_tensors,decode_params,num_summands=None,shape=None):
    """See base class."""
    del decode_params,num_summands  # Unused.
    
    indices = encoded_tensors[self.INDICES_KEY]
    values = encoded_tensors[self.ENCODED_VALUES_KEY]
    tensor = tf.fill([800],0.0)
    decoded_values = tf.tensor_scatter_nd_update(tensor,indices,values)
    
    return tf.reshape(decoded_values,shape)



def sparse_quantizing_encoder():
  encoder = te.core.EncoderComposer(
      StatefulTopKEncodingStage() )  
  return encoder.make()

fedavg_with_sparsification.py

[...]

def sparsification_broadcast_encoder_fn(value):
  spec = tf.TensorSpec(value.shape,value.dtype)
  return te.encoders.as_simple_encoder(te.encoders.identity(),spec)

def sparsification_mean_encoder_fn(value):
  spec = tf.TensorSpec(value.shape,value.dtype)
  
  if value.shape.num_elements() == 800:
    return te.encoders.as_gather_encoder(
        stateful_encoding_stage_topk.sparse_quantizing_encoder(),spec)

  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(),spec)
  
encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        model_fn,sparsification_broadcast_encoder_fn))

encoded_mean_process = (
    tff.learning.framework.build_encoded_mean_process_from_model(
        model_fn,sparsification_mean_encoder_fn))


iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.004),client_weight_fn=lambda _: tf.constant(1.0),broadcast_process=encoded_broadcast_process,aggregation_process=encoded_mean_process)

[...]

我正在使用:

  • 张量流 2.4.0
  • 张量流联合 0.17.0

解决方法

我会试着分两部分来回答; (1) 没有状态的 top_k 编码器和 (2) 在 TFF 中实现您似乎想要的有状态的想法。

(1)

为了让 TopKEncodingStage 在没有状态的情况下工作,我需要更改一些细节。

commutes_with_sum 属性应设置为 False。在伪代码中,它的含义是是否 sum_x(decode(encode(x))) == decode(sum_x(encode(x))) 。对于您的 encode 方法返回的表示而言,情况并非如此——对 indices 求和效果不佳。我认为 decode 方法的实现可以简化为

return tf.scatter_nd(
    indices=encoded_tensors[self.INDICES_KEY],updates=encoded_tensors[self.ENCODED_VALUES_KEY],shape=shape)

(2)

使用 tff.learning.build_federated_averaging_process 无法以这种方式实现您所指的内容。该方法返回的进程没有任何维护客户端/本地状态的机制。无论您的 StatefulTopKEncodingStage 中表达的状态是什么,最终都将成为服务器状态,而不是本地状态。

要使用客户端/本地状态,您可能需要编写更多自定义代码。对于初学者,请参阅 examples/stateful_clients,您可以对其进行调整以存储您引用的状态。

请记住,在 TFF 中,这需要表示为函数转换。将值存储在类的属性中并在其他地方使用它们可能会导致令人惊讶的错误。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...