索引错误:索引 5 超出大小为 5 的轴 1 的范围

问题描述

尝试将一列转换为用于 NN 分类分类数据。该列有 6 个类 Few rows of Dataset used

from tensorflow.keras.utils import to_categorical
y_train = to_categorical(y_train,num_classes=5)
y_test = to_categorical(y_test,num_classes=5)

得到的错误

IndexError:对于大小为 5 的轴 1,索引 5 超出范围。我应该怎么做才能清除它?

解决方法

如果该列有 6 个类,那么您为什么要在 num_classes=5 中传递 to_categorical
试试

y_train = to_categorical(y_train,num_classes=6)
y_test = to_categorical(y_test,num_classes=6)