在keras中自定义fit函数会导致Evaluate函数返回空列表

问题描述

我正在使用Keras教程https://keras.io/examples/generative/vae/训练VAE。这涉及创建VAE类并指定自定义训练过程,在这里https://keras.io/guides/customizing_what_happens_in_fit/中对此进行了详细说明。在按照本教程中的说明创建编码器和解码器并训练了模型之后,我通过以下步骤创建了VAE模型:

vae = VAE(encoder,decoder) 
vae.compile(optimizer=keras.optimizers.Adam())

vae.fit(x=x_train,y=None,epochs=epochs,batch_size=batch_size,verbose=False,validation_data=(x_test,None))

我想在训练后在单独的数据集上评估模型(因为我有多个评估数据集,所以我不将其用作validation_data)。但是,当我尝试运行vae.evaluate(data)时,它将返回一个空列表[]

注意:我可以通过vae.history.history轻松获得培训和验证指标,但是问题是当我尝试在培训后进行评估时。但是,当我尝试返回指标vae.metrics时,它也会返回一个空列表。如何使model.evaluate自定义训练过程一起使用,该训练过程返回损失指标的指标?我需要定义一些自定义的评估方式吗?

这是VAE类的定义方式。更多细节可以在上面的教程中找到。

class VAE(keras.Model):
    def __init__(self,encoder,decoder,**kwargs):
        super(VAE,self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self,data):
        if isinstance(data,tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean,z_log_var,z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data,reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss,self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads,self.trainable_weights))
        return {
            "loss": total_loss,"reconstruction_loss": reconstruction_loss,"kl_loss": kl_loss,}

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...