RNN中的梯度累积

问题描述

在运行大型RNN网络时,我遇到了一些内存问题(GPU),但我想使批处理大小保持合理,因此我想尝试进行梯度累积。在一个可以一口气预测输出的网络中,这似乎是不言而喻的,但是在RNN中,您需要为每个输入步骤进行多次前向传递。因此,我担心我的实现无法按预期工作。我从用户albanD的优秀示例here 开始,但是我认为在使用RNN时应对其进行修改。我认为这是因为您对每个序列进行多次前向积累了更多的梯度。

我当前的实现看起来像这样,同时允许在PyTorch 1.6中使用AMP,这似乎很重要-一切都需要在正确的地方调用。请注意,这只是一个抽象版本,可能看起来像很多代码,但主要是注释。

def train(epochs):
    """Main training loop. Loops for `epoch` number of epochs. Calls `process`."""
    for epoch in range(1,epochs + 1):
        train_loss = process("train")
        valid_loss = process("valid")
        # ... check whether we improved over earlier epochs
        if lr_scheduler:
            lr_scheduler.step(valid_loss)
        
def process(do):
    """Do a single epoch run through the dataloader of the training or validation set. 
       Also takes care of optimizing the model after every `gradient_accumulation_steps` steps.
       Calls `step` for each batch where it gets the loss from."""
    if do == "train":
        model.train()
        torch.set_grad_enabled(True)
    else:
        model.eval()
        torch.set_grad_enabled(False)
    
    loss = 0.
    for batch_idx,batch in enumerate(dataloaders[do]):
        step_loss,avg_step_loss = step(batch)
        loss += avg_step_loss

        if do == "train":
            if amp:
                scaler.scale(step_loss).backward()

                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    # clip in-place
                    clip_grad_norm_(model.parameters(),2.0)
                    scaler.step(optimizer)
                    scaler.update()
                    model.zero_grad()
            else:
                step_loss.backward()
                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    clip_grad_norm_(model.parameters(),2.0)
                    optimizer.step()
                    model.zero_grad()
        
        # return average loss
        return loss / len(dataloaders[do])

    def step():
        """Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
        # do stuff... init hidden state and first input etc.
        loss = torch.tensor([0.]).to(device)
        
        for i in range(target_len):
            with torch.cuda.amp.autocast(enabled=amp):
                # overwrite previous decoder_hidden
                output,decoder_hidden = model(decoder_input,decoder_hidden)

                # compute loss between predicted classes (bs x classes) and correct classes for _this word_
                item_loss = criterion(output,target_tensor[i])

                # We calculate the gradients for the average step so that when
                # we do take an optimizer.step,it takes into account the mean step_loss
                # across batches. So basically (A+B+C)/3 = A/3 + B/3 + C/3
                loss += (item_loss / gradient_accumulation_steps)

            topv,topi = output.topk(1)
            decoder_input = topi.detach()
        
        return loss,loss.item() / target_len

以上内容似乎并不像我希望的那样起作用,也就是说,它仍然很快会遇到内存不足的问题。我认为原因是step已经积累了很多信息,但我不确定。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)