错误:无法在跨副本上下文中调用 apply_gradients()在使用 TPUStrategy 时使用 tf.distribute.Strategy.run 进入副本上下文

问题描述

我正在尝试将模型更改为在 colab 中使用谷歌 CloudTPU。我试图运行的代码https://github.com/marcoppasini/MelGAN-VC/blob/master/MelGAN_VC.ipynb 中。我将仅在此处发布修改后的代码,我已按照此处的说明进行操作:https://www.tensorflow.org/guide/tpu#train_a_model_using_keras_high_level_apis

顶部:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

然后修改生成器、判别器和孪生网络的构建函数

def build_generator(input_shape):
  with strategy.scope():
    h,w,c = input_shape
    inp = Input(shape=input_shape)
    #downscaling
    g0 = tf.keras.layers.ZeroPadding2D((0,1))(inp)
    g1 = conv2d(g0,256,kernel_size=(h,3),strides=1,padding='valid')
    g2 = conv2d(g1,kernel_size=(1,9),strides=(1,2))
    g3 = conv2d(g2,7),2))
    #upscaling
    g4 = deconv2d(g3,g2,2))
    g5 = deconv2d(g4,g1,2),bnorm=False)
    g6 = ConvSN2DTranspose(1,1),kernel_initializer=init,padding='valid',activation='tanh')(g5)
    opt_gen = Adam(0.0001,0.5)
  return Model(inp,g6,name='G'),opt_gen

#Siamese Network
def build_siamese(input_shape):
  with strategy.scope():
      h,c = input_shape
      inp = Input(shape=input_shape)
      g1 = conv2d(inp,sn=False)
      g2 = conv2d(g1,sn=False)
      g3 = conv2d(g2,sn=False)
      g4 = Flatten()(g3)
      g5 = Dense(vec_len)(g4)
  
      return Model(inp,g5,name='S')

#discriminator (Critic) Network
def build_critic(input_shape):
  with strategy.scope():
      h,512,bnorm=False)
      g2 = conv2d(g1,bnorm=False)
      g3 = conv2d(g2,bnorm=False)
      g4 = Flatten()(g3)
      g4 = DenseSN(1,kernel_initializer=init)(g4)
      opt_disc = Adam(0.0001,0.5)
      return Model(inp,g4,name='C'),opt_disc

最后,在 train_all() 和 train_d 函数中,我将梯度和优化器放在了 strategy.scope() 下。

with strategy.scope():
    grad_gen = tape_gen.gradient(lossgtot,gen.trainable_variables+siam.trainable_variables)
    opt_gen.apply_gradients(zip(grad_gen,gen.trainable_variables+siam.trainable_variables))

    grad_disc = tape_disc.gradient(loss_d,critic.trainable_variables)
    opt_disc.apply_gradients(zip(grad_disc,critic.trainable_variables))

当我执行 train 函数时,出现此错误

RuntimeError: `apply_gradients() cannot be called in cross-replica context. Use `tf.distribute.Strategy.run` to enter replica context.

我会非常感谢你们的帮助!

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)