PyTorch 在测试循环中耗尽 GPU 内存

问题描述

对于下面的训练计划,训练和验证都可以。 一旦到达测试方法,我就有CUDA out of memory。我应该改变什么才能有足够的内存来测试。

import torch
from torchvision import datasets,transforms
import torch.nn.functional as f
class CnnLstm(nn.Module):
    def __init__(self):
        super(CnnLstm,self).__init__()
        self.cnn = CNN()
        self.rnn = nn.LSTM(input_size=180000,hidden_size=256,num_layers=2,batch_first=True)#stacked LSTM with 2 layers
        #print(num_classes)
        self.linear = nn.Linear(256,num_classes)
        #print('after num_classes')

    def forward(self,x):
        #print(x.shape)
        batch_size,time_steps,channels,height,width = x.size()
        c_in = x.view(batch_size * time_steps,width)
        _,c_out = self.cnn(c_in)
        r_in = c_out.view(batch_size,-1)
        r_out,(_,_) = self.rnn(r_in)
        r_out2 = self.linear(r_out[:,-1,:])
        return f.log_softmax(r_out2,dim=1)


class TrainCNNLSTM:
    def __init__(self):
        self.seed = 1
        self.batch_size = 8
        self.validate_batch_size = 8
        self.test_batch_size = 1
        self.epoch = 20
        self.learning_rate = 0.01
        self.step = 100
        self.train_loader = None
        self.validate_loader = None
        self.test_loader = None
        #print('before')
        self.model = CnnLstm().to(device)
        #print('after')
        self.criterion = nn.CrossEntropyLoss()

    def load_data(self):
        data_loader = DataLoader()
        self.train_loader = data_loader.get_train_data(self.batch_size)
        self.validate_loader = data_loader.get_validate_data(self.validate_batch_size)
        self.test_loader = data_loader.get_test_data(self.test_batch_size)

    def train(self):  
        optimizer = torch.optim.SGD(self.model.parameters(),lr=self.learning_rate,momentum=0.9)
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,base_lr=self.learning_rate/100.0,max_lr=self.learning_rate,step_size_up=13)
        #optimizer = torch.optim.SGD(self.model.parameters(),lr=self.learning_rate)
        for epoch in range(self.epoch):
            t_losses=[]
            for iteration,(data,target) in enumerate(self.train_loader):
                data = np.expand_dims(data,axis=1)
                data = torch.FloatTensor(data)
                data,target = data.cuda(),target.cuda()
                data,target = Variable(data),Variable(target)
                optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output,target)
                #loss = f.nll_loss(output,target)
                t_losses.append(loss)
                loss.backward()
                optimizer.step()
                scheduler.step() 
                if iteration % self.step == 0:
                   print('Epoch: {} | train loss: {:.4f}'.format(epoch,loss.item()))
            avgd_trainloss = sum(t_losses)/len(t_losses)
            self.validate(epoch,avgd_trainloss)

    def validate(self,epoch,avg_tloss):
        v_losses=[]
        with torch.no_grad():
            for iteration,target) in enumerate(self.validate_loader):
                data = np.expand_dims(data,Variable(target)              
                output = self.model(data)
                loss = self.criterion(output,target)
                v_losses.append(loss)
        avgd_validloss = sum(v_losses)/len(v_losses)
        print('Epoch: {} | train loss: {:.4f} | validate loss: {:.4f}'.format(epoch,avg_tloss,avgd_validloss))

    def test(self):
        test_loss = []
        correct = 0
        for data,target in self.test_loader:
            data = np.expand_dims(data,axis=1)
            data = torch.FloatTensor(data)
            data,target.cuda()
            data,target = Variable(data,volatile=True),Variable(target)
            output = self.model(data)
            loss = self.criterion(output,target)
            #f.nll_loss(output,target,size_average=False).item()  # sum up batch loss
            test_loss.append(loss)
            pred = torch.max(output,1)[1].data.squeeze()
            correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

        test_loss = sum(test_loss)/len(test_loss)
        print('\nTest set: Average loss: {:.4f},Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss,correct,len(self.test_loader.dataset),100. * correct / len(self.test_loader.dataset)))


train = TrainCNNLSTM()
train.load_data()
train.train()
train.test() 

解决方法

.item() 添加到损失列表时,您应该在其上调用 loss

loss = self.criterion(output,target)
test_loss.append(loss.item())

这避免了在仍然附加到计算图的列表中累积张量。为了您的准确性,我也会这么说。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...