问题描述
我正在尝试为数据集的每个图像提取标签值。我的想法是拆分训练和测试:
data,info= tfds.load("cats_vs_dogs",as_supervised = True,split = 'train',with_info=True)
# Split into training and test parts
test_dataset = data.take(5500)
train_dataset = data.skip(5500)
然后我想应用预处理来根据训练中的标签进行区分:
def preprocess_start(image,label):
#cast the image values from integers to floating and then divide by 255
image = tf.image.resize(image,[100,100])
image = tf.cast(image,tf.float32)/ 255.0
if label == 0: #this check is not correct
code ...
else:
code ...
return image,label
train_test= train_test.map(preprocess_start)
但是主要的问题是标签是:
Tensor("args_1:0",shape=(),dtype=int64)
如何提取整数值?
解决方法
# Split into training and test parts
test_dataset = data.take(5500)
train_dataset = data.skip(5500)
@tf.function
def preprocess_start(image,label):
# cast the image values from integers to floating and then divide by 255
image = tf.image.resize(image,[100,100])
image = tf.cast(image,tf.float32) / 255.0
if tf.equal(label,0 ): # this check is not correct
tf.print("Zero",label)
else:
tf.print("Not Zero",label)
return image,label
train_test = train_dataset.map(preprocess_start)
for e in train_test:
tf.print(e)
所以对于更简单的功能,这应该可以工作。如果您需要更复杂的逻辑,您必须阅读 rules
Zero 0
Zero 0
Not Zero 1
Zero 0
Zero 0
Zero 0
Not Zero 1
因此,例如,此检查应按预期工作。
l = tf.constant(0)
if tf.equal(l,0 ): # this check is not correct
tf.print("Zero",l)
else:
tf.print("Not Zero",l)