实现自定义损失函数

问题描述

我目前正在写我的学士论文,我正在尝试使用 GAN 扩展我的数据集。到目前为止,我使用的是 WGAN-GP 与 GAN 和辅助分类器的渐进式增长。

辅助分类属性的实现要求我向鉴别器添加一个额外的损失函数,即分类交叉熵损失

我已经尝试将其实现到我的损失函数中,但我不确定我是否以正确的方式做到了这一点。

因此,我的问题在于这个自定义损失函数实现是否正确?:

def gradient_penalty(self,batch_size,real_images,fake_images):
    """ Calculates the gradient penalty.
    This loss is calculated on an interpolated image
    and added to the discriminator loss.
    """
    # Get the interpolated image
    alpha = tf.random.uniform(shape=[batch_size,1,1],minval=0.0,maxval=1.0)
    diff = fake_images - real_images
    interpolated = real_images + alpha * diff

    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        # 1. Get the discriminator output for this interpolated image.
        pred = self.discriminator(interpolated,training=True)

    # 2. Calculate the gradients w.r.t to this interpolated image.
    grads = tape.gradient(pred,[interpolated])[0]
    # 3. Calculate the norm of the gradients.
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads),axis=[1,2,3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

#Changed the input
def train_step(self,data):
    if len(data) == 3:
        real_images,labels,sample_weight = data
    else:
        sample_weight = None
        real_images,labels = data
    
    extracted = labels[0]
    batch_size = real_images.shape[0]

    # For each batch,we are going to perform the
    # following steps as laid out in the original paper:
    # 1. Train the generator and get the generator loss
    # 2. Train the discriminator and get the discriminator loss
    # 3. Calculate the gradient penalty
    # 4. Multiply this gradient penalty with a constant weight factor
    # 5. Add the gradient penalty to the discriminator loss
    # 6. Return the generator and discriminator losses as a loss dictionary

    # Train the discriminator first. The original paper recommends training
    # the discriminator for `x` more steps (typically 5) as compared to
    # one step of the generator. Here we will train it for 3 extra steps
    # as compared to 5 to reduce the training time.
    
    #Train the discriminator a extra amount of times (Property of WGAN-GP)
    
    for i in range(self.d_steps):
        # Get the latent vector
        random_latent_vectors = tf.convert_to_tensor(tf.random.normal(shape=(batch_size,self.latent_dim)))
        random_fake_labels = tf.convert_to_tensor(np.random.randint(2,size=(batch_size,3)))
        
        #This is how we log all the computations as we are going thorugh,and allows us to take the deriviative.
        
        #Wasserstein,gradient penalty and catgorical crossentropy-loss
        with tf.GradientTape() as tape:
            # Generate fake images from the latent vector and the label
            fake_images = self.generator([random_latent_vectors,random_fake_labels],training=True)
            
            # Get the logits for the fake images aswell as a prediction for what kind of fail there is
            fake_logits,fake_Label_pred = self.discriminator(fake_images,training=True)
            
            # Get the logits for the real images aswell as a prediction for what kind of fail there is on the image.
            real_logits,real_Label_pred = self.discriminator(real_images,training=True)

            #Calculate the categorical cross entropy loss
            cce = tf.keras.losses.CategoricalCrossentropy()
            fake_Label_loss = cce(random_fake_labels,fake_Label_pred)
            real_Label_loss = cce(labels,real_Label_pred)
            label_loss = 0.5*(fake_Label_loss + real_Label_loss)
            
            # Calculate the discriminator loss using the fake and real image logits
            #Wasserstein loss function
            d_cost = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits)
            
            # Calculate the gradient penalty
            gp = self.gradient_penalty(batch_size,fake_images)

            # Calculate the drift for regularization
            drift = tf.reduce_mean(tf.square(real_logits))

            # Add the gradient penalty to the original discriminator loss aswell as the loss for the labels.
            d_loss = d_cost + self.gp_weight * gp + self.drift_weight * drift + label_loss
            
        
        # Get the gradients w.r.t the discriminator loss
        d_gradient = tape.gradient(d_loss,self.discriminator.trainable_variables)
        # Update the weights of the discriminator using the discriminator optimizer
        self.d_optimizer.apply_gradients(zip(d_gradient,self.discriminator.trainable_variables))

    # Train the generator
    # Get the latent vector
    random_latent_vectors = tf.random.normal(shape=(batch_size,self.latent_dim))
    random_fake_labels = np.random.randint(2,3))
    
    #Generator loss
    with tf.GradientTape() as tape:
        # Generate fake images using the generator and the associated label
        generated_images = self.generator([random_latent_vectors,training=True)
                                                               
        # Get the discriminator logits for fake images aswell as a prediction to what kind of failure
        gen_img_logits,gen_Label_pred = self.discriminator(generated_images,training=True)

        #Initialize a istance of the Keras Categorical Crossentropy loss
        cce = tf.keras.losses.CategoricalCrossentropy()

        #Calculate the label loss,based on the real labels and the prediction logits.
        fake_Label_loss = cce(random_fake_labels,gen_Label_pred)
        
        # Calculate the generator loss and the label loss
        g_loss = -tf.reduce_mean(gen_img_logits) + fake_Label_loss
        
        #Calculate the loss of the predicted label and the actual label,using the sparse crossentropy
        #Added
        
    # Get the gradients w.r.t the generator loss
    g_gradient = tape.gradient(g_loss,self.generator.trainable_variables)
    
    # Update the weights of the generator using the generator optimizer
    self.g_optimizer.apply_gradients(zip(g_gradient,self.generator.trainable_variables))
    return {'d_loss': d_loss,'g_loss': g_loss}

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)