我可以在Tensorflow联合学习TFF的keras模型中使用class_weight

问题描述

我的数据集是类不平衡的,所以我想使用class_weight,它启用了分类器重的次要类。在一般情况下,我可以按以下方式分配班级权重:

weighted_history = weighted_model.fit(
    train_features,train_labels,batch_size=BATCH_SIZE,epochs=EPOCHS,callbacks=[early_stopping],validation_data=(val_features,val_labels),# The class weights go here
    class_weight=class_weight) 

在Tensorflow联合学习中,有什么方法可以分配class_weight?我的联合学习代码如下:

def create_keras_model(output_bias=None):
    return tf.keras.models.Sequential([
        tf.keras.layers.Dense(12,activation='relu',input_shape(5,)),tf.keras.layers.Dense(8,activation='relu'),tf.keras.layers.Dense(5,tf.keras.layers.Dense(3,tf.keras.layers.Dense(1,activation='sigmoid')])

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,input_spec=preprocessed_example_dataset.element_spec,loss=tf.keras.losses.BinaryCrossentropy(),metrics=[tf.keras.metrics.BinaryAccuracy()])

解决方法

不直接。主要问题是tf.keras.Model.fit方法在概念上并未映射到从分散数据进行训练的思想。

如果要在TFF中进行此工作,第一步是确定应执行的算法。据我所知,这没有一个明显的答案-例如,如果您无法直接访问数据,如何确定class_weights是什么?

但是,假设您以某种方式获得了该信息,并且只想修改客户的本地培训程序。从examples/simple_fedavg开始,实现这一目标的方法是适当修改this loop中的梯度计算方式。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...