分层的KFold交叉验证KerasValueError:找到的数组为暗4估计量应小于等于2

问题描述

我需要使用分层的kfold(不平衡的多类任务)交叉验证keras模型。是否可以在(folds = list(StratifiedKFold(k,shuffle = True,random_state = 1).split(x_train,y_train)))中将image_image_generator(flow_from_directory)与x_train / y_train一起使用?在Kaggle({{3 }}),但x_train,y_train = next(train_generator)不能正确映射数据和标签。非常感谢您的帮助!

train_generator = ImageDataGenerator(rescale=1./255).flow_from_directory(
 directory=train_path,target_size=input_img[:-1],color_mode="rgb",batch_size=BATCH_SIZE,classes=target_names,class_mode="input")

test_generator = ImageDataGenerator(rescale=1./255).flow_from_directory(
 directory=test_path,class_mode="input",shuffle=False)

test_labels = test_generator.classes

#Instantiate to load data and generate k stratified folds
k = 5
def load_data_kfold(k):
 #For StratifiedKFold,labels (y_train) must be 1-D array of labels (Cannot be one-hot)
 x_train,y_train = next(train_generator)
 print(x_train)
 print(y_train)
 folds = list(StratifiedKFold(k,shuffle=True,random_state=1).split(x_train,y_train))

 return folds,x_train,y_train

folds,y_train = load_data_kfold(k)

Traceback (most recent call last):
File "C:/Users/LaRoche Lab/PycharmProjects/pythonProject2/R.py",line 122,in <module>
folds,y_train = load_data_kfold(k)
File "C:/Users/LaRoche Lab/PycharmProjects/pythonProject2/R.py",line 118,in load_data_kfold
folds = list(StratifiedKFold(k,y_train))
File "C:\Users\LaRoche Lab\Anaconda3\envs\tensorflow\lib\site- 
packages\sklearn\model_selection\_split.py",line 735,in split
y = check_array(y,ensure_2d=False,dtype=None)
File "C:\Users\LaRoche Lab\Anaconda3\envs\tensorflow\lib\site-packages\sklearn\utils\validation.py",line 73,in inner_f return f(**kwargs)File "C:\Users\LaRoche Lab\Anaconda3\envs\tensorflow\lib\site- 
packages\sklearn\utils\validation.py",line 642,in check_array
% (array.ndim,estimator_name)) 
ValueError: Found array with dim 4. Estimator expected <= 2.

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...