Pytorch 转换器前向函数掩码解码器前向函数的实现

问题描述

我正在尝试使用和学习带有 DeepMind 数学数据集的 PyTorch Transformer。我已经标记了输入模型的(字符不是单词)序列。模型转发功能是对编码器进行一次转发,对解码器进行多次转发(直到所有批次输出都达到令牌,这仍然是待办事项)。 我正在努力处理 Transformer 掩码和解码器,因为它会引发错误

    k = k.contiguous().view(-1,bsz * num_heads,head_dim).transpose(0,1)
    RuntimeError: shape '[-1,24,64]' is invalid for input of size 819200.

源是 N = 32,S = 50,E = 512。目标是 N = 32,S = 3,E = 512。 可能是我对掩码的实现有误,或者源和目标长度不同,不太确定。

class PositionalEncoding(nn.Module):   
# function to positionally encode src and target sequencies 
def __init__(self,d_model,dropout=0.1,max_len=5000):
    super(PositionalEncoding,self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    pe = torch.zeros(max_len,d_model)
    position = torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0,2).float() * (-math.log(10000.0) / d_model))
    pe[:,0::2] = torch.sin(position * div_term)
    pe[:,1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0,1)
    self.register_buffer('pe',pe)

def forward(self,x):
    x = x + self.pe[:x.size(0),:]
    return self.dropout(x)

class MyTransformerModel(nn.Module):
# should implement init and forward function
# define separate functions for masks
# define forward function with
# implement:
#  embedding layer
#  positional encoding
#  encoder layer
#  decoder layer
#  final classification layer
# encoder -> forward once
# decoder -> forward multiple times (for one encoder forward)
# decoder output => concatenate to input e.g. decoder_input = torch.cat([decoder_input],[decoder_output])
# early stopping => all in batch reach <eos> token
def __init__(self,vocab_length = 30,sequence_length = 512,num_encoder_layers = 3,num_decoder_layers = 2,num_hidden_dimension = 256,Feed_forward_dimensions = 1024,attention_heads = 8,dropout = 0.1,pad_idx = 3,device = "cpu",batch_size = 32):
    super(MyTransformerModel,self).__init__()
    self.src_embedding = nn.Embedding(vocab_length,sequence_length)
    self.pos_encoder = PositionalEncoding(sequence_length,dropout)
    self.src_mask = None # attention mask
    self.memory_mask = None # attention mask
    self.pad_idx = pad_idx        
    self.device = device        
    self.batch_size = batch_size
    self.transformer = nn.Transformer(
        sequence_length,attention_heads,num_encoder_layers,num_decoder_layers,Feed_forward_dimensions,dropout,)
    
def src_att_mask(self,src_len):
    mask = (torch.triu(torch.ones(src_len,src_len)) == 1).transpose(0,1)
    mask = mask.float().masked_fill(mask == 0,float('-inf')).masked_fill(mask == 1,float(0.0))
    return mask


def no_peak_att_mask(self,batch_size,src_len,time_step):
    mask = np.zeros((batch_size,src_len),dtype=bool)
    mask[:,time_step: ] = 1 # np.NINF
    mask = torch.from_numpy(mask)
    return mask

def make_src_key_padding_mask(self,src):
    # mask "<pad>"
    src_mask = src.transpose(0,1) == self.pad_idx
    return src_mask.to(self.device)

def make_trg_key_padding_mask(self,trg):
    tgt_mask = trg.transpose(0,1) == self.pad_idx
    return tgt_mask.to(self.device)


