问题描述
我一直在积极寻找答案,但找不到答案。我在张量流方面也很新,所以如果不清楚或愚蠢,请原谅我。
我正在尝试根据图像预测浮动值,而CNN的最后一层是:
x = Flatten()(x)
x = Dense(32,activation='relu')(x)
x = Dropout(rate = 0.5)(x)
x = Dense(32,activation='relu')(x)
x = Dropout(rate = 0.25)(x)
x = Dense(1,activation = 'linear)(x)
图像和目标浮点值以TFRecords格式保存。为了获得加权损失(MSE或MAE),添加了一个额外的浮点数(每个图像一个)。
def write_tfrecords(out_path,images,labels,fweights):
assert len(images) == len(labels)
with tf.io.TFRecordWriter(out_path) as writer:
for i in range(len(labels)):
img_bytes = images[i].tostring()
labels_temp = labels[i]
fweight_temp= fweights[i]
data = {'image': _bytes_feature(img_bytes),'label': _float_feature(labels_temp),'fweight':_float_feature(fweight_temp)}
feature = tf.train.Features(feature=data)
example = tf.train.Example(features=feature)
serialized = example.SerializetoString()
writer.write(serialized)
请参阅文档,在解析TFRecord时,第三个输出将作为sample_weight。编译模型时,我还设置了sample_weight_mode=None
。
def parse_example(serialized,shape=(INPUT_HEIGHT,INPUT_WIDTH,1)):
features = {'image': tf.io.FixedLenFeature([],tf.string),'label': tf.io.FixedLenFeature((),tf.float32),'fweight': tf.io.FixedLenFeature((),tf.float32)}
parsed_example = tf.io.parse_single_example(serialized=serialized,features=features)
image_raw = parsed_example['image']
image = tf.decode_raw(image_raw,tf.float32)
image = tf.reshape(image,shape=shape)
label = parsed_example['label']
fweight = parsed_example['fweight']
return image,label,fweight
现在,出现以下错误,批处理大小为10(如果将batch_size更改为1,则更改为1):
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: The second input must be a scalar,but it has shape [10]
[[{{node dropout_44/cond/Switch}}]]
[[loss_19/mul/_3891]]
(1) Invalid argument: The second input must be a scalar,but it has shape [10]
[[{{node dropout_44/cond/Switch}}]]
0 successful operations.
0 derived errors ignored.
我感觉到这些层将输出形式从[None,32]
更改为[[None,32]]
,从而使损失乘以sample_weight成为不可能。但是,我不知道如何解决。我已经在这些图层的输出上尝试了tf.squeeze,但这是不可能的,而Flatten()则无济于事。
您有什么建议吗? :)
非常感谢!
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)