带有 CUDA 的 Pytorch 在使用 pack_padded_sequence 时抛出 RuntimeError

问题描述

我正在尝试使用 Pytorch 训练 BiLSTM-CRF 检测新的 NER 实体。 为此,我使用了从 Pytorch Advanced tutorial 派生的一段代码This snippet 实现批量训练。

我按照自述文件进行操作,以便根据需要提供数据。在 cpu 上一切正常,但是当我尝试将其连接到 GPU 时,出现以下错误

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-23-794982510db6> in <module>
      4         batch_input,batch_input_lens,batch_mask,batch_target = batch_info
      5 
----> 6         loss_train = model.neg_log_likelihood(batch_input,batch_target)
      7         optimizer.zero_grad()
      8         loss_train.backward()

<ipython-input-11-e44ffbf7d75f> in neg_log_likelihood(self,batch_input,batch_target)
    185 
    186     def neg_log_likelihood(self,batch_target):
--> 187         feats = self.bilstm(batch_input,batch_mask)
    188         gold_score = self.CRF.score_sentence(feats,batch_target)
    189         forward_score = self.CRF.score_z(feats,batch_input_lens)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self,*input,**kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input,**kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks,non_full_backward_hooks = [],[]

<ipython-input-11-e44ffbf7d75f> in forward(self,batch_mask)
     46         batch_input = self.word_embeds(batch_input)  # size: #batch * padding_length * embedding_dim
     47         batch_input = rnn_utils.pack_padded_sequence(
---> 48             batch_input,batch_first=True)
     49         batch_output,self.hidden = self.lstm(batch_input,self.hidden)
     50         self.repackage_hidden(self.hidden)

/opt/conda/lib/python3.7/site-packages/torch/nn/utils/rnn.py in pack_padded_sequence(input,lengths,batch_first,enforce_sorted)
    247 
    248     data,batch_sizes = \
--> 249         _VF._pack_padded_sequence(input,batch_first)
    250     return _packed_sequence_init(data,batch_sizes,sorted_indices,None)
    251 

