问题描述
def get_label(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path,os.path.sep)
class_names = ['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
# The second to last is the class-directory
one_hot = parts[-2] == class_names
# Integer encode the label
return tf.argmax(one_hot)
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img,channels=3)
# resize the image to the desired size
return tf.image.resize(img,[img_height,img_width])
def process_path(file_path):
label = get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img,label
train_ds = train_ds.map(process_path,num_parallel_calls=AUTOTUNE)
如果我使用具有2个标签的其他数据集来更改此代码,则class_names = ['dog','cat']
会发现此错误
TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,float16,uint32,uint64
那么我如何更新 def get_label(file_path)
解决方法
我的猜测是tf.argmax需要这些数据类型之一(我现在无法测试)
float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,float16,uint32,uint64
所以您要做的就是转换输出
one_hot = parts[-2] == class_names
对于int,“ ==“的计算结果为True / False,这可能是不允许的。
,我遇到了同样的问题。遵循最后一篇文章的想法:
from django.contrib.auth import logout
from django.http import HttpResponse
def logout_view(request):
logout(request)
return HttpResponse('OK')