具有keras的具有损失功能的批次元素产品

问题描述

我正在尝试在keras中编写自定义损失函数,需要通过中间层的输出(形状为(batch_size,(batch_size,64,64))在y_true和y_pred(形状:(b​​atch_size,64,64))之间加权MSE。 1))。

我需要的操作只是将MSE的每个批处理元素权重(乘以一个因子),即weight_tensor中的相应批处理元素。

我尝试了以下

def loss_aux_wrapper(weight_tensor):
    def loss_aux(y_true,y_pred):
        K.print_tensor(weight_tensor,message='weight = ')
        _shape = K.shape(y_true)
        return K.reshape(K.batch_dot(K.batch_flatten(mse(y_true,y_pred)),weight_tensor,axes=[1,1]),_shape)
    return loss_aux

但我知道

tensorflow.python.framework.errors_impl.InvalidArgumentError:  In[0] mismatch In[1] shape: 4096 vs. 1: [32,1,4096] [32,1] 0 0      [[node loss/aux_motor_output_loss/MatMul (defined at /code/icub_sensory_enhancement/scripts/models.py:327) ]] [Op:__inference_keras_scratch_graph_6548]

K.print_tensor不会输出任何内容,我相信是因为它是在编译时调用的?

任何建议都将不胜感激!

解决方法

为了加权MSE损失函数,可以在调用sample_weight=函数时使用mse参数。根据{{​​3}},

如果 sample_weight是大小[batch_size]的张量,则总损耗 对于该批次的每个样品,将通过相应的元素重新缩放 在sample_weight向量中。如果sample_weight的形状是 [batch_size,d0,.. dN-1](或可以广播到此形状),然后 y_pred的每个损失元素均按的相应值缩放 sample_weight。 (请注意onN-1:所有损失函数按一维减少, 通常为axis=-1。)

在您的情况下,weight_tensor的形状为( batch size,1 )。所以首先我们需要重塑它,

reshaped_weight_tensor = K.reshape( weight_tensor,shape=( batch_size ) )

然后该张量可以与MeanSquaredError一起使用,

def loss_aux_wrapper(weight_tensor):

    def loss_aux(y_true,y_pred):
        reshaped_weight_tensor = K.reshape( weight_tensor,shape=( batch_size ) )
        return mse( y_true,y_pred,sample_weight=reshaped_weight_tensor )
        
    return loss_aux