TypeError:传递给参数'input'的值的DataType bool不在允许值列表中:float32,float64,int32,uint8,int16,int8

问题描述

我有一个包含5个标签的数据集

def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path,os.path.sep)
  class_names = ['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
  # The second to last is the class-directory
  one_hot = parts[-2] == class_names
  # Integer encode the label
  return tf.argmax(one_hot)

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img,channels=3)
  # resize the image to the desired size
  return tf.image.resize(img,[img_height,img_width])

def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img,label

train_ds = train_ds.map(process_path,num_parallel_calls=AUTOTUNE)

如果我使用具有2个标签的其他数据集来更改此代码,则class_names = ['dog','cat']会发现此错误 TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,float16,uint32,uint64 那么我如何更新 def get_label(file_path)

解决方法

我的猜测是tf.argmax需要这些数据类型之一(我现在无法测试)

float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,float16,uint32,uint64

所以您要做的就是转换输出

one_hot = parts[-2] == class_names

对于int,“ ==“的计算结果为True / False,这可能是不允许的。

,

我遇到了同样的问题。遵循最后一篇文章的想法:

from django.contrib.auth import logout
from django.http import HttpResponse

def logout_view(request):
  logout(request)
  return HttpResponse('OK')