实施交叉验证时出错

问题描述

我正在尝试使用交叉验证来评估模型(MNIST):

from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
skfolds = StratifiedKFold(n_splits=5,random_state=42)

在运行第三行时,我得到以下警告:

C:\ Users \ nextg \ Desktop \ sample_project \ env \ lib \ site-packages \ sklearn \ model_selection_split.py:293: FutureWarning:由于shuffle是 假。这将产生0.24的误差。你应该离开random_state 设为认值(无),或设置shuffle = True。 warnings.warn(

忽略警告,我写这段代码

for train_index,test_index in skfolds.split(X_train,y_test_5):
   clone_clf = clone(sgd_clf)
   X_train_folds = X_train[train_index]
   y_train_folds = y_train[train_index]
   X_test_fold = X_test[test_index]
   y_test_fold = y_test_5[test_index]

   clone_clf.fit(X_train_folds,y_train_folds)
   y_pred = clone_clf.predict(X_test_fold)
   n_correct = sum(y_pred == y_test_fold)
   print(n_correct / len(y_pred))

运行此代码后,错误

ValueError                                Traceback (most recent call last)
<ipython-input-66-7e786591c439> in <module>
 ----> 1 for train_index,y_test_5):
  2     clone_clf = clone(sgd_clf)
  3     X_train_folds = X_train[train_index]
  4     y_train_folds = y_train[train_index]
  5     X_test_fold = X_test[test_index]

 ~\Desktop\sample_project\env\lib\site- 
 packages\sklearn\model_selection\_split.py in split(self,X,y,groups)
     326             The testing set indices for that split.
     327         """
 --> 328         X,groups = indexable(X,groups)
     329         n_samples = _num_samples(X)
     330         if self.n_splits > n_samples:

   ~\Desktop\sample_project\env\lib\site-packages\sklearn\utils\validation.py in indexable(*iterables)
    291     """
    292     result = [_make_indexable(X) for X in iterables]
--> 293     check_consistent_length(*result)
    294     return result
    295 

 ~\Desktop\sample_project\env\lib\site-packages\sklearn\utils\validation.py in check_consistent_length(*arrays)
    254     uniques = np.unique(lengths)
    255     if len(uniques) > 1:
--> 256         raise ValueError("Found input variables with inconsistent numbers of"
257                          " samples: %r" % [int(l) for l in lengths])
258 

 ValueError: Found input variables with inconsistent numbers of samples: [60000,10000]

有人可以解决这个错误

解决方法

此表达式没有意义:UserModule

应该是skfolds.split(X_train,y_test_5)skfolds.split(X,y)

来自doc

X.shape[0] == y.shape[0]
,

应该为skfolds.split(X_train,y_train_5)而不是skfolds.split(X_train,y_test_5) 在for循环的第二行,其y_test_fold = y_train_5[test_index]不是y_train_folds = y_train[train_index]

整个问题都因为使用Tab键开始了。

,

它的作用:

from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone

skfolds = StratifiedKFold(n_splits=3,random_state=42,shuffle=True)

for train_index,test_index in skfolds.split(X_train,y_train_5):  
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train.values[train_index]
    y_train_folds = y_train_5[train_index]
    X_test_fold = X_train.values[test_index]
    y_test_fold = y_train_5[test_index]
    
    clone_clf.fit(X_train_folds,y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    n_correct = sum(y_pred == y_test_fold)
    print(n_correct / len(y_pred))

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...