问题描述
我编写了一个基于 this paper 的自定义模型,用于 TF2.0 中的样式转换。 简而言之,所提出算法的损失函数需要评估 3 个损失分量。该模型接受 2 个输入图像,比如说 Ic、Is(c 代表内容,s 代表样式),然后弹出一个拼贴图像 O。
在单个训练步骤中,网络接收以下对作为输入并弹出相应的图像:
- Ic,Is -> O(期望)
- Ic,Ic -> O 身份丢失1
- 是,是 -> O 身份丢失2
然后特征网络评估不同的损失分量(因为主网络需要 3 次前向传递,但不是可训练网络的一部分,因此它具有不可训练的权重)
代码如下:
class StyleTranasfer(keras.Model):
def __init__(self,autoencoder,net):
super(StyleTranasfer,self).__init__()
self.autoencoder = autoencoder
self.net=net
def compile(self,optimizer,loss):
super(StyleTranasfer,self).compile()
self.optimizer = optimizer
self.loss_fn=loss
@tf.function
def call(self,input,training=False):
content_images,style_images=input
return self.autoencoder((tf.image.resize(content_images,[224,224]),tf.image.resize(style_images,224])))
@tf.function
def test_step(self,data):
stylyzed_output=self.call(data)
stylyzed_content_output=self.call((content_images,content_images))
stylyzed_style_output=self.call((style_images,style_images))
d_loss,style_loss,content_loss,identity_loss_1,identity_loss_2 = self.loss_fn(
self.prepro(stylyzed_output),self.prepro(content_images),self.prepro(style_images),self.prepro(stylyzed_style_output),self.prepro(stylyzed_content_output),self.net)
return {"loss": loss,"style_loss": style_loss,"content_loss": content_loss,"identity_loss_1": identity_loss_1,"identity_loss_2": identity_loss_2}
@tf.function
def train_step(self,data):
content_images,style_images=data
with tf.GradientTape() as tape:
stylyzed_output=self.call(data)
stylyzed_content_output=self.call((content_images,content_images))
stylyzed_style_output=self.call((style_images,style_images))
loss,identity_loss_2 = self.loss_fn(
self.prepro(stylyzed_output),self.net)
grads = tape.gradient(loss,self.autoencoder.trainable_weights)
self.optimizer.apply_gradients(zip(grads,self.autoencoder.trainable_weights))
return {"loss": loss,"identity_loss_2": identity_loss_2}
def prepro(self,img):
return tf.keras.applications.vgg19.preprocess_input(128.0*img)
我可以很容易地训练模型,但是当我尝试 save_weights 时,我得到:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py",line 2882,in run_code
exec(code_obj,self.user_global_ns,self.user_ns)
File "<ipython-input-20-785ae311d57f>",line 4,in <module>
model.save_weights('prova.h5')
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py",line 2108,in save_weights
hdf5_format.save_weights_to_hdf5_group(f,self.layers)
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/hdf5_format.py",line 642,in save_weights_to_hdf5_group
param_dset = g.create_dataset(name,val.shape,dtype=val.dtype)
File "/usr/local/lib/python3.7/dist-packages/h5py/_hl/group.py",line 139,in create_dataset
self[name] = dset
File "/usr/local/lib/python3.7/dist-packages/h5py/_hl/group.py",line 373,in __setitem__
h5o.link(obj.id,self.id,name,lcpl=lcpl,lapl=self._lapl)
File "h5py/_objects.pyx",line 54,in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx",line 55,in h5py._objects.with_phil.wrapper
File "h5py/h5o.pyx",line 202,in h5py.h5o.link
RuntimeError: Unable to create link (name already exists)
... 有人知道吗?
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)