变压器图像字幕模型仅产生填充而不是字幕

问题描述

我正在尝试生成一个模型,该模型将使用resnet作为编码器,使用变压器作为解码器,并使用COCO作为数据库来为图像生成标题

在对我的模型进行了10个时期的训练之后,我的模型没有产生除单词<pad>之外的任何内容,这意味着经过模型后的唯一结果仅产生了与<pad>相对应的标记0

使用调试器后,似乎错误发生在argmax处,其中输出仅变为零而不是其他任何值,但是我不知道如何解决它,这是我的模型是否有问题,或者是方式它受过训练吗?

如果有帮助,我会基于this github建立我的模型。

下载COCO模型的脚本在这里

Download.sh

mkdir data
wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P ./data/
wget http://images.cocodataset.org/zips/train2014.zip -P ./data/
wget http://images.cocodataset.org/zips/val2014.zip -P ./data/

unzip ./data/captions_train-val2014.zip -d ./data/
rm ./data/captions_train-val2014.zip
unzip ./data/train2014.zip -d ./data/
rm ./data/train2014.zip 
unzip ./data/val2014.zip -d ./data/ 
rm ./data/val2014.zip 

非常感谢您的帮助。

这是我的代码

model.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math,copy,time
import torchvision.models as models
from torch.nn import TransformerDecoderLayer,TransformerDecoder
from torch.nn.utils.rnn import pack_padded_sequence
from torch.autograd import Variable

class EncoderCNN(nn.Module):
    def __init__(self,embed_size):
        super(EncoderCNN,self).__init__()
        resnet = models.resnet152(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-2])
        self.conv1 = nn.Conv2d(2048,embed_size,1)
        self.embed_size = embed_size

        self.fine_tune()
        
    def forward(self,images):
        features = self.resnet(images)
        batch_size,_,_ = features.shape
        features = self.conv1(features)
        features = features.view(batch_size,self.embed_size,-1)
        features = features.permute(2,1)

        return features

    def fine_tune(self,fine_tune=True):
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning,only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

class PositionEncoder(nn.Module):
    def __init__(self,d_model,dropout,max_len=5000):
        super(PositionEncoder,self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len,d_model)
        position = torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,2).float() * (-math.log(10000.0) / d_model))
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe',pe)

    def forward(self,x):
        x = x + self.pe[:x.size(0),:]
        return self.dropout(x)
    
class Embedder(nn.Module):
    def __init__(self,vocab_size,d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size,d_model)
    def forward(self,x):
        return self.embed(x)


