问题描述
我目前正在写我的学士论文,我正在尝试使用 GAN 扩展我的数据集。到目前为止,我使用的是 WGAN-GP 与 GAN 和辅助分类器的渐进式增长。
辅助分类器属性的实现要求我向鉴别器添加一个额外的损失函数,即分类交叉熵损失。
我已经尝试将其实现到我的损失函数中,但我不确定我是否以正确的方式做到了这一点。
def gradient_penalty(self,batch_size,real_images,fake_images):
""" Calculates the gradient penalty.
This loss is calculated on an interpolated image
and added to the discriminator loss.
"""
# Get the interpolated image
alpha = tf.random.uniform(shape=[batch_size,1,1],minval=0.0,maxval=1.0)
diff = fake_images - real_images
interpolated = real_images + alpha * diff
with tf.GradientTape() as tape:
tape.watch(interpolated)
# 1. Get the discriminator output for this interpolated image.
pred = self.discriminator(interpolated,training=True)
# 2. Calculate the gradients w.r.t to this interpolated image.
grads = tape.gradient(pred,[interpolated])[0]
# 3. Calculate the norm of the gradients.
norm = tf.sqrt(tf.reduce_sum(tf.square(grads),axis=[1,2,3]))
gp = tf.reduce_mean((norm - 1.0) ** 2)
return gp
#Changed the input
def train_step(self,data):
if len(data) == 3:
real_images,labels,sample_weight = data
else:
sample_weight = None
real_images,labels = data
extracted = labels[0]
batch_size = real_images.shape[0]
# For each batch,we are going to perform the
# following steps as laid out in the original paper:
# 1. Train the generator and get the generator loss
# 2. Train the discriminator and get the discriminator loss
# 3. Calculate the gradient penalty
# 4. Multiply this gradient penalty with a constant weight factor
# 5. Add the gradient penalty to the discriminator loss
# 6. Return the generator and discriminator losses as a loss dictionary
# Train the discriminator first. The original paper recommends training
# the discriminator for `x` more steps (typically 5) as compared to
# one step of the generator. Here we will train it for 3 extra steps
# as compared to 5 to reduce the training time.
#Train the discriminator a extra amount of times (Property of WGAN-GP)
for i in range(self.d_steps):
# Get the latent vector
random_latent_vectors = tf.convert_to_tensor(tf.random.normal(shape=(batch_size,self.latent_dim)))
random_fake_labels = tf.convert_to_tensor(np.random.randint(2,size=(batch_size,3)))
#This is how we log all the computations as we are going thorugh,and allows us to take the deriviative.
#Wasserstein,gradient penalty and catgorical crossentropy-loss
with tf.GradientTape() as tape:
# Generate fake images from the latent vector and the label
fake_images = self.generator([random_latent_vectors,random_fake_labels],training=True)
# Get the logits for the fake images aswell as a prediction for what kind of fail there is
fake_logits,fake_Label_pred = self.discriminator(fake_images,training=True)
# Get the logits for the real images aswell as a prediction for what kind of fail there is on the image.
real_logits,real_Label_pred = self.discriminator(real_images,training=True)
#Calculate the categorical cross entropy loss
cce = tf.keras.losses.CategoricalCrossentropy()
fake_Label_loss = cce(random_fake_labels,fake_Label_pred)
real_Label_loss = cce(labels,real_Label_pred)
label_loss = 0.5*(fake_Label_loss + real_Label_loss)
# Calculate the discriminator loss using the fake and real image logits
#Wasserstein loss function
d_cost = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits)
# Calculate the gradient penalty
gp = self.gradient_penalty(batch_size,fake_images)
# Calculate the drift for regularization
drift = tf.reduce_mean(tf.square(real_logits))
# Add the gradient penalty to the original discriminator loss aswell as the loss for the labels.
d_loss = d_cost + self.gp_weight * gp + self.drift_weight * drift + label_loss
# Get the gradients w.r.t the discriminator loss
d_gradient = tape.gradient(d_loss,self.discriminator.trainable_variables)
# Update the weights of the discriminator using the discriminator optimizer
self.d_optimizer.apply_gradients(zip(d_gradient,self.discriminator.trainable_variables))
# Train the generator
# Get the latent vector
random_latent_vectors = tf.random.normal(shape=(batch_size,self.latent_dim))
random_fake_labels = np.random.randint(2,3))
#Generator loss
with tf.GradientTape() as tape:
# Generate fake images using the generator and the associated label
generated_images = self.generator([random_latent_vectors,training=True)
# Get the discriminator logits for fake images aswell as a prediction to what kind of failure
gen_img_logits,gen_Label_pred = self.discriminator(generated_images,training=True)
#Initialize a istance of the Keras Categorical Crossentropy loss
cce = tf.keras.losses.CategoricalCrossentropy()
#Calculate the label loss,based on the real labels and the prediction logits.
fake_Label_loss = cce(random_fake_labels,gen_Label_pred)
# Calculate the generator loss and the label loss
g_loss = -tf.reduce_mean(gen_img_logits) + fake_Label_loss
#Calculate the loss of the predicted label and the actual label,using the sparse crossentropy
#Added
# Get the gradients w.r.t the generator loss
g_gradient = tape.gradient(g_loss,self.generator.trainable_variables)
# Update the weights of the generator using the generator optimizer
self.g_optimizer.apply_gradients(zip(g_gradient,self.generator.trainable_variables))
return {'d_loss': d_loss,'g_loss': g_loss}
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)