RuntimeError: 'lengths' argument should be a 1D cpu int64 tensor,but got 1D cuda:0 Long tensor`

如果我理解得很好,pack_padded_sequence 需要张量在 cpu 而不是 GPU 上。不幸的是,我的前向函数调用了 pack_padded_sequence,如果不回到 cpu 进行整个训练,我看不出有任何方法可以这样做。

这是完整的代码

类定义

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils


class BiLSTM(nn.Module):
    def __init__(self,vocab_size,tagset,embedding_dim,hidden_dim,num_layers,bidirectional,dropout,pretrained=None):
        super(BiLSTM,self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.tagset_size = len(tagset)
        self.bidirectional = bidirectional
        self.num_layers = num_layers
        self.word_embeds = nn.Embedding(vocab_size+2,embedding_dim)
        if pretrained is not None:
            self.word_embeds = nn.Embedding.from_pretrained(pretrained)
        self.lstm = nn.LSTM(
            input_size=embedding_dim,hidden_size=hidden_dim // 2 if bidirectional else hidden_dim,num_layers=num_layers,dropout=dropout,bidirectional=bidirectional,batch_first=True,)
        self.hidden2tag = nn.Linear(hidden_dim,self.tagset_size)
        self.hidden = None

    def init_hidden(self,batch_size,device):
        init_hidden_dim = self.hidden_dim // 2 if self.bidirectional else self.hidden_dim
        init_first_dim = self.num_layers * 2 if self.bidirectional else self.num_layers
        self.hidden = (
            torch.randn(init_first_dim,init_hidden_dim).to(device),torch.randn(init_first_dim,init_hidden_dim).to(device)
        )

    def repackage_hidden(self,hidden):
        """Wraps hidden states in new Tensors,to detach them from their history."""
        if isinstance(hidden,torch.Tensor):
            return hidden.detach_().to(device)
        else:
            return tuple(self.repackage_hidden(h) for h in hidden)

    def forward(self,batch_mask):
        batch_size,padding_length = batch_input.size()
        batch_input = self.word_embeds(batch_input)  # size: #batch * padding_length * embedding_dim
        batch_input = rnn_utils.pack_padded_sequence(
            batch_input,batch_first=True)
        batch_output,self.hidden)
        self.repackage_hidden(self.hidden)
        batch_output,_ = rnn_utils.pad_packed_sequence(batch_output,batch_first=True)
        batch_output = batch_output.contiguous().view(batch_size * padding_length,-1)
        batch_output = batch_output[batch_mask,...]
        out = self.hidden2tag(batch_output)
        return out

    def neg_log_likelihood(self,batch_target):
        loss = nn.CrossEntropyLoss(reduction='mean')
        feats = self(batch_input,batch_mask)
        batch_target = torch.cat(batch_target,0).to(device)
        return loss(feats,batch_target)

    def predict(self,batch_mask):
        feats = self(batch_input,batch_mask)
        val,pred = torch.max(feats,1)
        return pred


class CRF(nn.Module):
    def __init__(self,start_tag,end_tag,device):
        super(CRF,self).__init__()
        self.tagset_size = len(tagset)
        self.START_TAG_IDX = tagset.index(start_tag)
        self.END_TAG_IDX = tagset.index(end_tag)
        self.START_TAG_TENSOR = torch.LongTensor([self.START_TAG_IDX]).to(device)
        self.END_TAG_TENSOR = torch.LongTensor([self.END_TAG_IDX]).to(device)
        # trans: (tagset_size,tagset_size) trans (i,j) means state_i -> state_j
        self.trans = nn.Parameter(
            torch.randn(self.tagset_size,self.tagset_size)
        )
        # self.trans.data[...] = 1
        self.trans.data[:,self.START_TAG_IDX] = -10000
        self.trans.data[self.END_TAG_IDX,:] = -10000
        self.device = device

    def init_alpha(self,tagset_size):
        return torch.full((batch_size,tagset_size,1),-10000,dtype=torch.float,device=self.device)

    def init_path(self,size_shape):
        # Initialization Path - LongTensor + Device + Full_value=0
        return torch.full(size_shape,dtype=torch.long,device=self.device)

    def _iter_legal_batch(self,reverse=False):
        index = torch.arange(0,batch_input_lens.sum(),dtype=torch.long)
        packed_index = rnn_utils.pack_sequence(
            torch.split(index,batch_input_lens.tolist())
        )
        batch_iter = torch.split(packed_index.data,packed_index.batch_sizes.tolist())
        batch_iter = reversed(batch_iter) if reverse else batch_iter
        for idx in batch_iter:
            yield idx,idx.size()[0]

    def score_z(self,feats,batch_input_lens):
        # 模拟packed pad过程
        tagset_size = feats.shape[1]
        batch_size = len(batch_input_lens)
        alpha = self.init_alpha(batch_size,tagset_size)
        alpha[:,self.START_TAG_IDX,:] = 0  # Initialization
        for legal_idx,legal_batch_size in self._iter_legal_batch(batch_input_lens):
            feat = feats[legal_idx,].view(legal_batch_size,1,tagset_size)  # 
            # #batch * 1 * |tag| + #batch * |tag| * 1 + |tag| * |tag| = #batch * |tag| * |tag|
            legal_batch_score = feat + alpha[:legal_batch_size,] + self.trans
            alpha_new = torch.logsumexp(legal_batch_score,1).unsqueeze(2).to(device)
            alpha[:legal_batch_size,] = alpha_new
        alpha = alpha + self.trans[:,self.END_TAG_IDX].unsqueeze(1)
        score = torch.logsumexp(alpha,1).sum().to(device)
        return score

    def score_sentence(self,batch_target):
        # CRF Batched Sentence score
        # feats: (#batch_state(#words),tagset_size)
        # batch_target: list<torch.LongTensor> At least One LongTensor
        # Warning: words order =  batch_target order
        def _add_start_tag(target):
            return torch.cat([self.START_TAG_TENSOR,target]).to(device)

        def _add_end_tag(target):
            return torch.cat([target,self.END_TAG_TENSOR]).to(device)

        from_state = [_add_start_tag(target) for target in batch_target]
        to_state = [_add_end_tag(target) for target in batch_target]
        from_state = torch.cat(from_state).to(device)  
        to_state = torch.cat(to_state).to(device)  
        trans_score = self.trans[from_state,to_state]

        gather_target = torch.cat(batch_target).view(-1,1).to(device)
        emit_score = torch.gather(feats,gather_target).to(device)  

        return trans_score.sum() + emit_score.sum()

    def viterbi(self,batch_input_lens):
        word_size,tagset_size = feats.shape
        batch_size = len(batch_input_lens)
        viterbi_path = self.init_path(feats.shape)  # use feats.shape to init path.shape
        alpha = self.init_alpha(batch_size,:].view(legal_batch_size,tagset_size)
            legal_batch_score = feat + alpha[:legal_batch_size,] + self.trans
            alpha_new,best_tag = torch.max(legal_batch_score,1).to(device)
            alpha[:legal_batch_size,] = alpha_new.unsqueeze(2)
            viterbi_path[legal_idx,] = best_tag
        alpha = alpha + self.trans[:,self.END_TAG_IDX].unsqueeze(1)
        path_score,best_tag = torch.max(alpha,1).to(device)
        path_score = path_score.squeeze()  # path_score=#batch

        best_paths = self.init_path((word_size,1))
        for legal_idx,legal_batch_size in self._iter_legal_batch(batch_input_lens,reverse=True):
            best_paths[legal_idx,] = best_tag[:legal_batch_size,]  # 
            backword_path = viterbi_path[legal_idx,]  # 1 * |Tag|
            this_tag = best_tag[:legal_batch_size,]  # 1 * |legal_batch_size|
            backword_tag = torch.gather(backword_path,this_tag).to(device)
            best_tag[:legal_batch_size,] = backword_tag
            # never computing <START>

        # best_paths = #words
        return path_score.view(-1),best_paths.view(-1)


class BiLSTM_CRF(nn.Module):
    def __init__(self,device,pretrained=None):
        super(BiLSTM_CRF,self).__init__()
        self.bilstm = BiLSTM(vocab_size,pretrained)
        self.CRF = CRF(tagset,device)

    def init_hidden(self,device):
        self.bilstm.hidden = self.bilstm.init_hidden(batch_size,device)

    def forward(self,batch_mask):
        feats = self.bilstm(batch_input,batch_mask)
        score,path = self.CRF.viterbi(feats,batch_input_lens)
        return path

    def neg_log_likelihood(self,batch_target):
        feats = self.bilstm(batch_input,batch_mask)
        gold_score = self.CRF.score_sentence(feats,batch_target)
        forward_score = self.CRF.score_z(feats,batch_input_lens)
        return forward_score - gold_score

    def predict(self,batch_mask):
        return self(batch_input,batch_mask)

训练单元:

def prepare_sequence(seq,to_ix,device):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs,dtype=torch.long).to(device)

def prepare_labels(lab,tag_to_ix,device):
    idxs = [tag_to_ix[w] for w in lab]
    return torch.tensor(idxs,dtype=torch.long).to(device)


class PadSequence:
    def __call__(self,batch):
        device = torch.device('cuda')
        # Let's assume that each element in "batch" is a tuple (data,label).
        # Sort the batch in the descending order
        sorted_batch = sorted(batch,key=lambda x: len(x[0]),reverse=True)
        # Get each sequence and pad it
        sequences = [x[0] for x in sorted_batch]
        sentence_in =[prepare_sequence(x,word_to_ix,device) for x in sequences]
        sequences_padded = torch.nn.utils.rnn.pad_sequence(sentence_in,padding_value = len(word_to_ix) +1,batch_first=True).to(device)
        
        lengths = torch.LongTensor([len(x) for x in sequences]).to(device)
        
        masks = [True if index_word!=len(word_to_ix)+1 else False for sentence in sequences_padded for index_word in sentence ]
        
        labels = [x[1] for x in sorted_batch]
        labels_in = [prepare_sequence(x,device) for x in labels]
        return sequences_padded,masks,labels_in


{ .... code to get the data formatted...}


device = torch.device("cuda")
batch_size = 64


START_TAG = "<START>"
STOP_TAG = "<STOP>"
EMbedDING_DIM = 200
HIDDEN_DIM = 20
NUM_LAYER = 3
BIDIRECTIONNAL = True
DROPOUT = 0.1

train_iter = DataLoader(dataset=training_data,collate_fn=PadSequence(),batch_size=64,shuffle=True) 




model = BiLSTM_CRF(len(word_to_ix),EMbedDING_DIM,HIDDEN_DIM,NUM_LAYER,BIDIRECTIONNAL,DROPOUT,START_TAG,STOP_TAG,device ).to(device)
optimizer = optim.SGD(model.parameters(),lr=0.01,weight_decay=1e-4)
model.init_hidden(batch_size,device)
with tqdm(total=len(train_iter)) as progress_bar:
    for batch_info in train_iter:
        batch_input,batch_target = batch_info

        loss_train = model.neg_log_likelihood(batch_input,batch_target)
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()
        progress_bar.update(1) # update progress

解决方法

PadSequence 函数中(作为一个 collate_fn 收集样本并从中进行批量处理),您明确地投射到 cuda 设备,即:

class PadSequence:
    def __call__(self,batch):
        device = torch.device('cuda')
        
        # Left rest of the code for brevity
        ...
        lengths = torch.LongTensor([len(x) for x in sequences]).to(device)
        ...
        return sequences_padded,lengths,masks,labels_in

创建批处理时不需要投射数据,我们通常在通过神经网络推送示例之前就这样做。

此外,您至少应该像这样定义设备:

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

或者甚至更好地在您设置所有内容的代码的某些部分中为您/用户选择设备。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...