我有一个 GAN 输出,我可以将以下输出序列中的图像保存为单个图像吗?

问题描述

`

For example consider this image an I want to save each part as single image

这是代码,这将在每个时期后给出一系列输出,我想将每个小输出保存为单个图像。我该如何继续? ############################################### ############################################### ############################################### ############################################### ###########

import tensorflow as tf
from tensorflow.keras.layers import Input,Reshape,Dropout,Dense 
from tensorflow.keras.layers import Flatten,Batchnormalization
from tensorflow.keras.layers import Activation,ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D,Conv2D
from tensorflow.keras.models import Sequential,Model,load_model
from tensorflow.keras.optimizers import Adam
import numpy as np
from PIL import Image
from tqdm import tqdm
import os 
import time
import matplotlib.pyplot as plt

from google.colab import drive
drive.mount('/content/drive')

 GENERATE_RES = 3
 GENERATE_SQUARE = 32 * GENERATE_RES # rows/cols (should be square)
 IMAGE_CHANNELS = 3
 # Preview image 
 PREVIEW_ROWS = 4
 PREVIEW_COLS = 7
 PREVIEW_MARGIN = 16

 # Size vector to generate images from
 SEED_SIZE = 100

 # Configuration
 DATA_PATH = '/content/drive/MyDrive/cars/images'
 EPOCHS = 50
 BATCH_SIZE = 32
 BUFFER_SIZE = 60000

 print(f"Will generate {GENERATE_SQUARE}px square images.")
 def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return "{}:{:>02}:{:>05.2f}".format(h,m,s)



 training_binary_path = os.path.join(DATA_PATH,f'training_data_{GENERATE_SQUARE}_{GENERATE_SQUARE}.npy')

 print(f"Looking for file: {training_binary_path}")

 if not os.path.isfile(training_binary_path):
     start = time.time()
     print("Loading training images...")

     training_data = []
     faces_path = os.path.join(DATA_PATH)
     for filename in tqdm(os.listdir(faces_path)):
         path = os.path.join(faces_path,filename)
         image = Image.open(path).resize((GENERATE_SQUARE,GENERATE_SQUARE),Image.ANTIALIAS)
         training_data.append(np.asarray(image))
     training_data = np.reshape(training_data,(-1,GENERATE_SQUARE,IMAGE_CHANNELS))
     training_data = training_data.astype(np.float32)
     training_data = training_data / 127.5 - 1.


     print("Saving training image binary...")
     np.save(training_binary_path,training_data)
     elapsed = time.time()-start
     print (f'Image preprocess time: {hms_string(elapsed)}')
else:
     print("Loading prevIoUs training pickle...")
     training_data = np.load(training_binary_path)

#shuffle the data
 train_dataset = tf.data.Dataset.from_tensor_slices(training_data) \
.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

def build_generator(seed_size,channels):
  model = Sequential()

  model.add(Dense(4*4*256,activation="relu",input_dim=seed_size))
  model.add(Reshape((4,4,256)))

  model.add(UpSampling2D())
  model.add(Conv2D(256,kernel_size=3,padding="same"))
  model.add(Batchnormalization(momentum=0.8))
  model.add(Activation("relu"))

  model.add(UpSampling2D())
  model.add(Conv2D(256,padding="same"))
  model.add(Batchnormalization(momentum=0.8))
  model.add(Activation("relu"))

  # Output resolution,additional upsampling
  model.add(UpSampling2D())
  model.add(Conv2D(128,padding="same"))
  model.add(Batchnormalization(momentum=0.8))
  model.add(Activation("relu"))

  if GENERATE_RES>1:
    model.add(UpSampling2D(size=(GENERATE_RES,GENERATE_RES)))
    model.add(Conv2D(128,padding="same"))
    model.add(Batchnormalization(momentum=0.8))
    model.add(Activation("relu"))

  # Final CNN layer
  model.add(Conv2D(channels,padding="same"))
  model.add(Activation("tanh"))

  return model


def build_discriminator(image_shape):
  model = Sequential()

  model.add(Conv2D(32,strides=2,input_shape=image_shape,padding="same"))
  model.add(LeakyReLU(alpha=0.2))

  model.add(Dropout(0.25))
  model.add(Conv2D(64,padding="same"))
  model.add(ZeroPadding2D(padding=((0,1),(0,1))))
  model.add(Batchnormalization(momentum=0.8))
  model.add(LeakyReLU(alpha=0.2))

  model.add(Dropout(0.25))
  model.add(Conv2D(128,padding="same"))
  model.add(Batchnormalization(momentum=0.8))
  model.add(LeakyReLU(alpha=0.2))

  model.add(Dropout(0.25))
  model.add(Conv2D(256,strides=1,padding="same"))
  model.add(Batchnormalization(momentum=0.8))
  model.add(LeakyReLU(alpha=0.2))

  model.add(Dropout(0.25))
  model.add(Conv2D(512,padding="same"))
  model.add(Batchnormalization(momentum=0.8))
  model.add(LeakyReLU(alpha=0.2))

  model.add(Dropout(0.25))
  model.add(Flatten())
  model.add(Dense(1,activation='sigmoid'))

  return model

def save_images(cnt,noise):
  image_array = np.full(( 
    PREVIEW_MARGIN + (PREVIEW_ROWS * (GENERATE_SQUARE+PREVIEW_MARGIN)),PREVIEW_MARGIN + (PREVIEW_COLS * (GENERATE_SQUARE+PREVIEW_MARGIN)),3),255,dtype=np.uint8)

generated_images = generator.predict(noise)

generated_images = 0.5 * generated_images + 0.5

image_count = 0
for row in range(PREVIEW_ROWS):
    for col in range(PREVIEW_COLS):
      r = row * (GENERATE_SQUARE+16) + PREVIEW_MARGIN
      c = col * (GENERATE_SQUARE+16) + PREVIEW_MARGIN
      image_array[r:r+GENERATE_SQUARE,c:c+GENERATE_SQUARE] \
        = generated_images[image_count] * 255
      image_count += 1

      
output_path = os.path.join(DATA_PATH,'output')
if not os.path.exists(output_path):
  os.makedirs(output_path)

filename = os.path.join(output_path,f"train-{cnt}.png")
im = Image.fromarray(image_array)
im.save(filename)

generator = build_generator(SEED_SIZE,IMAGE_CHANNELS)

noise = tf.random.normal([1,SEED_SIZE])
generated_image = generator(noise,training=False)

plt.imshow(generated_image[0,:,0])

image_shape = (GENERATE_SQUARE,IMAGE_CHANNELS)

discriminator = build_discriminator(image_shape)
decision = discriminator(generated_image)
print (decision)

cross_entropy = tf.keras.losses.BinaryCrossentropy()

def discriminator_loss(real_output,fake_output):
  real_loss = cross_entropy(tf.ones_like(real_output),real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
  total_loss = real_loss + fake_loss
  return total_loss

def generator_loss(fake_output):
  return cross_entropy(tf.ones_like(fake_output),fake_output)

generator_optimizer = tf.keras.optimizers.Adam(1.5e-4,0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1.5e-4,0.5)

@tf.function
def train_step(images):
  seed = tf.random.normal([BATCH_SIZE,SEED_SIZE])

  with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
    generated_images = generator(seed,training=True)

    real_output = discriminator(images,training=True)
    fake_output = discriminator(generated_images,training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output,fake_output)


    gradients_of_generator = gen_tape.gradient(\
        gen_loss,generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(\
        disc_loss,discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(
        gradients_of_generator,generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(
        gradients_of_discriminator,discriminator.trainable_variables))
  return gen_loss,disc_loss

def train(dataset,epochs):
  fixed_seed = np.random.normal(0,1,(PREVIEW_ROWS * PREVIEW_COLS,SEED_SIZE))
  start = time.time()

  for epoch in range(epochs):
    epoch_start = time.time()

    gen_loss_list = []
    disc_loss_list = []

   for image_batch in dataset:
     t = train_step(image_batch)
     gen_loss_list.append(t[0])
     disc_loss_list.append(t[1])

    g_loss = sum(gen_loss_list) / len(gen_loss_list)
    d_loss = sum(disc_loss_list) / len(disc_loss_list)

    epoch_elapsed = time.time()-epoch_start
    print (f'Epoch {epoch+1},gen loss={g_loss},disc loss={d_loss},'\
       ' {hms_string(epoch_elapsed)}')
    save_images(epoch,fixed_seed)

elapsed = time.time()-start
print (f'Training time: {hms_string(elapsed)}')

train(train_dataset,EPOCHS)

解决方法

#预览图片设置如下 PREVIEW_ROWS = 1 PREVIEW_COLS = 1 PREVIEW_MARGIN = 16

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...