如何创建使用少数客户端权重的 FL 算法?

问题描述

基于此 link 我正在尝试编写一种新的 FL 算法方法。我训练所有客户端并将所有客户端的模型参数发送到服务器,服务器在聚合过程中只会对所有客户端的30%的模型参数进行加权平均。作为选择 30% 客户的模型参数的标准,我想通过使用 weights_delta 的 30% 的客户和较少的 loss_sum 客户进行加权平均。

以下代码是针对此 link修改代码

@tf.function
def client_update(model,dataset,server_message,client_optimizer):

model_weights = model.weights
initial_weights = server_message.model_weights
tff.utils.assign(model_weights,initial_weights)

num_examples = tf.constant(0,dtype=tf.int32)
loss_sum = tf.constant(0,dtype=tf.float32)

for batch in iter(dataset):
    with tf.GradientTape() as tape:
        outputs = model.forward_pass(batch)
    grads = tape.gradient(outputs.loss,model_weights.trainable)
    grads_and_vars = zip(grads,model_weights.trainable)
    client_optimizer.apply_gradients(grads_and_vars)
    batch_size = tf.shape(batch['x'])[0]
    num_examples += batch_size
    loss_sum += outputs.loss * tf.cast(batch_size,tf.float32)        

weights_delta = tf.nest.map_structure(lambda a,b: a - b,model_weights.trainable,initial_weights.trainable)
client_weight = tf.cast(num_examples,tf.float32)

client_loss = loss_sum #add

return ClientOutput(weights_delta,client_weight,loss_sum / client_weight,client_loss) 

client_output 中有以下属性

weights_delta = attr.ib()
client_weight = attr.ib()
model_output = attr.ib()
client_loss = attr.ib() 

之后,我通过序列的形式制作了client_output collected_output = tff.federated_collect(client_output)round_model_delta = tff.federated_map(selecting_fn,(collected_output,weight_denom))here 中。

   @tff.federated_computation(federated_server_state_type,federated_dataset_type)

    def run_one_round(server_state,federated_dataset):
    
    server_message = tff.federated_map(server_message_fn,server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn,(federated_dataset,server_message_at_client))

    weight_denom = client_outputs.client_weight

    collected_output = tff.federated_collect(client_outputs)  # add        
    
    round_model_delta = tff.federated_map(selecting_fn,weight_denom)) #add       

    server_state = tff.federated_map(server_update_fn,(server_state,round_model_delta))

    round_loss_metric = tff.federated_mean(client_outputs.model_output,weight=weight_denom)

    return server_state,round_loss_metric

另外,添加了以下代码here来实现selecting_fn函数

@tff.tf_computation()  # append
def selecting_fn(collected_output,weight_denom):
    #Todo
    return round_model_delta

我不确定按照上面的方式编写代码是否正确。 我尝试了各种方法,但主要是 TypeError: The value to be mapped must be a FederatedType or implicitly convertible to a FederatedType (got a <<model_weights=<trainable=<float32[5,5,1,32],float32[32],float32[5,32,64],float32[64],float32[3136,512],float32[512],float32[512,10],float32[10]>,non_trainable=<>>,optimizer_state= <int64>,round_num=int32>@SERVER,{<float32[5,512 ],float32[10]>}@CLIENTS>) 我收到此错误

我想知道序列类型 collected_output 如何访问每个客户端的 client_loss(= loss_sum) 并对它们进行排序,还想知道在应用了 weight_denom 的情况下计算加权平均值时使用什么方法

解决方法

我看到的一个问题是,在调用 tff.federated_map(selecting_fn,(collected_output,weight_denom) 中,collected_output 将放置在 tff.SERVER,而 weight_denom 将放置在 tff.CLIENTS,所以这是行不通的。我想你想先把所有东西都送到 tff.SERVER

我不确定您需要什么行为,但这里有一个示例代码,您可以从中开始和开发。它从客户端值(比如它的 ID)开始,采样一个随机值,将 (ID,value) 对收集到服务器,然后选择具有最大值的对 - 看起来与您描述的相似。

@tff.tf_computation()
def client_sample_fn():
  return tf.random.uniform((1,))

# Type annotation optional for tff.tf_computation. Added here for clarity.
idx_sample_type = tff.to_type(((tf.int32,(1,)),(tf.float32,))))
@tff.tf_computation(tff.SequenceType(idx_sample_type))
def select_fn(idx_sample_dataset):  # Inside,this is a tf.data.Dataset.
  # Concatenate all pairs in the dataset.
  concat_fn = lambda a,b: tf.concat([a,b],axis=0)
  reduce_fn = lambda x,y: (concat_fn(x[0],y[0]),concat_fn(x[1],y[1]))
  reduce_zero = (tf.constant((),dtype=tf.int32,shape=(0,tf.constant((),dtype=tf.float32,)))
  idx_tensor,sample_tensor = idx_sample_dataset.reduce(reduce_zero,reduce_fn)
  # Find 3 largest samples.
  top_3_val,top_3_idx = tf.math.top_k(sample_tensor,k=3)
  return tf.gather(idx_tensor,top_3_idx),top_3_val

@tff.federated_computation(tff.type_at_clients((tf.int32,))))
def fed_fn(client_idx):
  client_sample = tff.federated_eval(client_sample_fn,tff.CLIENTS)
  # First zip,to have a dataset of pairs,rather than a pair of datasets.
  client_idx_sample_pair = tff.federated_zip((client_idx,client_sample))
  collected_idx_sample_pair = tff.federated_collect(client_idx_sample_pair)
  return tff.federated_map(select_fn,collected_idx_sample_pair)

client_idx = [(10,),(11,(12,(13,(14,)]
fed_fn(client_idx)

使用示例输出:

(array([11,10,14],dtype=int32),array([0.8220736,0.81413555,0.6984291 ],dtype=float32))

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...