如何使用 Chainer 中的 Standard Updater 类大幅增加 mini-batch 的数量?

问题描述

如何使用 Chainer 中的 Standard Updater 类大幅增加 mini-batch 的数量

在 PyTorch 的情况下, 我可以大幅增加 mini-batch 的数量

  • 每次都执行 loss.backward()。
  • 每三次执行 optimizer.step() / optimizer.zero_grad() 一次。 这有效地显着增加了 mini-batch 的数量

问题 1。 在 Chainer 的情况下, 是否可以大幅增加 mini-batch 的数量

  • 每次都执行 loss.backward()。
  • 每 3 次执行 net.cleargrads() / optimizer.update() 一次。 这会大大增加小批量的数量吗?

问题 2。 事实上,我正在使用 StandardUpdater 类。 是否可以使用任何超参数大幅增加小批量的数量? 或者我应该让我的类继承自 StandardUpdater 类并更改上面的实现?

如果问题已经被问到,我很抱歉。

希望大家给点建议。

解决方法

(这个问题似乎很老了,但我偶然发现了它并想分享我对问题的解决方案)

您基本上可以像在 PyTorch 中那样做。不幸的是,StandardUpdater 既没有支持它的超参数,也没有“小批量更新”的实现。但这是我的实现,我是如何做到的(基本上正如您在问题中提到的:从 StandardUpdater 继承并重新实现 update_core 方法):

from chainer.training import StandardUpdater
from chainer.dataset import convert

class MiniBatchUpdater(StandardUpdater):
    """
        The iterator outputs batches in mini-batch sizes. This updater
        cummulates the gradients of these mini-batches until the
        update_size is reached. Then a parameter update is performed
    """
    def __init__(self,update_size=32,*args,**kwargs):
        super(MiniBatchUpdater,self).__init__(*args,**kwargs)
        self.update_size = update_size
        self.iteration_counter = 0

    def update_core(self):
        optimizer = self.get_optimizer('main')
        loss_func = self.loss_func or optimizer.target
        it = self.get_iterator('main')

        batch = it.next()
        data = convert._call_converter(self.converter,batch,self.device)

        use_cleargrads = getattr(optimizer,'_use_cleargrads',True)
        if use_cleargrads and self.iteration_counter == 0:
            optimizer.target.cleargrads()

        self.iteration_counter += it.batch_size
        loss = loss_func(*data)
        loss.backward()

        if self.iteration_counter >= self.update_size:
            self.iteration_counter = 0
            optimizer.update()

实现已经很老了(我认为是 chainer 4 或 5),但我也为 chainer 7.8 工作。可以更新一些行以匹配 update_core 方法的较新实现,但正如我所说,它对我有用。希望它有所帮助;)