单元测试pytorch前进功能

问题描述

我想在Pytorch中对我的Network模型的覆盖转发功能进行单元测试。因此,我使用setUp方法加载了模型(从Zoo进行了预训练),加载了种子并创建了一些随机批处理。在我的方法testForward中,我测试了shape和numel对正向的结果,但我还想检查apears的特定值是否为0。我对此并不确定,因此也检查了setUp中的参数,似乎不为0。

import unittest
import torch
from SemanticSegmentation.models.fcn8 import FCN8


class TestFCN8(unittest.TestCase):

    def setUp(self):
        self.model = FCN8(8,pretrained=True)
        torch.manual_seed(0)
        self.x = torch.rand((4,3,45,45))
        for param in self.model.parameters():
            print(param.data)

    def testForward(self):
        self.assertEqual(self.model.forward(self.x).shape.numel(),64800)
        self.assertEqual(str(self.model.forward(self.x).shape),'torch.Size([4,8,45])')
        print(self.model.named_parameters)


if __name__ == "__main__":
    unittest.main()

所以我的问题是:正向张量的sahpe是我所期望的,但是为什么这个张量完全为零?我期望至少有几个值。

导入的modell基于VGG16网络,并对之后的ConvLayer 4、8和16进行了升序处理。

解决方法

确定并调试转发功能后,我得出以下解释:

有关体系结构的一些信息

如果您从Andrew Ng或其他人那里上课,您将学会不将权重初始化为相同的值,例如“ 0”。这就是原始FCN论文的作者所做的,他们说,这是因为它不会改变性能或不会加快收敛速度​​(FCN-Paper)。

我的解决方案

因此,出于测试目的,我在测试模块中初始化了种子随机值,可以针对这些随机值进行测试:

import unittest
import torch
from SemanticSegmentation.models.fcn8 import FCN8


class TestFCN8(unittest.TestCase):

    def setUp(self):
        self.model = FCN8(8,pretrained=True)
        torch.manual_seed(0)
        # instead of zero init for score tensors use random init
        self.model.score_fr[6].weight.data.random_()
        self.model.score_fr[6].bias.data.random_()
        self.model.score_pool3.weight.data.random_()
        self.model.score_pool3.bias.data.random_()
        self.model.score_pool4.weight.data.random_()
        self.model.score_pool4.bias.data.random_()
        self.x = torch.rand((4,3,45,45))

    def testForward(self):
        self.assertEqual(
            self.model.forward(self.x).shape.numel(),64800)
        self.assertEqual(
            list(self.model.forward(self.x).shape),[4,8,45])
        self.assertEqual(
            float(self.model.forward(self.x)[3][4][44][4]),2277257216.0))

if __name__ == "__main__":
    unittest.main()