问题描述
尝试将一列转换为用于 NN 分类的分类数据。该列有 6 个类 Few rows of Dataset used
from tensorflow.keras.utils import to_categorical
y_train = to_categorical(y_train,num_classes=5)
y_test = to_categorical(y_test,num_classes=5)
得到的错误是
IndexError:对于大小为 5 的轴 1,索引 5 超出范围。我应该怎么做才能清除它?
解决方法
如果该列有 6 个类,那么您为什么要在 num_classes=5
中传递 to_categorical
。
试试
y_train = to_categorical(y_train,num_classes=6)
y_test = to_categorical(y_test,num_classes=6)