RNN变分自动编码器可产生良好的重建效果,但生成效果较差 #1 #2

问题描述

我正在尝试通过训练基于RNN的变分自动编码器来重现此serial in postgres is being increased even though I added on conflict do nothing的结果。虽然原始文本的重建效果很好,但是新文本的生成却很糟糕。我在下面给出了我的模型架构。它大致基于此paper

class SentenceVAE(nn.Module):
    def __init__(self,embedding_size,vocab_size,hidden_size,latent_dim,dropout,device,max_len = 50,pad_idx = 0,start_idx = 1,end_idx = 2,unk_idx = 3):
        super(SentenceVAE,self).__init__()
        
        self.tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
        self.embed = nn.Embedding(vocab_size,pad_idx)
        self.hidden_to_mu = nn.Linear(hidden_size,latent_dim)
        self.hidden_to_logvar = nn.Linear(hidden_size,latent_dim)
        self.dropout = nn.Dropout(dropout)
        self.encoder_gru = nn.GRU(embedding_size,batch_first=True)
        self.decoder_gru = nn.GRU(embedding_size,batch_first=True)
        self.flow_fc = nn.Sequential(
            nn.Linear(latent_dim,1024),nn.GELU(),nn.Linear(1024,hidden_size)
        )
        self.out = nn.Linear(hidden_size,vocab_size)
        self.device = device
        self.latent_dim = latent_dim
        self.unk_idx = unk_idx
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.pad_idx = pad_idx
        
    def reparameterize(self,mu,logvar):
        eps = torch.randn_like(logvar)
        std = torch.exp(0.5 * logvar)
        return mu + eps * std
    
    def decode(self,hidden,dec_in):
        decoder_input = self.embed(dec_in)
        if len(hidden.size()) < 3:
            hidden = hidden.unsqueeze(0)
        outputs,hidden = self.decoder_gru(decoder_input,hidden)
        out = self.out(outputs)
        return out,hidden
    
    def sample_sentence(self,z = None):
        max_len = 20
        batch = 1
        if z == None:
            z = torch.randn((batch,self.latent_dim))
            z = z.to(self.device)
        hidden = self.flow_fc(z)
        pred = [[self.start_idx]]
        out_sent = [] 
        for i in range(max_len):
            pred_tensor = torch.tensor(pred)
            pred_tensor = pred_tensor.to(device)
            preds,hidden = self.decode(hidden,pred_tensor)
            preds = preds[:,-1,:]
            pred_index = torch.argmax(preds,dim = -1)
            pred[0] = [pred_index.item()]
            out_sent.append(pred_index.item())
            if pred_index.item() == self.end_idx:
                break
        return out_sent
     
          
    def forward(self,x):
        enc_in,dec_in = x,x
        encoder_input = self.embed(enc_in)
        _,rnn_hidden = self.encoder_gru(encoder_input)
        rnn_hidden = rnn_hidden.squeeze(0)
        mu = self.hidden_to_mu(rnn_hidden)
        logvar = self.hidden_to_logvar(rnn_hidden)
        z = self.reparameterize(mu,logvar)
        ## Randomly replace words with <unk>
        dec_in_copy = dec_in.clone()
        prob = torch.rand(dec_in.size())
        prob[(dec_in == self.start_idx) | (dec_in == self.end_idx) | (dec_in == self.pad_idx)] = 1
        dec_in_copy[prob < dropout] = self.unk_idx
        hidden = self.flow_fc(z)
        out,_= self.decode(hidden,dec_in_copy)
        
        return mu,logvar,out

这里有一些样本重建输出,第一行是输入,第二行是输出

#1

<SOS> gondry 's direction is adequate ... but what gives human nature its unique feel is kaufman 's script . <EOS> 
<SOS> it 's direction is adequate ... but what gives human nature its unique feel is kaufman 's approach . <EOS> 

#2

<SOS> there seems to be no clear path as to where the story 's going,or how long it 's going to take to get there . <EOS> 
<SOS> there seems to be no amount path,to where the most 's going,or even long it 's going to take to get there . <EOS> 

现在,如果我在类sample_sentence中使用SentenceVAE方法生成新句子,则输出始终为:

<SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> 

在调试时,我注意到输出始终是pred = [[self.start_idx]]内部的令牌,该令牌重复max_len次。在上述情况下,<SOS>sample_sentence中的输入令牌。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)