Keras变体自动编码器上的ValueError-代码示例不起作用

问题描述

我真的是神经网络编程新手,并且对Keras上的代码示例有疑问。

Keras:https://keras.io/examples/generative/vae/
GitHub:https://github.com/keras-team/keras-io/blob/master/examples/generative/vae.py

"""
Title: Variational AutoEncoder
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2020/05/03
Last modified: 2020/05/03
Description: Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.
"""

"""
## Setup
"""

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

"""
## Create a sampling layer
"""


class Sampling(layers.Layer):
    """Uses (z_mean,z_log_var) to sample z,the vector encoding a digit."""

    def call(self,inputs):
        z_mean,z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch,dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


"""
## Build the encoder
"""

latent_dim = 2

encoder_inputs = keras.Input(shape=(28,28,1))
x = layers.Conv2D(32,3,activation="relu",strides=2,padding="same")(encoder_inputs)
x = layers.Conv2D(64,padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16,activation="relu")(x)
z_mean = layers.Dense(latent_dim,name="z_mean")(x)
z_log_var = layers.Dense(latent_dim,name="z_log_var")(x)
z = Sampling()([z_mean,z_log_var])
encoder = keras.Model(encoder_inputs,[z_mean,z_log_var,z],name="encoder")
encoder.summary()

"""
## Build the decoder
"""

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64,activation="relu")(latent_inputs)
x = layers.Reshape((7,7,64))(x)
x = layers.Conv2DTranspose(64,padding="same")(x)
x = layers.Conv2DTranspose(32,padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1,activation="sigmoid",padding="same")(x)
decoder = keras.Model(latent_inputs,decoder_outputs,name="decoder")
decoder.summary()

"""
## Define the VAE as a `Model` with a custom `train_step`
"""


class VAE(keras.Model):
    def __init__(self,encoder,decoder,**kwargs):
        super(VAE,self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self,data):
        if isinstance(data,tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean,z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data,reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss,self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads,self.trainable_weights))
        return {
            "loss": total_loss,"reconstruction_loss": reconstruction_loss,"kl_loss": kl_loss,}


"""
## Train the VAE
"""

(x_train,_),(x_test,_) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train,x_test],axis=0)
mnist_digits = np.expand_dims(mnist_digits,-1).astype("float32") / 255

vae = VAE(encoder,decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits,epochs=30,batch_size=128)

"""
## display a grid of sampled digits
"""

import matplotlib.pyplot as plt


def plot_latent(encoder,decoder):
    # display a n*n 2D manifold of digits
    n = 30
    digit_size = 28
    scale = 2.0
    figsize = 15
    figure = np.zeros((digit_size * n,digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale,scale,n)
    grid_y = np.linspace(-scale,n)[::-1]

    for i,yi in enumerate(grid_y):
        for j,xi in enumerate(grid_x):
            z_sample = np.array([[xi,yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size,digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,j * digit_size : (j + 1) * digit_size,] = digit

    plt.figure(figsize=(figsize,figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range,end_range,digit_size)
    sample_range_x = np.round(grid_x,1)
    sample_range_y = np.round(grid_y,1)
    plt.xticks(pixel_range,sample_range_x)
    plt.yticks(pixel_range,sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure,cmap="Greys_r")
    plt.show()


plot_latent(encoder,decoder)

"""
## display how the latent space clusters different digit classes
"""


def plot_label_clusters(encoder,data,labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean,_,_ = encoder.predict(data)
    plt.figure(figsize=(12,10))
    plt.scatter(z_mean[:,0],z_mean[:,1],c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()


(x_train,y_train),_ = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train,-1).astype("float32") / 255

plot_label_clusters(encoder,x_train,y_train)

这是有关在MNIST数据集上使用Keras构建的VAE(可变自动编码器)的信息。当我从GitHub复制示例代码时,总是会遇到以下失败(我没有更改代码):

"ValueError: The model cannot be compiled because it has no loss to optimize." Also I get following Warning: *"WARNING:tensorflow:Output output_1 missing from loss dictionary. 

我们认为这是有目的的。拟合和评估API不会期望将任何数据传递到output_1。“ *

开始时有更多警告:

"WARNING:tensorflow:AutoGraph Could not transform <bound method Sampling.call of <__main__.Sampling object at 0x000002CB451262E8>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug,set the verbosity to 10 (on Linux,`export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 
WARNING: AutoGraph Could not transform <bound method Sampling.call of <__main__.Sampling object at 0x000002CB451262E8>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug,`export AUTOGRAPH_VERBOSITY=10`) and attach the full output."

到目前为止,我已在Windows 10上尝试使用Python 3.6和Python 3.7。有人遇到此错误,谁知道解决方案?

提前谢谢!

解决方法

我遇到了类似的问题。该示例假定使用TF 2.3.0,检查您的TF版本,并在可能的情况下对其进行升级。

最佳