Tensorflow Variational Autoencoder 的自定义训练循环:`tape.gradient(loss,decoder_model.trainable_weights)` 总是返回全无的列表

问题描述

我正在尝试为由两个单独的 tf.keras.Model 对象组成的变分自编码器 (VAE) 编写自定义训练循环。此 VAE 的目标是多类分类。像往常一样,编码器模型的输出作为输入提供给解码器模型。解码器是一个循环解码器。同样像往常一样,VAE 中涉及两个损失函数:重建损失(分类交叉熵)和潜在损失。我当前架构的灵感基于此 github 上的 pytorch 实现。

问题:每当我使用 tape.gradient(loss,decoder.trainable_weights) 为解码器模型计算梯度时,返回的列表中每个元素都只有 nonetype 对象。我假设我在使用 reconstruction_tensor 时犯了一些错误,它靠近我在下面编写的代码底部。由于我需要进行迭代解码过程,如何在不返回渐变的 nonetype 元素列表的情况下使用类似 reconstruction_tensor 的东西?如果您愿意,您可以使用此 colab notebook 运行代码

为了进一步阐明这个问题中的张量是什么样的,我将说明原始输入、将分配给预测的“令牌”的零张量,以及基于预测的“令牌”的零张量的单一更新来自解码器:

Example original input tensor of shape (batch_size,max_seq_length,num_classes):
 _    _         _     _         _     _         _    _
|    |  1 0 0 0  |   |  0 1 0 0  |   |  0 0 0 1  |    |
|    |  0 1 0 0  |   |  1 0 0 0  |   |  1 0 0 0  |    |
|_   |_ 0 0 1 0 _|,|_ 0 0 0 1 _|,|_ 0 1 0 0 _|   _|

Initial zeros tensor:
 _    _         _     _         _     _         _    _
|    |  0 0 0 0  |   |  0 0 0 0  |   |  0 0 0 0  |    |
|    |  0 0 0 0  |   |  0 0 0 0  |   |  0 0 0 0  |    |
|_   |_ 0 0 0 0 _|,|_ 0 0 0 0 _|,|_ 0 0 0 0 _|   _|

Example zeros tensor after a single iteration of the decoding loop:
 _    _                 _     _                 _     _                   _    _
|    |  0.2 0.4 0.1 0.3  |   |  0.1 0.2 0.6 0.1  |   |  0.7 0.05 0.05 0.2  |    |
|    |  0   0   0   0    |   |  0   0   0   0    |   |  0   0    0    0    |    |
|_   |_ 0   0   0   0   _|,|_ 0   0   0   0   _|,|_ 0   0    0    0   _|   _|

这是重现问题的代码

# Arbitrary data
batch_size = 3  
max_seq_length = 3
num_classes = 4
original_inputs = tf.one_hot(tf.argmax((np.random.randn(batch_size,num_classes)),axis=2),depth=num_classes)
latent_dims = 5  # Must be less than (max_seq_length * num_classes)

def sampling(inputs):
    """Reparametrization function. Used for Lambda layer"""

    mus,log_vars = inputs
    epsilon = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mus))
    z = mus + tf.keras.backend.exp(log_vars/2) * epsilon

    return z

def latent_loss_fxn(mus,log_vars):
    """Return latent loss for means and log variance."""

    return -0.5 * tf.keras.backend.mean(1. + log_vars - tf.keras.backend.exp(log_vars) - tf.keras.backend.pow(mus,2))

class DummyEncoder(tf.keras.Model):
    def __init__(self,latent_dimension):
        """Define the hidden layer (bottleneck) and sampling layers"""

        super().__init__()
        self.hidden = tf.keras.layers.Dense(units=32)
        self.dense_mus = tf.keras.layers.Dense(units=latent_dimension)
        self.dense_log_vars = tf.keras.layers.Dense(units=latent_dimension)
        self.sampling = tf.keras.layers.Lambda(function=sampling)

    def call(self,inputs):
        """Define forward computation that outputs z,mu,log_var of input."""

        dense_projection = self.hidden(inputs)

        mus = self.dense_mus(dense_projection)
        log_vars = self.dense_log_vars(dense_projection)
        z = self.sampling([mus,log_vars])

        return z,mus,log_vars
        

