混合密度网络的K折交叉验证

问题描述

我正在尝试使用MSE和R ^ 2得分作为计算交叉验证得分的指标,将k倍交叉验证应用于混合物密度模型。 (我受到this article的启发,该技术使用这种技术来评估竞争模型)。

该模型有效并且 y_pred 很有意义,但是我不确定如何解释这些指标,甚至不确定它们是否正确计算。如果 y_true 的形状为(1000,1),则3种混合物的 y_pred 的形状为(1000,6)。

因此,我的第一个问题是,在计算MSE和R ^ 2分数时,Keras对 y_pred 使用什么价值? (由于该代码而获得的值根本没有意义。)如果这些值不是正确的值,那么我应该使用哪些值来计算交叉验证得分?

谢谢。

# k-folds cross-validator
num_folds = 5
kfold = KFold(n_splits=num_folds,shuffle=True)
cvscores_mse = []
cvscores_r2 = []

def r2_score(y_true,y_pred):
    SS_res = bk.sum(bk.square(y_true - y_pred))
    SS_tot = bk.sum(bk.square(y_true - bk.mean(y_true)))
    return (1 - SS_res/(SS_tot + bk.epsilon()))

                                                                                                
# Train model for each fold
for train,test in kfold.split(X_train,y_train):
    model = ks.Model(inputs=inputs,outputs=outputVector)
    model.compile(optimizer=ks.optimizers.Adam(learning_rate=lr,clipvalue=1.0),loss=mean_log_gaussian_like,metrics=['mse',r2_score])

    # Fit the model
    history = model.fit(X_train[train],y_train[train],validation_data=(X_train[test],y_train[test]),batch_size=batch,epochs=epoch)

    # Preserve the history 
    mse = history.history['mse']
    val_mse = history.history['val_mse']
    r2_score = history.history['r2_score']
    val_r2_score= history.history['val_r2_score']

    # Evaluate the model on the test data using 'evaluate'
    scores = model.evaluate(X_train[test],y_train[test],verbose=0,batch_size=128)
    print("%s: %.2f%%" % (model.metrics_names[1],scores[1] * 100))
    print("%s: %.2f%%" % (model.metrics_names[2],scores[2] * 100))
    print("Validation loss,Validation RMSE,Validation R^2:",scores)

    cvscores_mse.append(scores[1] * 100)
    cvscores_r2.append(scores[1] * 100)
    print("%.2f%%(+/-%.2f%%)" % (np.mean(cvscores_r2),np.std(cvscores_r2)))

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...