python – PyTorch,nn.Sequential(),访问nn.Sequential()中特定模块的权重

这应该是一个快速的.当我在PyTorch中使用预定义模块时,我通常可以非常轻松地访问其权重.但是,如果我先将模块包装在nn.Sequential()中,如何访问它们?请看下面的玩具示例

class My_Model_1(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_1,self).__init__()
        self.layer = nn.Linear(D_in,D_out)
    def forward(self,x):
        out = self.layer(x)
        return out

class My_Model_2(nn.Module):
    def __init__(self,D_out):
        super(My_Model_2,self).__init__()
        self.layer = nn.Sequential(nn.Linear(D_in,D_out))
    def forward(self,x):
        out = self.layer(x)
        return out

model_1 = My_Model_1(10,10)
print(model_1.layer.weight)
model_2 = My_Model_2(10,10)
# How do I print the weights Now?
# model_2.layer.0.weight doesn't work.

解决方法

访问权重的一个简单权重是使用模型的state_dict().

这适用于您的情况:

for k,v in model_2.state_dict().iteritems():
    print("Layer {}".format(k))
    print(v)

一个选择是获取modules()迭代器.如果您事先知道图层的类型,这也应该有效:

for layer in model_2.modules():
   if isinstance(layer,nn.Linear):
        print(layer.weight)

相关文章

功能概要:(目前已实现功能)公共展示部分:1.网站首页展示...
大体上把Python中的数据类型分为如下几类: Number(数字) ...
开发之前第一步,就是构造整个的项目结构。这就好比作一幅画...
源码编译方式安装Apache首先下载Apache源码压缩包,地址为ht...
前面说完了此项目的创建及数据模型设计的过程。如果未看过,...
python中常用的写爬虫的库有urllib2、requests,对于大多数比...