使用 tf.function

问题描述

如何在 tensorflow 2.x 中实现微批处理?也就是说,我想为多个批次累积梯度,然后使用这些累积梯度更新权重(这实际上会将我的批次大小增加到累积步长 * 批次大小)。

我尝试使用以下代码

import numpy as np
import tensorflow as tf

class Model(tf.keras.Model):
    def __init__(self,):
        super().__init__()
    
        self.dense = tf.keras.layers.Dense(1)

    def call(self,inputs):
        return self.dense(inputs)


class Trainer:
    def __init__(self,model,num_accumulate):
        self.model = model
        self.num_accumulate = num_accumulate
        self.optimizer = tf.keras.optimizers.Adam()
        self.accumulated_gradients = None

    def _init_accumulated_gradients_maybe(self):
        if self.accumulated_gradients is None:
            self.accumulated_gradients = [tf.Variable(var,dtype=var.dtype,trainable=False) for var in self.model.trainable_weights]
            self._reset_gradients()

    def _reset_gradients(self):
        for grad in self.accumulated_gradients:
            grad.assign(tf.zeros_like(grad))

    def _accumulate_gradients(self,gradients):
        for acc_grad,grad in zip(self.accumulated_gradients,gradients):
            acc_grad.assign_add( grad / self.num_accumulate )

    def get_mae(self,targets,mean_pred):
        return tf.reduce_mean(tf.abs(targets - mean_pred))

    @tf.function
    def train_on_batch(self,dataset_iter):
        
        for _ in range(self.num_accumulate): # problematic
            inputs,target = next(dataset_iter)

            with tf.GradientTape() as tape:
                prediction = self.model(inputs,training=True)
                loss = self.get_mae(target,prediction)

            gradients = tape.gradient(loss,self.model.trainable_weights)

            self._init_accumulated_gradients_maybe()
            self._accumulate_gradients(gradients)
            gradients = self.accumulated_gradients

        self.optimizer.apply_gradients(zip(gradients,self.model.trainable_weights))
        self._reset_gradients()

        return loss

class DataProvider:
    def __init__(self,batch_size: int = 1,):
        self.batch_size = batch_size
        self.in_data = np.random.rand(100,10)
        self.out_data = np.random.rand(100,1)

    def get_dataset(self):
        def generator():
            while True:
                yield (tf.constant(self.in_data,dtype=tf.float32),tf.constant(self.out_data,dtype=tf.float32))

        return tf.data.Dataset.from_generator(
                generator,output_types=(tf.float32,tf.float32),output_shapes=([None,10],[None,1])
                )


num_accumulate = 4
batch_size = 25
nSteps = 10

model = Model()
trainer = Trainer(model,num_accumulate)
dataset_iter = iter(DataProvider(batch_size).get_dataset())

for step in range(1,nSteps):
    trainer.train_on_batch(dataset_iter)

但是,根据我在 tf.function 装饰函数中使用 tf.range 还是 range ,我遇到了两个不同的问题。

  1. 使用范围:它适用于提供的迷你模型,但在我的用例中,模型明显更大(2.6 Mio 参数),当我累积这样的梯度时,会引发以下错误

2021-04-24 18:19:28.349940:W tensorflow/core/common_runtime/process_function_library_runtime.cc:733] 忽略多设备功能优化失败:超过期限:Meta_optimizer 超过期限。

我的猜测是使用范围(据我了解 tf.function 的工作原理)将每个梯度累积步骤添加到图中,而不是重复这部分并仅添加一次。

  1. 用 tf.range 替换 range 会引发以下错误
Traceback (most recent call last):
  File "/mydirectory/model/test_train copy.py",line 89,in <module>
    trainer.train_on_batch(dataset_iter)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py",line 580,in __call__
    result = self._call(*args,**kwds)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py",line 627,in _call
    self._initialize(args,kwds,add_initializers_to=initializers)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py",line 505,in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py",line 2446,in _get_concrete_function_internal_garbage_collected
    graph_function,_,_ = self._maybe_define_function(args,kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py",line 2777,in _maybe_define_function
    graph_function = self._create_graph_function(args,line 2657,in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py",line 981,in func_graph_from_py_func
    func_outputs = python_func(*func_args,**func_kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py",line 441,in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args,**kwds)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py",line 3299,in bound_method_wrapper
    return wrapped_fn(*args,**kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py",line 968,in wrapper
    raise e.ag_error_Metadata.to_exception(e)
ValueError: in user code:

    /mydirectory/model/test_train copy.py:40 train_on_batch  *
        for _ in tf.range(self.num_accumulate):
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:343 for_stmt
        _tf_range_for_stmt(
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:526 _tf_range_for_stmt
        _tf_while_stmt(
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:862 _tf_while_stmt
        _verify_loop_init_vars(init_vars,symbol_names)
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:119 _verify_loop_init_vars
        raise ValueError('"{}" must be defined before the loop.'.format(name))

    ValueError: "loss" must be defined before the loop.

因此,我初始化了所有出现的变量,例如梯度、损失和预测,然后它就可以工作了,但是速度很慢(在我的用例中),这是为什么?

我错过了什么?非常感谢任何帮助。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...