Tensorflow — 使用 `tf.distribute.MirroredStrategy` 时无法调用 `tf.keras.Model.add_metric`

问题描述

我有一个继承自 tf.keras.Model 的模型类。我可以使用 8 个 GPU 对其进行训练、评估和导出,并通过 tf.distribute.MirroredStrategy 进行分发。但是,我需要自定义指标,当我调用 add_metric 方法时,它会在尝试导出时引发错误

Traceback (most recent call last):
  File "repro/vae.py",line 80,in <module>
    vae.save("vae")
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py",line 1979,in save
    signatures,options)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py",line 134,in save_model
    signatures,options)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save.py",in save
    save_lib.save(model,filepath,signatures,options)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py",line 976,in save
    obj,export_dir,options,Meta_graph_def)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py",line 1047,in _build_Meta_graph
    checkpoint_graph_view)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_serialization.py",line 75,in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py",line 145,in list_functions
    self._serialization_cache)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py",line 2590,in _list_functions_for_serialization
    Model,self)._list_functions_for_serialization(serialization_cache)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py",line 3019,in _list_functions_for_serialization
    .list_functions_for_serialization(serialization_cache))
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py",line 87,in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py",line 79,in functions_to_serialize
    serialization_cache).functions_to_serialize)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py",line 95,in _get_serialized_attributes
    serialization_cache)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py",line 51,in _get_serialized_attributes_internal
    default_signature = save_impl.default_save_signature(self.obj)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py",line 205,in default_save_signature
    fn.get_concrete_function()
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py",line 1167,in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args,**kwargs)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py",line 1073,in _get_concrete_function_garbage_collected
    self._initialize(args,kwargs,add_initializers_to=initializers)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py",line 697,in _initialize
    *args,**kwds))
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/eager/function.py",line 2855,in _get_concrete_function_internal_garbage_collected
    graph_function,_,_ = self._maybe_define_function(args,kwargs)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/eager/function.py",line 3213,in _maybe_define_function
    graph_function = self._create_graph_function(args,line 3075,in _create_graph_function
    capture_by_value=self._capture_by_value),File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py",line 986,in func_graph_from_py_func
    func_outputs = python_func(*func_args,**func_kwargs)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py",line 600,in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args,**kwds)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py",in _wrapped_model
    outputs = model(inputs,training=False)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py",line 985,in __call__
    outputs = call_fn(inputs,*args,**kwargs)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py",line 302,in wrapper
    return func(*args,**kwargs)
  File "repro/vae.py",line 63,in call
    self.add_metric([0.],name="foo")
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py",line 1705,in add_metric
    metric_obj(value)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py",line 231,in __call__
    replica_local_fn,**kwargs)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/distribute/distributed_training_utils.py",line 1133,in call_replica_local_fn
    return fn(*args,**kwargs)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py",line 211,in replica_local_fn
    update_op = self.update_state(*args,**kwargs)  # pylint: disable=not-callable
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/keras/utils/metrics_utils.py",line 90,in decorated
    update_op = update_state_fn(*args,line 176,in update_state_fn
    return ag_update_state(*args,line 373,in update_state
    update_total_op = self.total.assign_add(value_sum)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/distribute/values.py",line 1015,in assign_add
    self,value,read_value=read_value)
  File "/Users/acarlson/anaconda3/envs/ed-autocoder-dev/lib/python3.7/site-packages/tensorflow/python/distribute/values_util.py",in on_read_assign_add_cross_replica
    "SyncOnReadVariable does not support `assign_add` in "
ValueError: SyncOnReadVariable does not support `assign_add` in cross-replica context when aggregation is set to `tf.VariableAggregation.SUM`.

我创建了一个简单的复制品,在这里显示了这个错误

import tensorflow as tf


class Sampling(tf.keras.layers.Layer):
    """Uses (z_mean,z_log_var) to sample z,the vector encoding a digit."""

    def call(self,inputs):
        z_mean,z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch,dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class Encoder(tf.keras.layers.Layer):
    """Maps MNIST digits to a triplet (z_mean,z_log_var,z)."""

    def __init__(self,latent_dim=32,intermediate_dim=64,name="encoder",**kwargs):
        super(Encoder,self).__init__(name=name,**kwargs)
        self.dense_proj = tf.keras.layers.Dense(intermediate_dim,activation="relu")
        self.dense_mean = tf.keras.layers.Dense(latent_dim)
        self.dense_log_var = tf.keras.layers.Dense(latent_dim)
        self.sampling = Sampling()

    def call(self,inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean,z_log_var))
        return z_mean,z


class Decoder(tf.keras.layers.Layer):
    """Converts z,the encoded digit vector,back into a readable digit."""

    def __init__(self,original_dim,name="decoder",**kwargs):
        super(Decoder,activation="relu")
        self.dense_output = tf.keras.layers.Dense(original_dim,activation="sigmoid")

    def call(self,inputs):
        x = self.dense_proj(inputs)
        return self.dense_output(x)


class VariationalAutoEncoder(tf.keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(self,name="autoencoder",**kwargs):
        super(VariationalAutoEncoder,**kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim,intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim,intermediate_dim=intermediate_dim)

    def call(self,z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss.
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        self.add_metric([0.],name="foo")
        return reconstructed


(x_train,_),_ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000,784).astype("float32") / 255

original_dim = 784

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    vae = VariationalAutoEncoder(original_dim,64,32)
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    vae.compile(optimizer,loss=tf.keras.losses.MeanSquaredError())

vae.fit(x_train,x_train,epochs=3,batch_size=64)
vae.save("vae")

我为这么多代码道歉,但其中大部分都不重要。重要的部分是这个模型是在 tf.distribute.MirroredStrategy 范围内实例化和编译的。模型中还有一个 self.add_metric([0.],name="foo")。如果您删除add_metric 调用,则它可以工作。它将正确导出。

因此,将 tf.keras.Model.add_metric 方法tf.distribute.MirroredStrategy 一起使用。我需要能够使用分布式模型添加自定义指标。

注意:指标应该在策略范围内计算,如the docs

中所述

“在 TF 中创建变量的常见事物:模型、优化器、指标。这些应该始终在范围内创建。”

至于版本,我使用的是 Google AI platform runtime version 2.3

解决方法

这是 TF 2.3 版本中的一个错误,并在 2.4 中修复。我在 an issue I filed with TF

收到回复