问题描述
我正在尝试将模型更改为在 colab 中使用谷歌 CloudTPU。我试图运行的代码在 https://github.com/marcoppasini/MelGAN-VC/blob/master/MelGAN_VC.ipynb 中。我将仅在此处发布修改后的代码,我已按照此处的说明进行操作:https://www.tensorflow.org/guide/tpu#train_a_model_using_keras_high_level_apis。
顶部:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
def build_generator(input_shape):
with strategy.scope():
h,w,c = input_shape
inp = Input(shape=input_shape)
#downscaling
g0 = tf.keras.layers.ZeroPadding2D((0,1))(inp)
g1 = conv2d(g0,256,kernel_size=(h,3),strides=1,padding='valid')
g2 = conv2d(g1,kernel_size=(1,9),strides=(1,2))
g3 = conv2d(g2,7),2))
#upscaling
g4 = deconv2d(g3,g2,2))
g5 = deconv2d(g4,g1,2),bnorm=False)
g6 = ConvSN2DTranspose(1,1),kernel_initializer=init,padding='valid',activation='tanh')(g5)
opt_gen = Adam(0.0001,0.5)
return Model(inp,g6,name='G'),opt_gen
#Siamese Network
def build_siamese(input_shape):
with strategy.scope():
h,c = input_shape
inp = Input(shape=input_shape)
g1 = conv2d(inp,sn=False)
g2 = conv2d(g1,sn=False)
g3 = conv2d(g2,sn=False)
g4 = Flatten()(g3)
g5 = Dense(vec_len)(g4)
return Model(inp,g5,name='S')
#discriminator (Critic) Network
def build_critic(input_shape):
with strategy.scope():
h,512,bnorm=False)
g2 = conv2d(g1,bnorm=False)
g3 = conv2d(g2,bnorm=False)
g4 = Flatten()(g3)
g4 = DenseSN(1,kernel_initializer=init)(g4)
opt_disc = Adam(0.0001,0.5)
return Model(inp,g4,name='C'),opt_disc
最后,在 train_all() 和 train_d 函数中,我将梯度和优化器放在了 strategy.scope() 下。
with strategy.scope():
grad_gen = tape_gen.gradient(lossgtot,gen.trainable_variables+siam.trainable_variables)
opt_gen.apply_gradients(zip(grad_gen,gen.trainable_variables+siam.trainable_variables))
grad_disc = tape_disc.gradient(loss_d,critic.trainable_variables)
opt_disc.apply_gradients(zip(grad_disc,critic.trainable_variables))
RuntimeError: `apply_gradients() cannot be called in cross-replica context. Use `tf.distribute.Strategy.run` to enter replica context.
我会非常感谢你们的帮助!
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)