计算来自中间层的输出的梯度并使用优化器更新权重

问题描述

我正在尝试实施以下架构,不确定是否正确应用了渐变磁带。

LFFD research paper

在上面的体系结构中,我们可以看到,输出取自蓝色框中的多层。每个蓝色框在paper中称为损失分支,其中包含两个损失,即交叉熵和l2损失。我在tensorflow 2中编写了架构,并使用渐变带进行自定义训练。我不确定的一件事是如何使用梯度带来更新损耗。

我有两个查询

  1. 在这种情况下,我应该如何使用梯度带进行多次损失。我对查看代码感兴趣!
  2. 例如,考虑上图中的第三个蓝色框(第三个损失分支),我们将从 conv 13 获取输入,并获得两个输出一个用于分类,另一个用于回归。 因此,在计算了损失后,我应该如何更新权重,应该更新上面的所有层(从conv 1到conv 13)还是应该仅更新获取 conv 13 的层权重(转换11、12和13)。

我还要附上一个link,昨天我在上面详细张贴了一个问题。

下面是我尝试进行梯度下降的代码段。如果我错了,请纠正我。

        images = batch.data[0]
        images = (images - 127.5) / 127.5

        targets = batch.label

        with tensorflow.GradientTape() as tape:
            outputs = self.net(images)
            loss = self.loss_criterion(outputs,targets)
        
        self.scheduler(i,self.optimizer)
        grads = tape.gradient(loss,self.net.trainable_variables)
        self.optimizer.apply_gradients(zip(grads,self.net.trainable_variables))

下面是自定义损失函数代码,用作上面的loss_criterion。

    losses = []
    for i in range(self.num_output_scales):
        pred_score = outputs[i * 2]
        pred_bBox = outputs[i * 2 + 1]
        gt_mask = targets[i * 2]
        gt_label = targets[i * 2 + 1]

        pred_score_softmax = tensorflow.nn.softmax(pred_score,axis=1)
        loss_mask = tensorflow.ones(pred_score_softmax.shape,tensorflow.float32)

        if self.hnm_ratio > 0:
            pos_flag = (gt_label[:,:,:] > 0.5)
            pos_num = tensorflow.math.reduce_sum(tensorflow.cast(pos_flag,dtype=tensorflow.float32)) 
        if pos_num > 0:
            neg_flag = (gt_label[:,1,:] > 0.5)
            neg_num = tensorflow.math.reduce_sum(tensorflow.cast(neg_flag,dtype=tensorflow.float32))
            neg_num_selected = min(int(self.hnm_ratio * pos_num),int(neg_num))
            neg_prob = tensorflow.where(neg_flag,pred_score_softmax[:,:],\
            tensorflow.zeros_like(pred_score_softmax[:,:]))
            neg_prob_sort = tensorflow.sort(tensorflow.reshape(neg_prob,shape=(1,-1)),direction='ASCENDING')
            prob_threshold = neg_prob_sort[0][int(neg_num_selected)]
            neg_grad_flag = (neg_prob <= prob_threshold)
            loss_mask = tensorflow.concat([tensorflow.expand_dims(pos_flag,axis=1),tensorflow.expand_dims(neg_grad_flag,axis=1)],axis=1)
        else:
            neg_choice_ratio = 0.1
            neg_num_selected = int(tensorflow.cast(tensorflow.size(pred_score_softmax[:,:]),dtype=tensorflow.float32) * 0.1)
            neg_prob = pred_score_softmax[:,:]
            neg_prob_sort = tensorflow.sort(tensorflow.reshape(neg_prob,direction='ASCENDING')
            prob_threshold = neg_prob_sort[0][int(neg_num_selected)]
            neg_grad_flag = (neg_prob <= prob_threshold)                
            loss_mask = tensorflow.concat([tensorflow.expand_dims(pos_flag,axis=1)

        pred_score_softmax_masked = tensorflow.where(loss_mask,pred_score_softmax,tensorflow.zeros_like(pred_score_softmax,dtype=tensorflow.float32))
        pred_score_log = tensorflow.math.log(pred_score_softmax_masked)
        score_cross_entropy = - tensorflow.where(loss_mask,gt_label[:,:2,tensorflow.zeros_like(gt_label[:,dtype=tensorflow.float32)) * pred_score_log
        loss_score = tensorflow.math.reduce_sum(score_cross_entropy) / 
        tensorflow.cast(tensorflow.size(score_cross_entropy),tensorflow.float32)

        mask_bBox = gt_mask[:,2:6,:]
        predict_bBox = pred_bBox * mask_bBox
        label_bBox = gt_label[:,:] * mask_bBox
        # l2 loss of Boxes
        # loss_bBox = tensorflow.math.reduce_sum(tensorflow.nn.l2_loss((label_bBox - predict_bBox)) ** 2) / 2
        loss_bBox = mse(label_bBox,predict_bBox) / tensorflow.math.reduce_sum(mask_bBox)

        # Adding only losses relevant to a branch and sending them for back prop
        losses.append(loss_score + loss_bBox)
        # losses.append(loss_bBox)
    
        # Adding all losses and sending to back prop Approach 1
        # loss_cls += loss_score
        # loss_reg += loss_bBox
        # loss_branch.append(loss_score)
        # loss_branch.append(loss_bBox)
        # loss = loss_cls + loss_reg

    return losses

我没有收到任何错误,但损失没有减少到最低。这是我训练的log

请有人帮我解决这个问题。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...