问题描述
client_output
中有以下属性
weights_delta = attr.ib()
client_weight = attr.ib()
model_output = attr.ib()
client_loss = attr.ib()
之后,我通过序列的形式制作了client_output
a = tff.federated_collect(client_output)
和 round_model_delta = tff.federated_map(selecting_fn,a)
在 here 中。我宣布
`
@tff.tf_computation() # append
def selecting_fn(a):
#Todo
return round_model_delta
在here。在服务器上求平均值的过程中,我想通过选择一些损失值较小的客户端来对weights_delta
进行平均。所以我尝试通过 a.weights_delta
访问它,但它不起作用。
解决方法
tff.federated_collect
返回一个位于 tff.SequenceType
的 tff.SERVER
,您可以像处理 for example 客户端数据集通常在 {{1} }.
请注意,您必须在 tff.tf_computation
的范围内使用 tff.federated_collect
运算符。您可能想要做的 [*] 是使用 tff.federated_computation
运算符将其传递给 tff.tf_computation
。进入 tff.federated_map
后,您可以将其视为 tff.tf_computation
对象,并且 tf.data.Dataset
模块中的所有内容都可用。
[*] 我猜。对您想要实现的目标进行更详细的说明会有所帮助。