如何访问序列类型中的值?

问题描述

client_output 中有以下属性

weights_delta = attr.ib()
client_weight = attr.ib()
model_output = attr.ib()
client_loss = attr.ib() 

之后,我通过序列的形式制作了client_output a = tff.federated_collect(client_output)round_model_delta = tff.federated_map(selecting_fn,a)here 中。我宣布 `

@tff.tf_computation()  # append
def selecting_fn(a):
    #Todo
    return round_model_delta

here。在服务器上求平均值的过程中,我想通过选择一些损失值较小的客户端来对weights_delta进行平均。所以我尝试通过 a.weights_delta 访问它,但它不起作用。

解决方法

tff.federated_collect 返回一个位于 tff.SequenceTypetff.SERVER,您可以像处理 for example 客户端数据集通常在 {{1} }.

请注意,您必须在 tff.tf_computation 的范围内使用 tff.federated_collect 运算符。您可能想要做的 [*] 是使用 tff.federated_computation 运算符将其传递给 tff.tf_computation。进入 tff.federated_map 后,您可以将其视为 tff.tf_computation 对象,并且 tf.data.Dataset 模块中的所有内容都可用。

[*] 我猜。对您想要实现的目标进行更详细的说明会有所帮助。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...