多项式logit在交叉验证的某些方面返回nans

问题描述

我编写了此代码,该代码使用分层k折叠来拆分数据集并拟合多项式回归,然后获得准确性。我的X是一个有19个变量的数组(最后一个是聚类变量),而Y有3个类(0,1,2)。

X = np.asarray(df[[*all 19 columns here*]],dtype="float64")
y = np.asarray(df["categoric_var"],dtype="int")

acc_test=[]
acc_train=[]
skf = StratifiedKFold(n_splits=5,shuffle=True)
split_n = 0

for train_ix,test_ix in skf.split(X,y):
    split_n +=1
    X_train,X_valid = X[train_ix],X[test_ix]
    y_train,y_valid = y[train_ix],y[test_ix]
    cluster_groups = X_train[:,-1]
    X_train2 = X_train[:,:-1].astype("float64") # remove clustering variable
    X_valid2 = X_valid[:,:-1].astype("float64") # remove clustering variable

    mnl = sm.MNLogit(y_train,X_train2).fit(cov_type="cluster",cov_kwds={"groups":cluster_groups})
    print(mnl.summary())
    train_pred = mnl.predict(X_train2)

    # turn predicted probabilities into final classification,into a list
    pred_list_train = []
    for row in train_pred:
        if np.where(row == np.amax(row))[0]==0:
            pred_list_train.append(0)
        elif np.where(row == np.amax(row))[0]==1:
            pred_list_train.append(1)
        else:
            pred_list_train.append(2)

    print('MNLogit Regression,training set,fold ',i,': ',classification_report(y_train,pred_list_train))
    
    pred = mnl.predict(X_valid2)

    # turn predicted probabilities into final classification,into a list
    pred_list_test = []
    for row in pred:
        if np.where(row == np.amax(row))[0]==0:
            pred_list_test.append(0)
        elif np.where(row == np.amax(row))[0]==1:
            pred_list_test.append(1)
        else:
            pred_list_test.append(2)

    #Measure of the fit of the model

    print('MNLogit Regression,validation set,classification_report(y_valid,pred_list_test))

    acc_test.append(accuracy_score(y_valid,pred_list_test))
    acc_train.append(accuracy_score(y_train,pred_list_train))

问题是我有y的两个版本,一个版本的类更加不平衡(版本1),另一个版本的类更加平衡(版本2)。

当我在y的版本1中尝试此代码时,它可以完美地工作。但是,当我尝试在版本2上运行它时,有些折叠会返回所有nan的回归...这是一个示例(对长度表示歉意)。这是前两折的结果:

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2251: RuntimeWarning: divide by zero encountered in log

  logprob = np.log(self.cdf(np.dot(self.exog,params)))

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2252: RuntimeWarning: invalid value encountered in multiply

  return np.sum(d * logprob)

Optimization terminated successfully.

         Current function value: nan

         Iterations 14

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in greater

  return (a < x) & (x < b)

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in less

  return (a < x) & (x < b)

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:1912: RuntimeWarning: invalid value encountered in less_equal

  cond2 = cond0 & (x <= _a)

                          MNLogit Regression Results                          

==============================================================================

Dep. Variable:                      y   No. Observations:                13852

Model:                        MNLogit   Df Residuals:                    13814

Method:                           MLE   Df Model:                           36

Date:                Thu,13 Aug 2020   Pseudo R-squ.:                     nan

Time:                        23:04:09   Log-Likelihood:                    nan

converged:                       True   LL-Null:                       -13943.

Covariance Type:              cluster   LLR p-value:                       nan

==============================================================================

       y=1       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1            -0.0012      0.009     -0.126      0.900      -0.020       0.017

x2             0.0001    1.8e-05      6.207      0.000    7.63e-05       0.000

x3            -0.6074      0.621     -0.978      0.328      -1.825       0.610

x4             8.5373      1.219      7.004      0.000       6.148      10.926

x5             0.0136      0.002      5.906      0.000       0.009       0.018

x6             0.0024      0.066      0.037      0.970      -0.127       0.131

x7            -0.0060      0.003     -1.972      0.049      -0.012   -3.76e-05

x8            -0.0263      0.015     -1.695      0.090      -0.057       0.004

x9            -0.0237      0.026     -0.926      0.355      -0.074       0.026

x10           -0.0008      0.002     -0.404      0.686      -0.005       0.003

x11            0.0713      0.031      2.308      0.021       0.011       0.132

x12        -9.272e-05   1.54e-05     -6.003      0.000      -0.000   -6.24e-05

x13           -0.0012      0.000     -4.696      0.000      -0.002      -0.001

x14          5.53e-05   1.06e-05      5.215      0.000    3.45e-05    7.61e-05

x15           -0.0007      0.000     -3.538      0.000      -0.001      -0.000

x16         7.334e-05   6.94e-05      1.056      0.291   -6.27e-05       0.000

x17           -0.0098      0.001     -9.659      0.000      -0.012      -0.008

x18           -0.0506      0.036     -1.409      0.159      -0.121       0.020

x19            0.0953      0.017      5.682      0.000       0.062       0.128

------------------------------------------------------------------------------

       y=2       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1             0.0354      0.025      1.411      0.158      -0.014       0.084

x2             0.0003      0.000      1.996      0.046    5.62e-06       0.001

x3             3.3663      3.177      1.060      0.289      -2.860       9.593

x4            16.6473      8.483      1.962      0.050       0.021      33.273

x5             0.0507      0.026      1.963      0.050    7.82e-05       0.101

x6             0.3423      0.278      1.232      0.218      -0.202       0.887

x7             0.0274      0.026      1.051      0.293      -0.024       0.079

x8             0.0998      0.071      1.397      0.162      -0.040       0.240

x9            -0.0231      0.049     -0.466      0.641      -0.120       0.074

x10            0.0126      0.006      1.969      0.049    5.65e-05       0.025

x11            0.2219      0.129      1.720      0.085      -0.031       0.475

x12           -0.0002    8.6e-05     -2.286      0.022      -0.000    -2.8e-05

x13           -0.0022      0.001     -2.591      0.010      -0.004      -0.001

x14            0.0001   5.35e-05      2.313      0.021    1.89e-05       0.000

x15           -0.0018      0.001     -2.209      0.027      -0.003      -0.000

x16         6.439e-05      0.000      0.468      0.640      -0.000       0.000

x17           -0.8636      0.047    -18.523      0.000      -0.955      -0.772

x18            1.7166      4.104      0.418      0.676      -6.328       9.761

x19            0.0713      0.052      1.375      0.169      -0.030       0.173

==============================================================================

MNLogit Regression,fold  21 :                precision    recall  f1-score   support

 

           0       0.89      0.78      0.83      3679

           1       0.76      0.83      0.80      2738

           2       0.97      1.00      0.98      7435

 

    accuracy                           0.91     13852

   macro avg       0.87      0.87      0.87     13852

weighted avg       0.91      0.91      0.90     13852

 

MNLogit Regression,fold  21 :                precision    recall  f1-score   support

 

           0       0.88      0.78      0.83       920

           1       0.77      0.82      0.79       685

           2       0.97      1.00      0.98      1859

 

    accuracy                           0.90      3464

   macro avg       0.87      0.86      0.87      3464

weighted avg       0.90      0.90      0.90      3464

 

shape xtrain:  (13853,19)

shape ytrain:  (13853,)

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2219: RuntimeWarning: overflow encountered in exp

  eXB = np.column_stack((np.ones(len(X)),np.exp(X)))

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2220: RuntimeWarning: invalid value encountered in true_divide

  return eXB/eXB.sum(1)[:,None]

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\base\optimizer.py:300: RuntimeWarning: invalid value encountered in greater

  oldparams) > tol)):

