序数回归算法似乎将预测转移了一类

问题描述

对于具有 3 个类 (1,2,3) 的序数回归问题,我正在运行以下算法:

class OrdinalClassifier():

    def __init__(self,clf):
        self.clf = clf
        self.clfs = {}

    def fit(self,X,y):
        self.unique_class = np.sort(np.unique(y))
        if self.unique_class.shape[0] > 2:
            for i in range(self.unique_class.shape[0]-1):
                # for each k - 1 ordinal value we fit a binary classification problem
                binary_y = (y > self.unique_class[i]).astype(np.uint8)
                clf = clone(self.clf)
                clf.fit(X,binary_y)
                self.clfs[i] = clf

    def predict_proba(self,X):
        clfs_predict = {k:self.clfs[k].predict_proba(X) for k in self.clfs}
        predicted = []
        for i,y in enumerate(self.unique_class):
            if i == 0:
                # V1 = 1 - Pr(y > V1)
                predicted.append(1 - clfs_predict[y-1][:,1])
            #elif y in clfs_predict:
            elif y < self.unique_class.shape[0]:
                # Vi = Pr(y > Vi-1) - Pr(y > Vi)
                predicted.append(clfs_predict[y-2][:,1] - clfs_predict[y-1][:,1])
            else:
                # Vk = Pr(y > Vk-1)
                predicted.append(clfs_predict[y-2][:,1])
        return np.vstack(predicted).T

    def predict(self,X):
        return np.argmax(self.predict_proba(X),axis=1)

我通过给它一个 clf 来称呼它:

clf = RandomForestClassifier()
forest = OrdinalClassifier(clf)

并通过调用 fit 来训练它:

forest.fit(X_train,y_train)

最后我通过调用得到预测:

#add 1 so 0->1,1 -> 2,2 -> 3
pred  = forest.predict(y_test) + 1

我相信这个算法应该按照我的意图工作。然而,在使用不同的超参数集运行模型时,这些类似乎以某种方式混淆了。对于几乎所有超参数集,我都发现了相同的模式。

  1. 预测为第 1 类的类占实际第 2 类的百分比最大
  2. 预测为第 2 类的类占实际第 3 类的百分比最大
  3. 预测为第 3 类的类占实际第 1 类的百分比最大

我觉得实际上似乎找到正确序数的一两个超参数组合更多是由运气造成的,而不是实际上产生了一个好的模型。总之,我的算法似乎确实在数据中找到了序数关系,但这些关系不正确/解码不正确。

问题:我的模型有问题吗?我应该只选择在验证集上按预期执行的超参数的确切组合吗?或者我应该将我的预测解码为 3 类 => 1 类,1 类 => 2 类,2 类 => 3 类?

预先感谢您的帮助!

编辑:对于任何感兴趣的人,我的 LGBM 模型开始以正确顺序看到模式的参数是当学习率变得非常高(接近或等于 1)而参数为l1 正则化 >0.

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...