Author: fchollet
Date created: 2019/04/29
Last modified: 2021/01/01
Description: A simple DCGAN trained using fit()
by overriding train_step
on CelebA images.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import os
import gdown
from zipfile import ZipFile
We'll use face images from the CelebA dataset, resized to 64x64.
os.makedirs("celeba_gan")
url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
output = "celeba_gan/data.zip"
gdown.download(url, output, quiet=True)
with ZipFile("celeba_gan/data.zip", "r") as zipobj:
zipobj.extractall("celeba_gan")
Create a dataset from our folder, and rescale the images to the [0-1] range:
dataset = keras.utils.image_dataset_from_directory(
"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
)
dataset = dataset.map(lambda x: x / 255.0)
Found 202599 files belonging to 1 classes.
Let's display a sample image:
for x in dataset:
plt.axis("off")
plt.imshow((x.numpy() * 255).astype("int32")[0])
break
It maps a 64x64 image to a binary classification score.
discriminator = keras.Sequential(
[
keras.Input(shape=(64, 64, 3)),
layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Flatten(),
layers.Dropout(0.2),
layers.Dense(1, activation="sigmoid"),
],
name="discriminator",
)
discriminator.summary()
Model: "discriminator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 32, 32, 64) 3136
_________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 32, 32, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 16, 16, 128) 131200
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 16, 16, 128) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 8, 8, 128) 262272
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 8, 8, 128) 0
_________________________________________________________________
flatten (Flatten) (None, 8192) 0
_________________________________________________________________
dropout (Dropout) (None, 8192) 0
_________________________________________________________________
dense (Dense) (None, 1) 8193
=================================================================
Total params: 404,801
Trainable params: 404,801
Non-trainable params: 0
_________________________________________________________________
It mirrors the discriminator, replacing Conv2D
layers with Conv2DTranspose
layers.
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
layers.Dense(8 * 8 * 128),
layers.Reshape((8, 8, 128)),
layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
],
name="generator",
)
generator.summary()
Model: "generator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 8192) 1056768
_________________________________________________________________
reshape (Reshape) (None, 8, 8, 128) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 16, 16, 128) 262272
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 16, 16, 128) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 32, 32, 256) 524544
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 32, 32, 256) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 64, 64, 512) 2097664
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 64, 64, 512) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 64, 64, 3) 38403
=================================================================
Total params: 3,979,651
Trainable params: 3,979,651
Non-trainable params: 0
_________________________________________________________________
train_step
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
@property
def metrics(self):
return [self.d_loss_metric, self.g_loss_metric]
def train_step(self, real_images):
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Decode them to fake images
generated_images = self.generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Update metrics
self.d_loss_metric.update_state(d_loss)
self.g_loss_metric.update_state(g_loss)
return {
"d_loss": self.d_loss_metric.result(),
"g_loss": self.g_loss_metric.result(),
}
class GANMonitor(keras.callbacks.Callback):
def __init__(self, num_img=3, latent_dim=128):
self.num_img = num_img
self.latent_dim = latent_dim
def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
generated_images = self.model.generator(random_latent_vectors)
generated_images *= 255
generated_images.numpy()
for i in range(self.num_img):
img = keras.utils.array_to_img(generated_images[i])
img.save("generated_img_%03d_%d.png" % (epoch, i))
epochs = 1 # In practice, use ~100 epochs
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
loss_fn=keras.losses.BinaryCrossentropy(),
)
gan.fit(
dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
)
6332/6332 [==============================] - 605s 96ms/step - d_loss: 0.6113 - g_loss: 1.1976
<tensorflow.python.keras.callbacks.History at 0x7f4eb5d055d0>
Some of the last generated images around epoch 30 (results keep improving after that):