具有sample_weights的TFRecords,与辍学层冲突

问题描述

我一直在积极寻找答案,但找不到答案。我在张量流方面也很新,所以如果不清楚或愚蠢,请原谅我。

我正在尝试根据图像预测浮动值,而CNN的最后一层是:

x = Flatten()(x)
x = Dense(32,activation='relu')(x)
x = Dropout(rate = 0.5)(x)
x = Dense(32,activation='relu')(x)
x = Dropout(rate = 0.25)(x)
x = Dense(1,activation = 'linear)(x)

图像和目标浮点值以TFRecords格式保存。为了获得加权损失(MSE或MAE),添加一个额外的浮点数(每个图像一个)。

def write_tfrecords(out_path,images,labels,fweights):
    assert len(images) == len(labels)
    with tf.io.TFRecordWriter(out_path) as writer:  
        for i in range(len(labels)):
            img_bytes = images[i].tostring()  
            labels_temp = labels[i] 
            fweight_temp= fweights[i]
            data = {'image': _bytes_feature(img_bytes),'label': _float_feature(labels_temp),'fweight':_float_feature(fweight_temp)}
            feature = tf.train.Features(feature=data)  
            example = tf.train.Example(features=feature)  
            serialized = example.SerializetoString()  
            writer.write(serialized)  

请参阅文档,在解析TFRecord时,第三个输出将作为sample_weight。编译模型时,我还设置了sample_weight_mode=None

def parse_example(serialized,shape=(INPUT_HEIGHT,INPUT_WIDTH,1)):
    features = {'image': tf.io.FixedLenFeature([],tf.string),'label': tf.io.FixedLenFeature((),tf.float32),'fweight': tf.io.FixedLenFeature((),tf.float32)}
    parsed_example = tf.io.parse_single_example(serialized=serialized,features=features)
    image_raw = parsed_example['image']  
    image = tf.decode_raw(image_raw,tf.float32) 
    image = tf.reshape(image,shape=shape)
    label = parsed_example['label']
    fweight = parsed_example['fweight']
    return image,label,fweight

现在,出现以下错误,批处理大小为10(如果将batch_size更改为1,则更改为1):

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: The second input must be a scalar,but it has shape [10]
     [[{{node dropout_44/cond/Switch}}]]
     [[loss_19/mul/_3891]]
  (1) Invalid argument: The second input must be a scalar,but it has shape [10]
     [[{{node dropout_44/cond/Switch}}]]
0 successful operations.
0 derived errors ignored.

如果我删除2个辍学层,则该错误不会发生,并且模型运行良好。

我感觉到这些层将输出形式从[None,32]更改为[[None,32]],从而使损失乘以sample_weight成为不可能。但是,我不知道如何解决。我已经在这些图层的输出上尝试了tf.squeeze,但这是不可能的,而Flatten()则无济于事。

您有什么建议吗? :)

非常感谢!

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)