class Transformer(nn.Module):
    def __init__(self,h,num_hidden,N,device,dropout_dec=0.1,dropout_pos=0.1):
        super(Transformer,self).__init__()
        decoder_layers = TransformerDecoderLayer(d_model,dropout_dec)
        self.source_mask = None
        self.device = device
        self.d_model = d_model
        self.pos_decoder = PositionalEncoder(d_model,dropout_pos)
        self.decoder = TransformerDecoder(decoder_layers,N)
        self.embed = Embedder(vocab_size,d_model)
        self.linear = nn.Linear(d_model,vocab_size)

        self.init_weights()

    def forward(self,source,mem):
        source = source.permute(1,0) 
        if self.source_mask is None or self.source_mask.size(0) != len(source):
            self.source_mask = nn.Transformer.generate_square_subsequent_mask(self=self,sz=len(source)).to(self.device)

        source = self.embed(source) 
        source = source*math.sqrt(self.d_model)  
        source = self.pos_decoder(source)
        output = self.decoder(source,mem,self.source_mask)
        output = self.linear(output)
        return output

    def init_weights(self):
        initrange = 0.1
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange,initrange)

    def pred(self,memory,pred_len):
        batch_size = memory.size(1)
        src = torch.ones((pred_len,batch_size),dtype=int) * 2
        if self.source_mask is None or self.source_mask.size(0) != len(src):
            self.source_mask = nn.Transformer.generate_square_subsequent_mask(self=self,sz=len(src)).to(self.device)
        output = torch.ones((pred_len,dtype=int)
        src,output = src.cuda(),output.cuda()
        for i in range(pred_len):
            src_emb = self.embed(src) # src_len * batch size * embed size
            src_emb = src_emb*math.sqrt(self.d_model)
            src_emb = self.pos_decoder(src_emb)
            out = self.decoder(src_emb,self.source_mask)
            out = out[i]
            out = self.linear(out) # batch_size * vocab_size
            out = out.argmax(dim=1)
            if i < pred_len-1:
                src[i+1] = out
            output[i] = out
        return output

Data_Loader.py

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
import nltk
from PIL import Image
from build_vocab import Vocabulary
from pycocotools.coco import COCO


class CocoDataset(data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self,root,json,vocab,transform=None):
        """Set the path for images,captions and vocabulary wrapper.
        
        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self,index):
        """Returns one data pair (image and caption)."""
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root,path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image,target

    def __len__(self):
        return len(self.ids)


def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image,caption).
    
    We should build custom collate_fn rather than using default collate_fn,because merging caption (including padding) is not supported in default.
    Args:
        data: list of tuple (image,caption). 
            - image: torch tensor of shape (3,256,256).
            - caption: torch tensor of shape (?); variable length.
    Returns:
        images: torch tensor of shape (batch_size,3,256).
        targets: torch tensor of shape (batch_size,padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]),reverse=True)
    images,captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images,0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions),max(lengths)).long()
    for i,cap in enumerate(captions):
        end = lengths[i]
        targets[i,:end] = cap[:end]        
    return images,targets,lengths

def get_loader(root,transform,batch_size,shuffle,num_workers):
    """Returns torch.utils.data.DataLoader for custom coco dataset."""
    # COCO caption dataset
    coco = CocoDataset(root=root,json=json,vocab=vocab,transform=transform)
    
    # Data loader for COCO dataset
    # This will return (images,captions,lengths) for each iteration.
    # images: a tensor of shape (batch_size,224,224).
    # captions: a tensor of shape (batch_size,padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    data_loader = torch.utils.data.DataLoader(dataset=coco,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,collate_fn=collate_fn)
    return data_loader

Build_vocab.py

import nltk
import pickle
import argparse
from collections import Counter
from pycocotools.coco import COCO

class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self,word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self,word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]


    def __len__(self):
        return len(self.word2idx)

def build_vocab(json,threshold):
    coco = COCO(json)
    counter = Counter()
    ids = coco.anns.keys()
    for i,id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

        if (i+1) % 1000 == 0:
            print("[{}/{}] Tokenized the captions.".format(i+1,len(ids)))

    # If the word frequency is less than 'threshold',then the word is discarded.
    words = [word for word,cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i,word in enumerate(words):
        vocab.add_word(word)
    return vocab

def main(args):
    vocab = build_vocab(json=args.caption_path,threshold=args.threshold)
    vocab_path = args.vocab_path
    with open(vocab_path,'wb') as f:
        pickle.dump(vocab,f)
    print("Total vocabulary size: {}".format(len(vocab)))
    print("Saved the vocabulary wrapper to '{}'".format(vocab_path))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--caption_path',type=str,default='./data/annotations/captions_train2014.json',help='path for train annotation file')
    parser.add_argument('--vocab_path',default='./data/vocab.pkl',help='path for saving vocabulary wrapper')
    parser.add_argument('--threshold',type=int,default=4,help='minimum word count threshold')
    args = parser.parse_args()
    main(args)

train.py

import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import math
from tqdm import tqdm
from data_loader import get_loader 
from build_vocab import Vocabulary
from model import EncoderCNN,Decoder
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main(args):
    batch_size = 64
    embed_size = 512
    num_heads = 8
    num_layers = 6
    num_workers = 2
    num_epoch = 5
    lr = 1e-3
    load = False
    # Create model directory
    if not os.path.exists('models/'):
        os.makedirs('models/')
    
    # Image preprocessing,normalization for the pretrained resnet
    transform = transforms.Compose([ 
        transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.normalize((0.485,0.456,0.406),(0.229,0.224,0.225))])
    
    # Load vocabulary wrapper
    with open('data/vocab.pkl','rb') as f:
        vocab = pickle.load(f)
    
    # Build data loader
    data_loader = get_loader('data/resized2014','data/annotations/captions_train2014.json',shuffle=True,num_workers=num_workers) 

    encoder = EncoderCNN(embed_size).to(device)
    encoder.fine_tune(False)
    decoder = Decoder(len(vocab),num_heads,num_layers).to(device)
    
    if(load):
        encoder.load_state_dict(torch.load(os.path.join('models/','encoder-{}-{}.ckpt'.format(5,5000))))
        decoder.load_state_dict(torch.load(os.path.join('models/','decoder-{}-{}.ckpt'.format(5,5000))))
        print("Load Successful")

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    encoder_optim = torch.optim.Adam(encoder.parameters(),lr=lr)
    decoder_optim = torch.optim.Adam(decoder.parameters(),lr=lr)
    
    # Train the models
    for epoch in range(num_epoch):
        encoder.train()
        decoder.train()
        for i,(images,lengths) in tqdm(enumerate(data_loader),total=len(data_loader),leave=False):
            
            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)

            # Forward,backward and optimize
            features = encoder(images)
            cap_input = captions[:,:-1]
            cap_target = captions[:,1:]
            outputs = decoder(cap_input,features)
            outputs = outputs.permute(1,2)
            outputs_shape = outputs.reshape(-1,len(vocab))
            loss = criterion(outputs_shape,cap_target.reshape(-1))
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            encoder_optim.step()
            decoder_optim.step()
                
            # Save the model checkpoints
            if (i+1) % args.save_step == 0:
                torch.save(decoder.state_dict(),os.path.join(
                    'models/','decoder-{}-{}.ckpt'.format(epoch+1,i+1)))
                torch.save(encoder.state_dict(),'encoder-{}-{}.ckpt'.format(epoch+1,i+1)))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_step',default=10,help='step size for prining log info')
    parser.add_argument('--save_step',default=1000,help='step size for saving trained models')
        
    args = parser.parse_args()
    print(args)
    main(args)

sample.py


import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from torchvision import transforms 
from build_vocab import Vocabulary
from data_loader import get_loader 
from model import EncoderCNN,Decoder
from PIL import Image


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#

def token_sentence(decoder_out,itos):
    tokens = decoder_out
    tokens = tokens.transpose(1,0)
    tokens = tokens.cpu().numpy()
    results = []
    for instance in tokens:
        result = ' '.join([itos[x] for x in instance])
        results.append(''.join(result.partition('<eos>')[0])) # Cut before '<eos>'
    return results

def load_image(image_path,transform=None):
    image = Image.open(image_path).convert('RGB')
    image = image.resize([224,224],Image.lanczos)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

def main(args):
    batch_size = 64
    embed_size = 512
    num_heads = 8
    num_layers = 6
    num_workers = 2
    
    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),0.225))])
    
    # Load vocabulary wrapper
    with open(args.vocab_path,'rb') as f:
        vocab = pickle.load(f)

    data_loader = get_loader('data/resized2014',num_workers=num_workers) 

    # Build models
    encoder = EncoderCNN(embed_size).to(device)
    encoder.fine_tune(False)
    decoder = Decoder(len(vocab),num_layers).to(device)

    # Load trained models
    encoder.load_state_dict(torch.load(os.path.join('models/','encoder-{}-{}.ckpt'.format(1,4000))))
    decoder.load_state_dict(torch.load(os.path.join('models/','decoder-{}-{}.ckpt'.format(1,4000))))
    encoder.eval()
    decoder.eval()
    
    itos = vocab.idx2word
    pred_len = 100
    result_collection = []

    # Decode with greedy
    # with torch.no_grad():
    #     for i,lengths) in enumerate(data_loader):
    #         images = images.to(device)
    #         features = encoder(images)
    #         output = decoder.generator(features,pred_len)
    #         result_caption = token_sentence(output,itos)
    #         result_collection.extend(result_caption)

# Decode with greedy
    with torch.no_grad():
        for batch_index,(inputs,caplens) in enumerate(data_loader):
            inputs,captions = inputs.cuda(),captions.cuda()
            enc_out = encoder(inputs)
            captions_input = captions[:,:-1]
            captions_target = captions[:,1:]
            output = decoder.pred(enc_out,pred_len)
            result_caption = token_sentence(output,itos)
            result_collection.extend(result_caption)
        
            
    print("Prediction-greedy:",result_collection[1])
    print("Prediction-greedy:",result_collection[2])
    print("Prediction-greedy:",result_collection[3])
    print("Prediction-greedy:",result_collection[4])
    print("Prediction-greedy:",result_collection[5])
    print("Prediction-greedy:",result_collection[6])
    print("Prediction-greedy:",result_collection[7])
    print("Prediction-greedy:",result_collection[8])
    print("Prediction-greedy:",result_collection[9])
    print("Prediction-greedy:",result_collection[10])
    print("Prediction-greedy:",result_collection[11])

    # # Prepare an image
    # image = load_image(args.image,transform)
    # image_tensor = image.to(device)
    
    # # Generate an caption from the image
    # feature = encoder(image_tensor)
    # sampled_ids = decoder.generator(feature,pred_len)
    # sampled_ids = sampled_ids[0].cpu().numpy()          # (1,max_seq_length) -> (max_seq_length)
    
    # # Convert word_ids to words
    # sampled_caption = []
    # for word_id in sampled_ids:
    #     word = vocab.idx2word[word_id]
    #     sampled_caption.append(word)
    #     if word == '<end>':
    #         break
    # sentence = ' '.join(sampled_caption)
    
    # # Print out the image and the generated caption
    # print (sentence)
    # image = Image.open(args.image)
    # plt.imshow(np.asarray(image))
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image',required=False,help='input image for generating caption')
    parser.add_argument('--vocab_path',default='data/vocab.pkl',help='path for vocabulary wrapper')
    args = parser.parse_args()
    main(args)

resize.py


import argparse
import os
from PIL import Image


def resize_image(image,size):
    """Resize an image to the given size."""
    return image.resize(size,Image.ANTIALIAS)

def resize_images(image_dir,output_dir,size):
    """Resize the images in 'image_dir' and save into 'output_dir'."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    images = os.listdir(image_dir)
    num_images = len(images)
    for i,image in enumerate(images):
        with open(os.path.join(image_dir,image),'r+b') as f:
            with Image.open(f) as img:
                img = resize_image(img,size)
                img.save(os.path.join(output_dir,img.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1,num_images,output_dir))

def main(args):
    image_dir = args.image_dir
    output_dir = args.output_dir
    image_size = [args.image_size,args.image_size]
    resize_images(image_dir,image_size)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_dir',default='./data/train2014/',help='directory for train images')
    parser.add_argument('--output_dir',default='./data/resized2014/',help='directory for saving resized images')
    parser.add_argument('--image_size',default=256,help='size for image after processing')
    args = parser.parse_args()
    main(args)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)