问题描述
如何使用 Chainer 中的 Standard Updater 类大幅增加 mini-batch 的数量?
在 PyTorch 的情况下, 我可以大幅增加 mini-batch 的数量。
问题 1。 在 Chainer 的情况下, 是否可以大幅增加 mini-batch 的数量?
问题 2。 事实上,我正在使用 StandardUpdater 类。 是否可以使用任何超参数大幅增加小批量的数量? 或者我应该让我的类继承自 StandardUpdater 类并更改上面的实现?
如果问题已经被问到,我很抱歉。
希望大家给点建议。
解决方法
(这个问题似乎很老了,但我偶然发现了它并想分享我对问题的解决方案)
您基本上可以像在 PyTorch 中那样做。不幸的是,StandardUpdater 既没有支持它的超参数,也没有“小批量更新”的实现。但这是我的实现,我是如何做到的(基本上正如您在问题中提到的:从 StandardUpdater 继承并重新实现 update_core 方法):
from chainer.training import StandardUpdater
from chainer.dataset import convert
class MiniBatchUpdater(StandardUpdater):
"""
The iterator outputs batches in mini-batch sizes. This updater
cummulates the gradients of these mini-batches until the
update_size is reached. Then a parameter update is performed
"""
def __init__(self,update_size=32,*args,**kwargs):
super(MiniBatchUpdater,self).__init__(*args,**kwargs)
self.update_size = update_size
self.iteration_counter = 0
def update_core(self):
optimizer = self.get_optimizer('main')
loss_func = self.loss_func or optimizer.target
it = self.get_iterator('main')
batch = it.next()
data = convert._call_converter(self.converter,batch,self.device)
use_cleargrads = getattr(optimizer,'_use_cleargrads',True)
if use_cleargrads and self.iteration_counter == 0:
optimizer.target.cleargrads()
self.iteration_counter += it.batch_size
loss = loss_func(*data)
loss.backward()
if self.iteration_counter >= self.update_size:
self.iteration_counter = 0
optimizer.update()
实现已经很老了(我认为是 chainer 4 或 5),但我也为 chainer 7.8 工作。可以更新一些行以匹配 update_core 方法的较新实现,但正如我所说,它对我有用。希望它有所帮助;)