def forward(self,src,trg):
    src_seq_length,N = src.shape
    trg_seq_length,N = trg.shape
    embed_src = self.src_embedding(src)
    position_embed_src =  self.pos_encoder(embed_src)
    embed_trg = self.src_embedding(trg)
    position_embed_trg = self.pos_encoder(embed_trg)        
    src_padding_mask = self.make_src_key_padding_mask(src)
    trg_padding_mask = self.make_trg_key_padding_mask(trg)
    trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(self.device)
    time_step = 1
    att_mask = self.no_peak_att_mask(self.batch_size,src_seq_length,time_step).to(self.device)
    encoder_output = self.transformer.encoder.forward(position_embed_src,src_key_padding_mask = src_padding_mask)
    # Todo : implement loop for transformer decoder forward fn,implement early stopping
    # where to Feed decoder_output?
    decoder_output = self.transformer.decoder.forward(position_embed_trg,encoder_output,trg_mask,att_mask,trg_padding_mask,src_padding_mask)
    return decoder_output
    

谁能指出我哪里出错了?

解决方法

看起来我弄乱了尺寸顺序(因为 Transformer 没有批量优先选项)。更正后的代码如下:

class MyTransformerModel(nn.Module):
def __init__(self,d_model = 512,vocab_length = 30,sequence_length = 512,num_encoder_layers = 3,num_decoder_layers = 2,num_hidden_dimension = 256,feed_forward_dimensions = 1024,attention_heads = 8,dropout = 0.1,pad_idx = 3,device = "CPU",batch_size = 32):
    #,ninp,device,nhead=8,nhid=2048,nlayers=2,dropout=0.1,src_pad_idx = 1,max_len=5000,forward_expansion= 4):
    super(MyTransformerModel,self).__init__()
    self.src_embedding = nn.Embedding(vocab_length,d_model)
    self.pos_encoder = PositionalEncoding(d_model,dropout)
    self.vocab_length = vocab_length
    self.d_model = d_model
    self.src_mask = None # attention mask
    self.memory_mask = None # attention mask
    self.pad_idx = pad_idx        
    self.device = device        
    self.batch_size = batch_size
    self.transformer = nn.Transformer(
        d_model,attention_heads,num_encoder_layers,num_decoder_layers,feed_forward_dimensions,dropout,)

    self.fc = nn.Linear(d_model,vocab_length)
    # self.init_weights() <= used in tutorial

def src_att_mask(self,src_len):
    mask = (torch.triu(torch.ones(src_len,src_len)) == 1).transpose(0,1)
    mask = mask.float().masked_fill(mask == 0,float('-inf')).masked_fill(mask == 1,float(0.0))
    return mask


def no_peak_att_mask(self,batch_size,src_len,time_step):
    mask = np.zeros((batch_size,src_len),dtype=bool)
    mask[:,time_step: ] = 1 # np.NINF
    mask = torch.from_numpy(mask)
    # mask = mask.float().masked_fill(mask == 0,float(0.0))
    return mask

def make_src_key_padding_mask(self,src):
    # mask "<pad>"
    src_mask = src.transpose(0,1) == self.pad_idx
    # src_mask = src == self.pad_idx
    # (N,src_len)
    return src_mask.to(self.device)

def make_trg_key_padding_mask(self,trg):
    # same as above -> expected tgt_key_padding_mask: (N,T)
    tgt_mask = trg.transpose(0,1) == self.pad_idx
    # tgt_mask = trg == self.pad_idx
    # (N,src_len)
    return tgt_mask.to(self.device)


def init_weights(self):
    initrange = 0.1
    nn.init.uniform_(self.encoder.weight,-initrange,initrange)
    nn.init.zeros_(self.decoder.weight)
    nn.init.uniform_(self.decoder.weight,initrange)

