问题描述
我想提取处理 forward
函数中定义的输入数据的模型层。例如,给定以下源代码:
def forward(self,inputs):
inputs = self.embedding(inputs)
inputs = F.dropout(inputs,0.25,self.training)
return inputs
我想提取处理输入数据的层,即:
input -> embedding -> dropout -> output
如何在不运行代码的情况下执行此操作?
解决方法
model.children()
将遍历网络中的子模块(层)。