class DummyDecoder(tf.keras.Model):
    def __init__(self,num_classes):
        """Define GRU layer and the Dense output layer"""

        super().__init__()
        self.gru = tf.keras.layers.GRU(units=1,return_sequences=True,return_state=True)
        self.dense = tf.keras.layers.Dense(units=num_classes,activation='softmax')

    def call(self,x,hidden_states=None):
        """Define forward computation"""

        outputs,h_t = self.gru(x,hidden_states)

        # The purpose of this computation is to use the unnormalized log
        # probabilities from the GRU to produce normalized probabilities via
        # the softmax activation function in the Dense layer
        reconstructions = self.dense(outputs)

        return reconstructions,h_t

# Instantiate the models
encoder_model = DummyEncoder(latent_dimension=5)
decoder_model = DummyDecoder(num_classes=num_classes)

# Instantiate reconstruction loss function
cce_loss_fxn = tf.keras.losses.CategoricalCrossentropy()

# Begin tape
with tf.GradientTape(persistent=True) as tape:
    # Flatten the inputs for the encoder
    reshaped_inputs = tf.reshape(original_inputs,shape=(tf.shape(original_inputs)[0],-1))

    # Encode the input
    z,log_vars = encoder_model(reshaped_inputs,training=True)

    # Expand dimensions of z so it meets recurrent decoder requirements of
    # (batch,timesteps,features)
    z = tf.expand_dims(z,axis=1)

    ################################
    # SUSPECTED CAUSE OF PROBLEM
    ################################

    # A tensor that will be modified based on model outputs
    reconstruction_tensor = tf.Variable(tf.zeros_like(original_inputs))

    ################################
    # END SUSPECTED CAUSE OF PROBLEM
    ################################

    # A decoding loop to iteratively generate the next token (i.e.,outputs)... 
    # in the sequence
    hidden_states = None
    for ith_token in range(max_seq_length):

        # Reconstruct the ith_token for a given sample in the batch
        reconstructions,hidden_states = decoder_model(z,hidden_states,training=True)

        # Reshape the reconstructions to allow assigning to reconstruction_tensor
        reconstructions = tf.squeeze(reconstructions)

        # After the loop is done iterating,this tensor is the model's prediction of the 
        # original inputs. Therefore,after a single iteration of the loop,# a single token prediction for each sample in the batch is assigned to
        # this tensor.
        reconstruction_tensor = reconstruction_tensor[:,ith_token,:].assign(reconstructions)

    # Calculates losses
    recon_loss = cce_loss_fxn(original_inputs,reconstruction_tensor)
    latent_loss = latent_loss_fxn(mus,log_vars)
    loss = recon_loss + latent_loss

# Calculate gradients
encoder_gradients = tape.gradient(loss,encoder_model.trainable_weights)
decoder_gradients = tape.gradient(loss,decoder_model.trainable_weights)

# Release tape
del tape

# Inspect gradients
print('Valid Encoder Gradients:',not(None in encoder_gradients))
print('Valid Decoder Gradients:',not(None in decoder_gradients),' -- ',decoder_gradients)

>>> Valid Encoder Gradients: True
>>> Valid Decoder Gradients: False -- [None,None,None]

解决方法

找到了我的问题的“解决方案”:

在 GradientTape() 上下文管理器中使用 tf.Variable 肯定存在一些问题。虽然我不知道这个问题是什么,通过用列表替换重建张量,在解码迭代期间附加到该列表,然后堆叠该列表,可以毫无问题地计算梯度。 colab notebook 反映了这些变化。请参阅下面的代码片段以获取修复:

....
....
with tf.GradientTape(persistent=True) as tape:
    ....
    ....

    # FIX
    reconstructions_tensor = []

    hidden_states = None
    for ith_token in range(max_seq_length):
        # Reconstruct the ith_token for a given sample in the batch
        reconstructions,hidden_states = decoder_model(z,hidden_states,training=True)

        # Reshape the reconstructions
        reconstructions = tf.squeeze(reconstructions)

        # FIX
        # Appending to the list which will eventually be stacked
        reconstructions_tensor.append(reconstructions)
    
    # FIX
    # Stack the reconstructions along axis=1 to get same result as previous assignment with zeros tensor
    reconstructions_tensor = tf.stack(reconstructions_tensor,axis=1)
....
....
# Successful gradient computations and subsequent optimization of models
# ....

编辑 1:

如果一个模型可以在图形模式下运行,我认为这个“解决方案”并不理想。我有限的理解是图形模式不适用于 list 等 python 对象。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...