问题描述
我正在尝试生成一个模型,该模型将使用resnet作为编码器,使用变压器作为解码器,并使用COCO作为数据库来为图像生成标题。
在对我的模型进行了10个时期的训练之后,我的模型没有产生除单词<pad>
之外的任何内容,这意味着经过模型后的唯一结果仅产生了与<pad>
相对应的标记0
使用调试器后,似乎错误发生在argmax处,其中输出仅变为零而不是其他任何值,但是我不知道如何解决它,这是我的模型是否有问题,或者是方式它受过训练吗?
如果有帮助,我会基于this github建立我的模型。
下载COCO模型的脚本在这里:
Download.sh
mkdir data
wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P ./data/
wget http://images.cocodataset.org/zips/train2014.zip -P ./data/
wget http://images.cocodataset.org/zips/val2014.zip -P ./data/
unzip ./data/captions_train-val2014.zip -d ./data/
rm ./data/captions_train-val2014.zip
unzip ./data/train2014.zip -d ./data/
rm ./data/train2014.zip
unzip ./data/val2014.zip -d ./data/
rm ./data/val2014.zip
非常感谢您的帮助。
这是我的代码:
model.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math,copy,time
import torchvision.models as models
from torch.nn import TransformerDecoderLayer,TransformerDecoder
from torch.nn.utils.rnn import pack_padded_sequence
from torch.autograd import Variable
class EncoderCNN(nn.Module):
def __init__(self,embed_size):
super(EncoderCNN,self).__init__()
resnet = models.resnet152(pretrained=True)
self.resnet = nn.Sequential(*list(resnet.children())[:-2])
self.conv1 = nn.Conv2d(2048,embed_size,1)
self.embed_size = embed_size
self.fine_tune()
def forward(self,images):
features = self.resnet(images)
batch_size,_,_ = features.shape
features = self.conv1(features)
features = features.view(batch_size,self.embed_size,-1)
features = features.permute(2,1)
return features
def fine_tune(self,fine_tune=True):
for p in self.resnet.parameters():
p.requires_grad = False
# If fine-tuning,only fine-tune convolutional blocks 2 through 4
for c in list(self.resnet.children())[5:]:
for p in c.parameters():
p.requires_grad = fine_tune
class PositionEncoder(nn.Module):
def __init__(self,d_model,dropout,max_len=5000):
super(PositionEncoder,self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len,d_model)
position = torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0,2).float() * (-math.log(10000.0) / d_model))
pe[:,0::2] = torch.sin(position * div_term)
pe[:,1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0,1)
self.register_buffer('pe',pe)
def forward(self,x):
x = x + self.pe[:x.size(0),:]
return self.dropout(x)
class Embedder(nn.Module):
def __init__(self,vocab_size,d_model):
super().__init__()
self.embed = nn.Embedding(vocab_size,d_model)
def forward(self,x):
return self.embed(x)
class Transformer(nn.Module):
def __init__(self,h,num_hidden,N,device,dropout_dec=0.1,dropout_pos=0.1):
super(Transformer,self).__init__()
decoder_layers = TransformerDecoderLayer(d_model,dropout_dec)
self.source_mask = None
self.device = device
self.d_model = d_model
self.pos_decoder = PositionalEncoder(d_model,dropout_pos)
self.decoder = TransformerDecoder(decoder_layers,N)
self.embed = Embedder(vocab_size,d_model)
self.linear = nn.Linear(d_model,vocab_size)
self.init_weights()
def forward(self,source,mem):
source = source.permute(1,0)
if self.source_mask is None or self.source_mask.size(0) != len(source):
self.source_mask = nn.Transformer.generate_square_subsequent_mask(self=self,sz=len(source)).to(self.device)
source = self.embed(source)
source = source*math.sqrt(self.d_model)
source = self.pos_decoder(source)
output = self.decoder(source,mem,self.source_mask)
output = self.linear(output)
return output
def init_weights(self):
initrange = 0.1
self.linear.bias.data.zero_()
self.linear.weight.data.uniform_(-initrange,initrange)
def pred(self,memory,pred_len):
batch_size = memory.size(1)
src = torch.ones((pred_len,batch_size),dtype=int) * 2
if self.source_mask is None or self.source_mask.size(0) != len(src):
self.source_mask = nn.Transformer.generate_square_subsequent_mask(self=self,sz=len(src)).to(self.device)
output = torch.ones((pred_len,dtype=int)
src,output = src.cuda(),output.cuda()
for i in range(pred_len):
src_emb = self.embed(src) # src_len * batch size * embed size
src_emb = src_emb*math.sqrt(self.d_model)
src_emb = self.pos_decoder(src_emb)
out = self.decoder(src_emb,self.source_mask)
out = out[i]
out = self.linear(out) # batch_size * vocab_size
out = out.argmax(dim=1)
if i < pred_len-1:
src[i+1] = out
output[i] = out
return output
Data_Loader.py
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
import nltk
from PIL import Image
from build_vocab import Vocabulary
from pycocotools.coco import COCO
class CocoDataset(data.Dataset):
"""COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
def __init__(self,root,json,vocab,transform=None):
"""Set the path for images,captions and vocabulary wrapper.
Args:
root: image directory.
json: coco annotation file path.
vocab: vocabulary wrapper.
transform: image transformer.
"""
self.root = root
self.coco = COCO(json)
self.ids = list(self.coco.anns.keys())
self.vocab = vocab
self.transform = transform
def __getitem__(self,index):
"""Returns one data pair (image and caption)."""
coco = self.coco
vocab = self.vocab
ann_id = self.ids[index]
caption = coco.anns[ann_id]['caption']
img_id = coco.anns[ann_id]['image_id']
path = coco.loadImgs(img_id)[0]['file_name']
image = Image.open(os.path.join(self.root,path)).convert('RGB')
if self.transform is not None:
image = self.transform(image)
# Convert caption (string) to word ids.
tokens = nltk.tokenize.word_tokenize(str(caption).lower())
caption = []
caption.append(vocab('<start>'))
caption.extend([vocab(token) for token in tokens])
caption.append(vocab('<end>'))
target = torch.Tensor(caption)
return image,target
def __len__(self):
return len(self.ids)
def collate_fn(data):
"""Creates mini-batch tensors from the list of tuples (image,caption).
We should build custom collate_fn rather than using default collate_fn,because merging caption (including padding) is not supported in default.
Args:
data: list of tuple (image,caption).
- image: torch tensor of shape (3,256,256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size,3,256).
targets: torch tensor of shape (batch_size,padded_length).
lengths: list; valid length for each padded caption.
"""
# Sort a data list by caption length (descending order).
data.sort(key=lambda x: len(x[1]),reverse=True)
images,captions = zip(*data)
# Merge images (from tuple of 3D tensor to 4D tensor).
images = torch.stack(images,0)
# Merge captions (from tuple of 1D tensor to 2D tensor).
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions),max(lengths)).long()
for i,cap in enumerate(captions):
end = lengths[i]
targets[i,:end] = cap[:end]
return images,targets,lengths
def get_loader(root,transform,batch_size,shuffle,num_workers):
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
# COCO caption dataset
coco = CocoDataset(root=root,json=json,vocab=vocab,transform=transform)
# Data loader for COCO dataset
# This will return (images,captions,lengths) for each iteration.
# images: a tensor of shape (batch_size,224,224).
# captions: a tensor of shape (batch_size,padded_length).
# lengths: a list indicating valid length for each caption. length is (batch_size).
data_loader = torch.utils.data.DataLoader(dataset=coco,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,collate_fn=collate_fn)
return data_loader
Build_vocab.py
import nltk
import pickle
import argparse
from collections import Counter
from pycocotools.coco import COCO
class Vocabulary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self,word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self,word):
if not word in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def build_vocab(json,threshold):
coco = COCO(json)
counter = Counter()
ids = coco.anns.keys()
for i,id in enumerate(ids):
caption = str(coco.anns[id]['caption'])
tokens = nltk.tokenize.word_tokenize(caption.lower())
counter.update(tokens)
if (i+1) % 1000 == 0:
print("[{}/{}] Tokenized the captions.".format(i+1,len(ids)))
# If the word frequency is less than 'threshold',then the word is discarded.
words = [word for word,cnt in counter.items() if cnt >= threshold]
# Create a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word('<pad>')
vocab.add_word('<start>')
vocab.add_word('<end>')
vocab.add_word('<unk>')
# Add the words to the vocabulary.
for i,word in enumerate(words):
vocab.add_word(word)
return vocab
def main(args):
vocab = build_vocab(json=args.caption_path,threshold=args.threshold)
vocab_path = args.vocab_path
with open(vocab_path,'wb') as f:
pickle.dump(vocab,f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved the vocabulary wrapper to '{}'".format(vocab_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--caption_path',type=str,default='./data/annotations/captions_train2014.json',help='path for train annotation file')
parser.add_argument('--vocab_path',default='./data/vocab.pkl',help='path for saving vocabulary wrapper')
parser.add_argument('--threshold',type=int,default=4,help='minimum word count threshold')
args = parser.parse_args()
main(args)
train.py
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import math
from tqdm import tqdm
from data_loader import get_loader
from build_vocab import Vocabulary
from model import EncoderCNN,Decoder
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(args):
batch_size = 64
embed_size = 512
num_heads = 8
num_layers = 6
num_workers = 2
num_epoch = 5
lr = 1e-3
load = False
# Create model directory
if not os.path.exists('models/'):
os.makedirs('models/')
# Image preprocessing,normalization for the pretrained resnet
transform = transforms.Compose([
transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.normalize((0.485,0.456,0.406),(0.229,0.224,0.225))])
# Load vocabulary wrapper
with open('data/vocab.pkl','rb') as f:
vocab = pickle.load(f)
# Build data loader
data_loader = get_loader('data/resized2014','data/annotations/captions_train2014.json',shuffle=True,num_workers=num_workers)
encoder = EncoderCNN(embed_size).to(device)
encoder.fine_tune(False)
decoder = Decoder(len(vocab),num_heads,num_layers).to(device)
if(load):
encoder.load_state_dict(torch.load(os.path.join('models/','encoder-{}-{}.ckpt'.format(5,5000))))
decoder.load_state_dict(torch.load(os.path.join('models/','decoder-{}-{}.ckpt'.format(5,5000))))
print("Load Successful")
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
encoder_optim = torch.optim.Adam(encoder.parameters(),lr=lr)
decoder_optim = torch.optim.Adam(decoder.parameters(),lr=lr)
# Train the models
for epoch in range(num_epoch):
encoder.train()
decoder.train()
for i,(images,lengths) in tqdm(enumerate(data_loader),total=len(data_loader),leave=False):
# Set mini-batch dataset
images = images.to(device)
captions = captions.to(device)
# Forward,backward and optimize
features = encoder(images)
cap_input = captions[:,:-1]
cap_target = captions[:,1:]
outputs = decoder(cap_input,features)
outputs = outputs.permute(1,2)
outputs_shape = outputs.reshape(-1,len(vocab))
loss = criterion(outputs_shape,cap_target.reshape(-1))
decoder.zero_grad()
encoder.zero_grad()
loss.backward()
encoder_optim.step()
decoder_optim.step()
# Save the model checkpoints
if (i+1) % args.save_step == 0:
torch.save(decoder.state_dict(),os.path.join(
'models/','decoder-{}-{}.ckpt'.format(epoch+1,i+1)))
torch.save(encoder.state_dict(),'encoder-{}-{}.ckpt'.format(epoch+1,i+1)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--log_step',default=10,help='step size for prining log info')
parser.add_argument('--save_step',default=1000,help='step size for saving trained models')
args = parser.parse_args()
print(args)
main(args)
sample.py
import torch
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
from torchvision import transforms
from build_vocab import Vocabulary
from data_loader import get_loader
from model import EncoderCNN,Decoder
from PIL import Image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#
def token_sentence(decoder_out,itos):
tokens = decoder_out
tokens = tokens.transpose(1,0)
tokens = tokens.cpu().numpy()
results = []
for instance in tokens:
result = ' '.join([itos[x] for x in instance])
results.append(''.join(result.partition('<eos>')[0])) # Cut before '<eos>'
return results
def load_image(image_path,transform=None):
image = Image.open(image_path).convert('RGB')
image = image.resize([224,224],Image.lanczos)
if transform is not None:
image = transform(image).unsqueeze(0)
return image
def main(args):
batch_size = 64
embed_size = 512
num_heads = 8
num_layers = 6
num_workers = 2
# Image preprocessing
transform = transforms.Compose([
transforms.ToTensor(),0.225))])
# Load vocabulary wrapper
with open(args.vocab_path,'rb') as f:
vocab = pickle.load(f)
data_loader = get_loader('data/resized2014',num_workers=num_workers)
# Build models
encoder = EncoderCNN(embed_size).to(device)
encoder.fine_tune(False)
decoder = Decoder(len(vocab),num_layers).to(device)
# Load trained models
encoder.load_state_dict(torch.load(os.path.join('models/','encoder-{}-{}.ckpt'.format(1,4000))))
decoder.load_state_dict(torch.load(os.path.join('models/','decoder-{}-{}.ckpt'.format(1,4000))))
encoder.eval()
decoder.eval()
itos = vocab.idx2word
pred_len = 100
result_collection = []
# Decode with greedy
# with torch.no_grad():
# for i,lengths) in enumerate(data_loader):
# images = images.to(device)
# features = encoder(images)
# output = decoder.generator(features,pred_len)
# result_caption = token_sentence(output,itos)
# result_collection.extend(result_caption)
# Decode with greedy
with torch.no_grad():
for batch_index,(inputs,caplens) in enumerate(data_loader):
inputs,captions = inputs.cuda(),captions.cuda()
enc_out = encoder(inputs)
captions_input = captions[:,:-1]
captions_target = captions[:,1:]
output = decoder.pred(enc_out,pred_len)
result_caption = token_sentence(output,itos)
result_collection.extend(result_caption)
print("Prediction-greedy:",result_collection[1])
print("Prediction-greedy:",result_collection[2])
print("Prediction-greedy:",result_collection[3])
print("Prediction-greedy:",result_collection[4])
print("Prediction-greedy:",result_collection[5])
print("Prediction-greedy:",result_collection[6])
print("Prediction-greedy:",result_collection[7])
print("Prediction-greedy:",result_collection[8])
print("Prediction-greedy:",result_collection[9])
print("Prediction-greedy:",result_collection[10])
print("Prediction-greedy:",result_collection[11])
# # Prepare an image
# image = load_image(args.image,transform)
# image_tensor = image.to(device)
# # Generate an caption from the image
# feature = encoder(image_tensor)
# sampled_ids = decoder.generator(feature,pred_len)
# sampled_ids = sampled_ids[0].cpu().numpy() # (1,max_seq_length) -> (max_seq_length)
# # Convert word_ids to words
# sampled_caption = []
# for word_id in sampled_ids:
# word = vocab.idx2word[word_id]
# sampled_caption.append(word)
# if word == '<end>':
# break
# sentence = ' '.join(sampled_caption)
# # Print out the image and the generated caption
# print (sentence)
# image = Image.open(args.image)
# plt.imshow(np.asarray(image))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image',required=False,help='input image for generating caption')
parser.add_argument('--vocab_path',default='data/vocab.pkl',help='path for vocabulary wrapper')
args = parser.parse_args()
main(args)
resize.py
import argparse
import os
from PIL import Image
def resize_image(image,size):
"""Resize an image to the given size."""
return image.resize(size,Image.ANTIALIAS)
def resize_images(image_dir,output_dir,size):
"""Resize the images in 'image_dir' and save into 'output_dir'."""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
images = os.listdir(image_dir)
num_images = len(images)
for i,image in enumerate(images):
with open(os.path.join(image_dir,image),'r+b') as f:
with Image.open(f) as img:
img = resize_image(img,size)
img.save(os.path.join(output_dir,img.format)
if (i+1) % 100 == 0:
print ("[{}/{}] Resized the images and saved into '{}'."
.format(i+1,num_images,output_dir))
def main(args):
image_dir = args.image_dir
output_dir = args.output_dir
image_size = [args.image_size,args.image_size]
resize_images(image_dir,image_size)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir',default='./data/train2014/',help='directory for train images')
parser.add_argument('--output_dir',default='./data/resized2014/',help='directory for saving resized images')
parser.add_argument('--image_size',default=256,help='size for image after processing')
args = parser.parse_args()
main(args)
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)