问题描述
def AdaIN(x):
#normalize x[0] (image representation)
mean = K.mean(x[0],axis = [1,2],keepdims = True)
std = K.std(x[0],keepdims = True) + 1e-7
y = (x[0] - mean) / std
#Reshape scale and bias parameters
pool_shape = [-1,1,y.shape[-1]]
scale = K.reshape(x[1],pool_shape)
bias = K.reshape(x[2],pool_shape)#Multiply by x[1] (GAMMA) and add x[2] (BETA)
return y * scale + bias
def g_block(input_tensor,latent_vector,filters):
gamma = Dense(filters,bias_initializer = 'ones')(latent_vector)
beta = Dense(filters)(latent_vector)
out = UpSampling2D()(input_tensor)
out = Conv2D(filters,3,padding = 'same')(out)
out = Lambda(AdaIN)([out,gamma,beta])
out = Activation('relu')(out)
return out
请参见上面的代码。我目前正在学习styleGAN。我正在尝试将此代码转换为pytorch,但我似乎无法理解Lambda在g_block中做什么。 AdaIN根据其声明仅需要一个输入,但是如何将gamma和beta用作输入?请告诉我Lambda在此代码中的作用。
非常感谢您。
解决方法
keras
中的 Lambda层用于在模型内部调用自定义函数。在g_block
中,Lambda
调用AdaIN
函数,并将out,gamma,beta
作为参数传递到列表中。 AdaIN
函数将这三个张量封装为x
封装在一个列表中。而且,这些张量也可以通过索引列表AdaIN
(x [0],x [1],x [2])在x
函数内部访问。
相当于pytorch
:
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaIN(nn.Module):
def forward(self,out,beta):
bs,ch = out.size()[:2]
mean = out.reshape(bs,ch,-1).mean(dim=2).reshape(bs,1,1)
std = out.reshape(bs,-1).std(dim=2).reshape(bs,1) + 1e-7
y = (out - mean) / std
bias = beta.unsqueeze(-1).unsqueeze(-1).expand_as(out)
scale = gamma.unsqueeze(-1).unsqueeze(-1).expand_as(out)
return y * scale + bias
class g_block(nn.Module):
def __init__(self,filters,latent_vector_shape,input_tensor_channels):
super().__init__()
self.gamma = nn.Linear(in_features = latent_vector_shape,out_features = filters)
# Initializes all bias to 1
self.gamma.bias.data = torch.ones(filters)
self.beta = nn.Linear(in_features = latent_vector_shape,out_features = filters)
# calculate appropriate padding
self.conv = nn.Conv2d(input_tensor_channels,3,padding=1)# calc padding
self.adain = AdaIN()
def forward(self,input_tensor,latent_vector):
gamma = self.gamma(latent_vector)
beta = self.beta(latent_vector)
# check default interpolation mode in keras and replace mode below if different
out = F.interpolate(input_tensor,scale_factor=2,mode='nearest')
out = self.conv(out)
out = self.adain(out,beta)
out = torch.relu(out)
return out
# Sample:
input_tensor = torch.randn((1,10,10))
latent_vector = torch.randn((1,5))
g = g_block(3,latent_vector.shape[1],input_tensor.shape[1])
out = g(input_tensor,latent_vector)
print(out)
注意:创建latent_vector
时需要传递input_tensor
和g_block
形状。