问题描述
def create_variables(name,probabilities,labels):
print('function called')
model = Metrics(probabilities,labels)
prec_curve = model.precision_curve()
kappa_curve = model.kappa_curve()
tpr_curve = model.tpr_curve()
fpr_curve = model.fpr_curve()
pr_auc = auc(tpr_curve,prec_curve)
roc_auc = auc(fpr_curve,tpr_curve)
auk = auc(fpr_curve,kappa_curve)
return [name,prec_curve,kappa_curve,tpr_curve,fpr_curve,pr_auc,roc_auc,auk]
我有以下变量:
svm = pd.read_csv('SVM.csv')
svm_prob_1 = svm.probability[svm.fold_number == 1]
svm_prob_2 = svm.probability[svm.fold_number == 2]
svm_label_1 = svm.true_label[svm.fold_number == 1]
svm_label_2 = svm.true_label[svm.fold_number == 2]
我想执行以下几行:
svm1 = create_variables('svm_fold1',svm_prob_1,svm_label_1)
svm2 = create_variables('svm_fold2',svm_prob_2,svm_label_2)
Python 在 svm1 上按预期工作。但是,当它开始处理 svm2 时,我收到以下错误:
svm2 = create_variables('svm_fold2',svm_label_2)
function called
Traceback (most recent call last):
File "<ipython-input-742-702cfac4d100>",line 1,in <module>
svm2 = create_variables('svm_fold2',svm_label_2)
File "<ipython-input-741-b8b5a84f0298>",line 6,in create_variables
prec_curve = model.precision_curve()
File "<ipython-input-734-dd9c309be961>",line 59,in precision_curve
self.tp,self.tn,self.fp,self.fn = self.confusion_matrix(self.preds)
File "<ipython-input-734-dd9c309be961>",line 72,in confusion_matrix
if pred == self.labels[i]:
File "C:\Users\20200016\AppData\Local\Continuum\anaconda3\lib\site-packages\pandas\core\series.py",line 1068,in __getitem__
result = self.index.get_value(self,key)
File "C:\Users\20200016\AppData\Local\Continuum\anaconda3\lib\site-packages\pandas\core\indexes\base.py",line 4730,in get_value
return self._engine.get_value(s,k,tz=getattr(series.dtype,"tz",None))
File "pandas\_libs\index.pyx",line 80,in pandas._libs.index.IndexEngine.get_value
File "pandas\_libs\index.pyx",line 88,line 131,in pandas._libs.index.IndexEngine.get_loc
File "pandas\_libs\hashtable_class_helper.pxi",line 992,in pandas._libs.hashtable.Int64HashTable.get_item
File "pandas\_libs\hashtable_class_helper.pxi",line 998,in pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 0
svm_prob_1
和 svm_prob_2
的形状相同并且包含非零值。 svm_label_2
包含 0 和 1,长度与 svm_prob_2 相同。
此外,错误似乎在 svm_label_1
中。更改此变量后,以下行确实有效:
svm2 = create_variables('svm_fold2',svm_label_1
根据下面的代码,svm_label_1
和 svm_label_2
之间似乎没有区别。
type(svm_label_1)
Out[806]: pandas.core.series.Series
type(svm_label_2)
Out[807]: pandas.core.series.Series
min(svm_label_1)
Out[808]: 0
min(svm_label_2)
Out[809]: 0
max(svm_label_1)
Out[810]: 1
max(svm_label_2)
Out[811]: 1
sum(svm_label_1)
Out[812]: 81
sum(svm_label_2)
Out[813]: 89
len(svm_label_1)
Out[814]: 856
len(svm_label_2)
Out[815]: 856
有人知道这里出了什么问题吗?
解决方法
我不知道它为什么有效,但将 svm_label_2
转换为列表有效:
svm_label_2 = list(svm.true_label[svm.fold_number == 2])
既然 svm_label_1
和 svm_label_2
是同一种类型,我不明白为什么后者引发错误而第一个没有。因此,我仍然欢迎对这种现象进行任何解释。