在 Tensorflow 中实现深度集成学习以进行不确定性估计

问题描述

我正在尝试在 Tensorflow 中实现深度集成模型。具体来说,我试图重现论文“使用深度集成Balaji Lakshminarayanan et al.的简单和可扩展的预测不确定性估计”中讨论的结果。作者指出,可以使用一组神经网络来估计预测的不确定性。对于回归问题,作者指示在对应于均值预测和方差的最后一层输出两个值。要最小化的成本函数是负对数似然。我已经实现了如下,

def base_model():
        # Base model upon which ensembles will be built
        inputs = Input(shape=(1,))
        x = Dense(50,activation='relu')(inputs)
        x = Dense(50,activation='relu')(x)
        # Note two outputs for the final layer corresponding to mean and variance
        x = Dense(2,activation=None)(x)

        model = keras.Model(inputs,dist)
        model.compile(keras.optimizers.Adam(learning_rate=self.lr),loss=NLL)
        return model

def NLL(ytrue,ypred):
        # Negative log-likelihood activation function
        # variance is sotfmax of the NN output so that it is always positive number
        var = tf.math.log(1.0+tf.math.exp(ypred[...,1:2])) + 1e-6
        # Negative log-likelihood formula
        NLL = tf.math.log(var)*0.5 + 0.5*tf.math.divide(tf.math.square(ytrue-ypred[...,0:1]),var)
        return NLL

def train(self,xtrain,ytrain,batch,epochs,validation_data=None):
        h = []
        models = []
        # Train 10 models
        for i in range(10):
            # Get model
            models.append(base_model())
            # Train model
            h1=models[i].fit(xtrain,batch_size=batch,validation_data=validation_data,epochs=epochs)
            h.append(h1)
        return h,models

现在根据上面引用的论文,可以使用神经网络的集成来估计方差。我训练了一个由 10 个神经网络组成的集成,并试图在论文中重现图 1(最后一个面板),如下

enter image description here

在此图中,红点是训练示例,灰色带是预测均值和 3 个标准差。需要注意的重要一点是,在没有训练示例的区域中方差很高。但是,在我的实施中,方差没有显示出这样的趋势(下图示例)。我尝试了各种配置,但无法重现论文的结果。然而,我在 medium 中找到了一篇文章,其中已经完成了同一篇论文的实现,但该实现与我的截然不同且更复杂。

我想知道我的实现中是否存在一些逻辑错误

enter image description here

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...