问题描述
我正在尝试在 TF2.x 上实现梯度累积。我发现的所有实现都适用于 TF1.x 或旧的 keras 接口。我不认为有实现(尽管我很高兴被证明是错误的)。
这是我正在使用的:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Flatten,Dense
from tqdm import tqdm
import matplotlib.pyplot as plt
class SimpleTrainStepModel(Model):
def train_step(self,data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x,y,sample_weight = data
else:
(x,y),sample_weight = data,None
# FirsT GRADIENT
with tf.GradientTape() as tape:
y_pred = self(x,training = True) # Forward pass
loss = self.compiled_loss(y,y_pred,sample_weight = sample_weight,regularization_losses = self.losses)
gradients = tape.gradient(loss,self.trainable_variables)
self.compiled_metrics.update_state(y,y_pred)
self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))
return {m.name: m.result() for m in self.metrics}
class GradAccumModel(Model):
def fit(self,*args,batch_size = 32,grad_accum = 1,**kwargs):
self.train_function = None
if batch_size % grad_accum != 0:
raise ValueError('Batch size must be divisible by the Gradient accumulation steps,dummy!')
self.grad_accum = grad_accum
self.batch_size = batch_size
return super(GradAccumModel,self).fit(*args,batch_size = self.batch_size,#validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,**kwargs)
def train_step(self,None
step = self.batch_size // self.grad_accum
# def _slice_nested(obj,i,j):
# if type(obj) is list:
# return [o[i:j] for o in obj]
# else:
# return obj[i:j]
# FirsT GRADIENT
with tf.GradientTape() as tape:
y_pred = self(x[:step],training = True) # Forward pass
loss = self.compiled_loss(y[:step],self.trainable_variables)
self.compiled_metrics.update_state(y[:step],y_pred)
i = tf.constant(step)
# tf.print('TF - HERE!')
def cond(i,*args):
return i < self.batch_size
def body(i,grad):
# tf.print('\tTF - HERE!')
with tf.GradientTape() as tape:
y_pred = self(x[i:i + step],training = True) # Forward pass
loss = self.compiled_loss(y[i:i + step],regularization_losses = self.losses)
_grad = tape.gradient(loss,self.trainable_variables)
for g,_g in zip(grad,_grad):
g += _g
self.compiled_metrics.update_state(y[i:i + step],y_pred)
return [i + step,grad]
i,gradients = tf.while_loop(cond,body,[i,gradients],parallel_iterations = 1)
# for g in gradients: # I tried with and without division co calculate the mean
# g *= 1/self.grad_accum #
# Update weights
self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
if __name__ == '__main__':
(x_train,y_train),(x_valid,y_valid) = tf.keras.datasets.mnist.load_data()
for MODEL,ga_kwarg,colour in list(zip([Model,SimpleTrainStepModel,GradAccumModel,GradAccumModel],[{},{},{'grad_accum': 1},{'grad_accum': 6}],['blue','green','yellow','red'])):
for _ in tqdm(range(10)):
# tf.random.set_seed(0)
x = Input((28,28))
y = x
y = Flatten()(y)
y = Dense(128,activation = 'sigmoid')(y)
y = Dense(10,activation = 'softmax')(y)
model = MODEL(x,y)
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),optimizer = tf.keras.optimizers.Adam(1e-4),metrics = ['acc'])
hist = model.fit(x_train,y_train,validation_data = (x_valid,y_valid),verbose = 0,batch_size = 6000,epochs = 100,**ga_kwarg)
plt.plot(hist.history['val_acc'],color = colour,alpha = .25)
plt.title('')
plt.xscale('symlog')
plt.yscale('logit')
plt.show()
我已经能够验证它确实节省了 GPU 内存。但是,最终结果与正常的 Model.fit
不同。
如您所见,前三个 Model.fit
聚类良好并给出相同的结果。但是当 while
循环开始时,训练就完全不同了。
有人知道为什么会这样吗?
解决方法
经过多次尝试后,我找到了解决方案,看来主要问题是梯度的复合分配,它不像我预期的那样工作。这是我为可能感兴趣的任何人提供的最终解决方案。它包括用于分布式、混合精度训练和嵌套输入/输出的额外内容。
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.util import nest
from tensorflow.keras.models import Model as _Model
class Model(_Model):
def fit(self,*args,batch_size: int = 32,grad_accum_steps: int = 1,**kwargs):
"""
Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.
Parameters
----------
batch_size : int
same as in Model.fit
grad_accum_steps : int
Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
"""
if grad_accum_steps == 1:
super().fit(*args,batch_size = batch_size,**kwargs)
self.train_function = None
num_workers = ds_context.get_strategy().num_replicas_in_sync
if batch_size % (grad_accum_steps * num_workers) != 0:
raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}),and the number of replicas ({num_workers}),dummy!')
self._grad_accum_ = grad_accum_steps
self._batch_size_ = batch_size
self._num_workers_ = num_workers
train_step_backup = self.train_step
self.train_step = self._train_step_
out = super(self).fit(*args,batch_size = self._batch_size_,# TODO maybe consider validation batch size
**kwargs)
del self._grad_accum_
del self._batch_size_
del self._num_workers_
self.train_step = train_step_backup
return out
def _train_step_(self,data):
"""
Custom training step taking into account gradient accumulation for low memory training
"""
if len(data) == 3:
x,y,sample_weight = data
else:
(x,y),sample_weight = data,None
def slice_map(struct,start,stop): # dealing with nasty nested structures
if struct is None:
return None # special case for sample_weight
return nest.map_structure(lambda x: x[start:stop],struct)
# ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
step = self._batch_size_ // self._num_workers_ // self._grad_accum_
x_ = slice_map(x,step)
y_ = slice_map(y,step)
w_ = slice_map(sample_weight,step)
with tf.GradientTape() as tape:
y_pred = self(x_,training = True) # Forward pass
loss = self.compiled_loss(y_,y_pred,sample_weight = w_,regularization_losses = self.losses)
if isinstance(self.optimizer,lso.LossScaleOptimizer):
loss = self.optimizer.get_scaled_loss(loss)
gradients = tape.gradient(loss,self.trainable_variables)
gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
self.compiled_metrics.update_state(y_,y_pred)
i = tf.constant(step)
def cond(i,*args):
return i < self._batch_size_
def body(i,grad):
x_ = slice_map(x,i,i + step)
y_ = slice_map(y,i + step)
w_ = slice_map(sample_weight,i + step)
with tf.GradientTape() as tape:
y_pred = self(x_,training = True) # Forward pass
loss = self.compiled_loss(y_,regularization_losses = self.losses)
if isinstance(self.optimizer,lso.LossScaleOptimizer):
loss = self.optimizer.get_scaled_loss(loss)
_grad = tape.gradient(loss,self.trainable_variables)
_grad = [_g * (1./self._grad_accum_) for _g in _grad]
grad = [g + _g for g,_g in zip(grad,_grad)]
self.compiled_metrics.update_state(y_,y_pred)
return [i + step,grad]
i,gradients = tf.while_loop(cond,body,[i,gradients],parallel_iterations = 1)
# --------------------------------------------------------------------------------------------------------------
# ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended,parameter_server_strategy.ParameterServerStrategyExtended))
if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling,due to the extra accumulation steps
gradients = self.optimizer._aggregate_gradients(zip(gradients,self.trainable_variables))
if isinstance(self.optimizer,lso.LossScaleOptimizer):
gradients = self.optimizer.get_unscaled_gradients(gradients)
gradients = self.optimizer._clip_gradients(gradients)
if self.trainable_variables:
if aggregate_grads_outside_optimizer:
self.optimizer.apply_gradients(zip(gradients,self.trainable_variables),experimental_aggregate_gradients = False)
else:
self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))
# --------------------------------------------------------------------------------------------------------------
return {m.name: m.result() for m in self.metrics}