tensorflow 2.x / keras 中的梯度积累

问题描述

我正在尝试在 TF2.x 上实现梯度累积。我发现的所有实现都适用于 TF1.x 或旧的 keras 接口。我不认为有实现(尽管我很高兴被证明是错误的)。

这是我正在使用的:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Flatten,Dense
from tqdm import tqdm
import matplotlib.pyplot as plt


class SimpleTrainStepModel(Model):
    def train_step(self,data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x,y,sample_weight = data
        else:
            (x,y),sample_weight = data,None


        # FirsT GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x,training = True)  # Forward pass
            loss = self.compiled_loss(y,y_pred,sample_weight = sample_weight,regularization_losses = self.losses)
        gradients = tape.gradient(loss,self.trainable_variables)
        self.compiled_metrics.update_state(y,y_pred)

        self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))

        return {m.name: m.result() for m in self.metrics}


class GradAccumModel(Model):
    def fit(self,*args,batch_size = 32,grad_accum = 1,**kwargs):
        self.train_function = None
        if batch_size % grad_accum != 0:
            raise ValueError('Batch size must be divisible by the Gradient accumulation steps,dummy!')
        self.grad_accum = grad_accum
        self.batch_size = batch_size
        return super(GradAccumModel,self).fit(*args,batch_size = self.batch_size,#validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,**kwargs)

    def train_step(self,None

        step = self.batch_size // self.grad_accum

        # def _slice_nested(obj,i,j):
        #     if type(obj) is list:
        #         return [o[i:j] for o in obj]
        #     else:
        #         return obj[i:j]

        # FirsT GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x[:step],training = True)  # Forward pass
            loss = self.compiled_loss(y[:step],self.trainable_variables)
        self.compiled_metrics.update_state(y[:step],y_pred)

        i = tf.constant(step)
        # tf.print('TF - HERE!')
        def cond(i,*args):
            return i < self.batch_size
        def body(i,grad):
            # tf.print('\tTF - HERE!')
            with tf.GradientTape() as tape:
                y_pred = self(x[i:i + step],training = True) # Forward pass
                loss = self.compiled_loss(y[i:i + step],regularization_losses = self.losses)
            _grad = tape.gradient(loss,self.trainable_variables)

            for g,_g in zip(grad,_grad):
                g += _g

            self.compiled_metrics.update_state(y[i:i + step],y_pred)
            return [i + step,grad]

        i,gradients = tf.while_loop(cond,body,[i,gradients],parallel_iterations = 1)


        # for g in gradients:        # I tried with and without division co calculate the mean
        #     g *= 1/self.grad_accum #


        # Update weights
        self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))
        # Update metrics (includes the metric that tracks the loss)

        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}


if __name__ == '__main__':
    (x_train,y_train),(x_valid,y_valid) = tf.keras.datasets.mnist.load_data()

    for MODEL,ga_kwarg,colour in list(zip([Model,SimpleTrainStepModel,GradAccumModel,GradAccumModel],[{},{},{'grad_accum': 1},{'grad_accum': 6}],['blue','green','yellow','red'])):

        for _ in tqdm(range(10)):
            # tf.random.set_seed(0)
            x = Input((28,28))
            y = x
            y = Flatten()(y)
            y = Dense(128,activation = 'sigmoid')(y)
            y = Dense(10,activation = 'softmax')(y)

            model = MODEL(x,y)
            model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),optimizer = tf.keras.optimizers.Adam(1e-4),metrics = ['acc'])

            hist = model.fit(x_train,y_train,validation_data = (x_valid,y_valid),verbose = 0,batch_size = 6000,epochs = 100,**ga_kwarg)
            plt.plot(hist.history['val_acc'],color = colour,alpha = .25)

    plt.title('')
    plt.xscale('symlog')
    plt.yscale('logit')
    plt.show()

我已经能够验证它确实节省了 GPU 内存。但是,最终结果与正常的 Model.fit 不同。

Validation

Close-up

如您所见,前三个 Model.fit 聚类良好并给出相同的结果。但是当 while 循环开始时,训练就完全不同了。

有人知道为什么会这样吗?

解决方法