def forward(self,src,trg):
    N,src_seq_length = src.shape
    N,trg_seq_length = trg.shape        
    #  S - source sequence length
    #  T - target sequence length
    #  N - batch size
    #  E - feature number
    #  src: (S,N,E) (sourceLen,batch,features)
    #  tgt: (T,E)
    #  src_mask: (S,S)
    #  tgt_mask: (T,T)
    #  memory_mask: (T,S)
    #  src_key_padding_mask: (N,S)
    #  tgt_key_padding_mask: (N,T)
    #  memory_key_padding_mask: (N,S)
    src = rearrange(src,'n s -> s n')
    trg = rearrange(trg,'n t -> t n')
    print("src shape {}".format(src.shape))
    print(src)
    print("trg shape {}".format(trg.shape))
    print(trg)

    embed_src = self.src_embedding(src)
    print("embed_src shape {}".format(embed_src.shape))
    print(embed_src)
    position_embed_src =  self.pos_encoder(embed_src)
    print("position_embed_src shape {}".format(position_embed_src.shape))
    print(position_embed_src)
    embed_trg = self.src_embedding(trg)
    print("embed_trg shape {}".format(embed_trg.shape))
    print(embed_trg)
    position_embed_trg = self.pos_encoder(embed_trg)
    # position_embed_trg = position_embed_trg.transpose(0,1)
    print("position_embed_trg shape {}".format(position_embed_trg.shape))
    print(position_embed_trg)
    src_padding_mask = self.make_src_key_padding_mask(src)
    print("KEY - src_padding_mask shape {}".format(src_padding_mask.shape))
    print("should be of shape: src_key_padding_mask: (N,S)")
    print(src_padding_mask)
    trg_padding_mask = self.make_trg_key_padding_mask(trg)
    print("KEY - trg_padding_mask shape {}".format(trg_padding_mask.shape))
    print("should be of shape: trg_key_padding_mask: (N,T)")
    print(trg_padding_mask)
    trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(self.device)
    print("trg_mask shape {}".format(trg_mask.shape))
    print("trg_mask should be of shape tgt_mask: (T,T)")
    print(trg_mask)
    # att_mask = self.src_att_mask(trg_seq_length).to(self.device)
    time_step = 1
    # error => memory_mask: expected shape! (T,S) !!! this is not a key_padding_mask!
    # att_mask = self.no_peak_att_mask(self.batch_size,src_seq_length,time_step).to(self.device)
    # print("att_mask shape {}".format(att_mask.shape))
    # print("att_mask should be of shape  memory_mask: (T,S)")
    # print(att_mask)
    att_mask = None
    # get encoder output
    # forward(self,src: Tensor,mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None)
    # forward encoder just once for a batch
    # attention forward of encoder expects => src,src_mask,src_key_padding_mask +++ possible positional encoding error !!!
    encoder_output = self.transformer.encoder.forward(position_embed_src,src_key_padding_mask = src_padding_mask)
    print("encoder_output")  
    print("encoder_output shape {}".format(encoder_output.shape))
    print(encoder_output)  
    # forward decoder till all in batch did not reach <eos>?
    # def forward(self,tgt: Tensor,memory: Tensor,tgt_mask: Optional[Tensor] = None,# memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,# memory_key_padding_mask: Optional[Tensor] = None)
    # first forward
    decoder_output = self.transformer.decoder.forward(position_embed_trg,encoder_output,trg_mask,att_mask,trg_padding_mask,src_padding_mask)
    # TODO: target in => target out shifted by one,loop till all in batch meet stopping criteria || max len is reached
    # 
    print("decoder_output")  
    print("decoder_output shape {}".format(decoder_output.shape))
    print(decoder_output)
    
    output = rearrange(decoder_output,'t n e -> n t e')
    output =  self.fc(output)
    print("output")  
    print("output shape {}".format(output.shape))
    print(output)

    predicted = F.log_softmax(output,dim=-1)
    print("predicted")  
    print("predicted shape {}".format(predicted.shape))
    print(predicted)
    # top k
    top_value,top_index = torch.topk(predicted,k=1)
    top_index = torch.squeeze(top_index)
    print("top_index")  
    print("top_index shape {}".format(top_index.shape))
    print(top_index)
    print("top_value")  
    print("top_value shape {}".format(top_value.shape))
    print(top_value)
    return top_index