Optimization terminated successfully.

         Current function value: nan

         Iterations 6

                          MNLogit Regression Results                         

==============================================================================

Dep. Variable:                      y   No. Observations:                13853

Model:                        MNLogit   Df Residuals:                    13815

Method:                           MLE   Df Model:                           36

Date:                Thu,13 Aug 2020   Pseudo R-squ.:                     nan

Time:                        23:04:10   Log-Likelihood:                    nan

converged:                       True   LL-Null:                       -13944.

Covariance Type:              cluster   LLR p-value:                       nan

==============================================================================

       y=1       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1                nan        nan        nan        nan         nan         nan

x2                nan        nan        nan        nan         nan         nan

x3                nan        nan        nan        nan         nan         nan

x4                nan        nan        nan        nan         nan         nan

x5                nan        nan        nan        nan         nan         nan

x6                nan        nan        nan        nan         nan         nan

x7                nan        nan        nan        nan         nan         nan

x8                nan        nan        nan        nan         nan         nan

x9                nan        nan        nan        nan         nan         nan

x10               nan        nan        nan        nan         nan         nan

x11               nan        nan        nan        nan         nan         nan

x12               nan        nan        nan        nan         nan         nan

x13               nan        nan        nan        nan         nan         nan

x14               nan        nan        nan        nan         nan         nan

x15               nan        nan        nan        nan         nan         nan

x16               nan        nan        nan        nan         nan         nan

x17               nan        nan        nan        nan         nan         nan

x18               nan        nan        nan        nan         nan         nan

x19               nan        nan        nan        nan         nan         nan

------------------------------------------------------------------------------

       y=2       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1                nan        nan        nan        nan         nan         nan

x2                nan        nan        nan        nan         nan         nan

x3                nan        nan        nan        nan         nan         nan

x4                nan        nan        nan        nan         nan         nan

x5                nan        nan        nan        nan         nan         nan

x6                nan        nan        nan        nan         nan         nan

x7                nan        nan        nan        nan         nan         nan

x8                nan        nan        nan        nan         nan         nan

x9                nan        nan        nan        nan         nan         nan

x10               nan        nan        nan        nan         nan         nan

x11               nan        nan        nan        nan         nan         nan

x12               nan        nan        nan        nan         nan         nan

x13               nan        nan        nan        nan         nan         nan

x14               nan        nan        nan        nan         nan         nan

x15               nan        nan        nan        nan         nan         nan

x16               nan        nan        nan        nan         nan         nan

x17               nan        nan        nan        nan         nan         nan

x18               nan        nan        nan        nan         nan         nan

x19               nan        nan        nan        nan         nan         nan

==============================================================================

__main__:42: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False,but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

__main__:44: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False,but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

  _warn_prf(average,modifier,msg_start,len(result))

__main__:54: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False,but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

__main__:56: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False,but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

  _warn_prf(average,len(result))

MNLogit Regression,fold  21 :                precision    recall  f1-score   support

 

           0       0.00      0.00      0.00      3679

           1       0.00      0.00      0.00      2739

           2       0.54      1.00      0.70      7435

 

    accuracy                           0.54     13853

   macro avg       0.18      0.33      0.23     13853

weighted avg       0.29      0.54      0.37     13853

 

MNLogit Regression,fold  21 :                precision    recall  f1-score   support

 

           0       0.00      0.00      0.00       920

           1       0.00      0.00      0.00       684

           2       0.54      1.00      0.70      1859

 

    accuracy                           0.54      3463

   macro avg       0.18      0.33      0.23      3463

weighted avg       0.29      0.54      0.38      3463

我不知道这里会发生什么,因为什么都没有真正改变,只有因变量中的值。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...