经过多次尝试后,我找到了解决方案,看来主要问题是梯度的复合分配,它不像我预期的那样工作。这是我为可能感兴趣的任何人提供的最终解决方案。它包括用于分布式、混合精度训练和嵌套输入/输出的额外内容。

from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.util import nest
from tensorflow.keras.models import Model as _Model


class Model(_Model):
    def fit(self,*args,batch_size: int = 32,grad_accum_steps: int = 1,**kwargs):
        """
        Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.

        Parameters
        ----------
        batch_size : int
            same as in Model.fit
        grad_accum_steps : int
            Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
        """
        if grad_accum_steps == 1:
            super().fit(*args,batch_size = batch_size,**kwargs)

        self.train_function = None
        num_workers = ds_context.get_strategy().num_replicas_in_sync
        if batch_size % (grad_accum_steps * num_workers) != 0:
            raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}),and the number of replicas ({num_workers}),dummy!')

        self._grad_accum_ = grad_accum_steps
        self._batch_size_ = batch_size
        self._num_workers_ = num_workers
        train_step_backup = self.train_step
        self.train_step = self._train_step_
        out = super(self).fit(*args,batch_size = self._batch_size_,# TODO maybe consider validation batch size
                              **kwargs)

        del self._grad_accum_
        del self._batch_size_
        del self._num_workers_
        self.train_step = train_step_backup
        return out

    def _train_step_(self,data):
        """
        Custom training step taking into account gradient accumulation for low memory training
        """

        if len(data) == 3:
            x,y,sample_weight = data
        else:
            (x,y),sample_weight = data,None


        def slice_map(struct,start,stop): # dealing with nasty nested structures
            if struct is None:
                return None # special case for sample_weight

            return nest.map_structure(lambda x: x[start:stop],struct)



        # ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
        step = self._batch_size_ // self._num_workers_ // self._grad_accum_
        x_ = slice_map(x,step)
        y_ = slice_map(y,step)
        w_ = slice_map(sample_weight,step)

        with tf.GradientTape() as tape:

            y_pred = self(x_,training = True)  # Forward pass
            loss = self.compiled_loss(y_,y_pred,sample_weight = w_,regularization_losses = self.losses)
            if isinstance(self.optimizer,lso.LossScaleOptimizer):
                loss = self.optimizer.get_scaled_loss(loss)

        gradients = tape.gradient(loss,self.trainable_variables)
        gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
        self.compiled_metrics.update_state(y_,y_pred)

        i = tf.constant(step)
        def cond(i,*args):
            return i < self._batch_size_

        def body(i,grad):
            x_ = slice_map(x,i,i + step)
            y_ = slice_map(y,i + step)
            w_ = slice_map(sample_weight,i + step)

            with tf.GradientTape() as tape:
                y_pred = self(x_,training = True) # Forward pass
                loss = self.compiled_loss(y_,regularization_losses = self.losses)
                if isinstance(self.optimizer,lso.LossScaleOptimizer):
                    loss = self.optimizer.get_scaled_loss(loss)

            _grad = tape.gradient(loss,self.trainable_variables)
            _grad = [_g * (1./self._grad_accum_) for _g in _grad]

            grad = [g + _g for g,_g in zip(grad,_grad)]

            self.compiled_metrics.update_state(y_,y_pred)
            return [i + step,grad]

        i,gradients = tf.while_loop(cond,body,[i,gradients],parallel_iterations = 1)
        # --------------------------------------------------------------------------------------------------------------



        # ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
        aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended,parameter_server_strategy.ParameterServerStrategyExtended))

        if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling,due to the extra accumulation steps
            gradients = self.optimizer._aggregate_gradients(zip(gradients,self.trainable_variables))

        if isinstance(self.optimizer,lso.LossScaleOptimizer):
            gradients = self.optimizer.get_unscaled_gradients(gradients)

        gradients = self.optimizer._clip_gradients(gradients)
        if self.trainable_variables:
            if aggregate_grads_outside_optimizer:
                self.optimizer.apply_gradients(zip(gradients,self.trainable_variables),experimental_aggregate_gradients = False)
            else:
                self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))
        # --------------------------------------------------------------------------------------------------------------


        return {m.name: m.result() for m in self.metrics}

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...