问题描述
我想从模型中删除Batchnorm。因此,我考虑将其与Linear融合。我的模型结构如下:
-
Linear -> ReLU -> Batchnorm -> Dropout -> Linear
我尝试融合Batchnorm -> Linear
,但无法与可用代码融合。有什么方法可以将Batchnorm与以上任何一层融合。
解决方法
class DummyModule_1(nn.Module):
def __init__(self):
super(DummyModule_1,self).__init__()
def forward(self,x):
# print("Dummy,Dummy.")
return x
def fuse_1(linear,bn):
w = linear.weight
print(w.size())
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight
gamma = bn.bias
if linear.bias is not None:
b = linear.bias
else:
b = mean.new_zeros(mean.shape)
w = w.cuda()
b = b.cuda()
w = w * (beta / var_sqrt).reshape([4096,1])
b = (b - mean)/var_sqrt * beta + gamma
fused_linear = nn.Linear(linear.in_features,linear.out_features)
fused_linear.weight = nn.Parameter(w)
fused_linear.bias = nn.Parameter(b)
return fused_linear
def fuse_module_1(m):
children = list(m.named_children())
c = None
cn = None
global c1
global count
global c18
for name,child in children:
print("name is",name,"child is",child)
if name == 'linear':
count = count+1
if count == 2:
c18 = child
print("c18 is",c18)
else:
fuse_module_1(child)
if name =='2' and isinstance(child,nn.BatchNorm1d):
print("child is",child)
bc = fuse_1(c18,child)
m.classifier[1].linear = bc
m.classifier[2] = DummyModule_1(
else:
#fuse_module_1(child)
fuse_module_1(child)```