Introduction

Generative Adversarial Network is a network with two game players: a generator that generates fake images and a discriminator that spots the real from the fake images. The network is trained and ultimately they reach an equilibrium in which the generator generates images close to real art, and the discriminator cannot distinguish the forged version. This network results in sophisticated generator and the generator is usually used for generating real images, even art. There are many variants of GANs since the original paper.

The normal GAN doesn’t support generating images of a certain category, it only generates a general beautiful images, of any categories it was trained on. Conditional GAN is a variant of GANs that has an additional embeded vector y as input for both generator and discriminator. Trained in this way, a cGAN generator can generate images of a given category.

In cGAN, the value function becomes dependent on the input y (hence the word conditional):

\[min_G max_D V(D,G) = E_x {[log D(x \mid y)]} + E_z {[log(1-D(G(z \mid y)))]}\]

Apart from generating images of a certain kind, cGAN paper also introduces an application in multi-modal learning, in which it generates tags (text) for image captioning. This is multimodal since the number of tags for one image can be more than one, and those words can also be synonym.

There are several ways we can add the label y to the generator:

  • As an embedding layer

  • Add as an additional channel to the images

  • keep embedding dimension low then upsample to match image size

Code example

Follows is an example of the discriminator and generator. When you run for MNIST for Fashion MNIST dataset or CIFAR-10, you need to adapt the input and some layers to fit the dimensions of the images (28x28x1 for MNISTs and 32x32x3 color images for CIFAR-10). MNISTs are black and white images (one channel, 28 pixel width and 28 pixel height) and CIFAR-10 are color images with 3 channels (RGB, and 32 pixel width and 32 pixel height). The following discriminator and generator are adapted for CIFAR-10. To adapt among different length of tensors, we can run the model.summary() and fix the numbers accordingly.

# define the standalone discriminator model
def define_discriminator(in_shape=(32,32,3), n_classes=10):
	# label input
	in_label = Input(shape=(1,))
	# embedding for categorical input
	li = Embedding(n_classes, 50)(in_label)
	# scale up to image dimensions with linear activation
	n_nodes = in_shape[0] * in_shape[1]
	li = Dense(n_nodes)(li)
	# reshape to additional channel
	li = Reshape((in_shape[0], in_shape[1], 1))(li)
	# image input
	in_image = Input(shape=in_shape)
	# concat label as a channel
	merge = Concatenate()([in_image, li])
	# downsample
	fe = Conv2D(64, (3,3), strides=(2,2), padding='same')(merge)
	fe = LeakyReLU(alpha=0.2)(fe)
	# downsample
	fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
	fe = LeakyReLU(alpha=0.2)(fe)
	# flatten feature maps
	fe = Flatten()(fe)
	# dropout
	fe = Dropout(0.4)(fe)
	# output
	out_layer = Dense(1, activation='sigmoid')(fe)
	# define model
	model = Model([in_image, in_label], out_layer)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
	return model

# define the standalone generator model
def define_generator(latent_dim, n_classes=10):
	# label input
	in_label = Input(shape=(1,))
	# embedding for categorical input
	li = Embedding(n_classes, 50)(in_label)
	# linear multiplication
	n_nodes = 8 * 8
	li = Dense(n_nodes)(li)
	# reshape to additional channel
	li = Reshape((8, 8, 1))(li)
	# image generator input
	in_lat = Input(shape=(latent_dim,))
	# foundation for 7x7 image
	n_nodes = 128 * 8 * 8
	gen = Dense(n_nodes)(in_lat)
	gen = LeakyReLU(alpha=0.2)(gen)
	gen = Reshape((8, 8, 128))(gen)
	# merge image gen and label input
	merge = Concatenate()([gen, li])
	# upsample to 14x14
	gen = Conv2DTranspose(64, (4,4), strides=(2,2), padding='same')(merge)
	gen = LeakyReLU(alpha=0.2)(gen)
	# upsample to 28x28
	gen = Conv2DTranspose(32, (4,4), strides=(2,2), padding='same')(gen)
	gen = LeakyReLU(alpha=0.2)(gen)
	# output
	out_layer = Conv2D(3, (8,8), activation='tanh', padding='same')(gen)
	# define model
	model = Model([in_lat, in_label], out_layer)
	return model

latent_dim = 100

d_model = define_discriminator()
d_model.summary()
g_model = define_generator(latent_dim)
g_model.summary()

Model: “model_10” (discriminator)

Layer (type) Output Shape Param # Connected to
input_17 (InputLayer) [(None, 1)] 0 []
embedding_8 (Embedding) (None, 1, 50) 500 [‘input_17[0][0]’]
dense_16 (Dense) (None, 1, 1024) 52224 [‘embedding_8[0][0]’]
input_18 (InputLayer) [(None, 32, 32, 3)] 0 []
reshape_11 (Reshape) (None, 32, 32, 1) 0 [‘dense_16[0][0]’]
concatenate_8 (Concatenate) (None, 32, 32, 4) 0 [‘input_18[0][0]’,’reshape_11[0][0]’]
conv2d_13 (Conv2D) (None, 16, 16, 64) 2368 [‘concatenate_8[0][0]’]
leaky_re_lu_19 (LeakyReLU) (None, 16, 16, 64) 0 [‘conv2d_13[0][0]’]
conv2d_14 (Conv2D) (None, 8, 8, 128) 73856 [‘leaky_re_lu_19[0][0]’]
leaky_re_lu_20 (LeakyReLU) (None, 8, 8, 128) 0 [‘conv2d_14[0][0]’]
flatten_5 (Flatten) (None, 8192) 0 [‘leaky_re_lu_20[0][0]’]
dropout_5 (Dropout) (None, 8192) 0 [‘flatten_5[0][0]’]
dense_17 (Dense) (None, 1) 8193 [‘dropout_5[0][0]’]

Total params: 137,141

Trainable params: 137,141

Non-trainable params: 0

Model: “model_11” (generator)

Layer (type) Output Shape Param # Connected to
input_20 (InputLayer) [(None, 100)] 0 []
input_19 (InputLayer) [(None, 1)] 0 []
dense_19 (Dense) (None, 8192) 827392 [‘input_20[0][0]’]
embedding_9 (Embedding) (None, 1, 50) 500 [‘input_19[0][0]’]
leaky_re_lu_21 (LeakyReLU) (None, 8192) 0 [‘dense_19[0][0]’]
dense_18 (Dense) (None, 1, 64) 3264 [‘embedding_9[0][0]’]
reshape_13 (Reshape) (None, 8, 8, 128) 0 [‘leaky_re_lu_21[0][0]’]
reshape_12 (Reshape) (None, 8, 8, 1) 0 [‘dense_18[0][0]’]
concatenate_9 (Concatenate) (None, 8, 8, 129) 0 [‘reshape_13[0][0]’,’reshape_12[0][0]’]
conv2d_transpose_6 (Conv2DTranspose) (None, 16, 16, 64) 132160 [‘concatenate_9[0][0]’]
leaky_re_lu_22 (LeakyReLU) (None, 16, 16, 64) 0 [‘conv2d_transpose_6[0][0]’]
conv2d_transpose_7 (Conv2DTranspose) (None, 32, 32, 32) 32800 [‘leaky_re_lu_22[0][0]’]
leaky_re_lu_23 (LeakyReLU) (None, 32, 32, 32) 0 [‘conv2d_transpose_7[0][0]’]
conv2d_15 (Conv2D) (None, 32, 32, 3) 6147 [‘leaky_re_lu_23[0][0]’]

Total params: 1,002,263

Trainable params: 1,002,263

Non-trainable params: 0

Note that the followings are after only 10 epochs:

MNIST: cgan_mnist

FASHION MNIST: cgan_fashion_mnist

CIFAR: cgan_cifar

Pix2Pix GAN

Pix2Pix is a conditional GAN architecture, used to translate image to image, for example: using a certain style (Monet, Van Gogh, ..), or from summer to winter time.

Screen Shot 2023-03-23 at 16 37 24

When a GAN goes from noise \(G: z \rightarrow y\), a conditional GAN goes from noise and input \(G: \{x,y\} \rightarrow y\).

In Pix2Pix paper, the objective function is modified. They add a regularization term: the L1 distance for the generator (how far the generated image is from the original one). They use L1 since L2 (euclidean distance) tends to tell to average all the pixel values to minimize it, hence encourage blurring of the images.

\[Loss_{L_1}(G) = E{[\mid \mid y-G(x,z)\mid \mid_1]}\]

This term is added into the final objective with a scaling factor \(\alpha\):

\[G^* = arg min_G max_D V(G,D) + \alpha Loss_{L_1} (G)\]

Architecture

In the original model, they don’t use noise, only dropout. The reason is that in their experiments, it doesn’t depend on the initial random point of the latent space. For both the generator and discriminator, they use modules of Convolution-BatchNorm-ReLU. They also utilize skip connection in their model of the generator. A skip connection simply means that we concanate everything at layer i to those at layer j. Like in the UNET architecture with encoder and decoder to reduce the information and then expand it, they use symetric layers to skip connections.

Since L1 is to force low-frequency correctness (pixels that locate on the corresponding positions should look like each other), this blurs the image. The authors use a technique to enforce high frequencies (to return a crisp image), that only look at the structure in local patches. This is called PatchGAN. This term only penalizes structure at level of a patch. So the discriminator only try to classify if averagely speaking, all the patches in the images are real. This idea is like a Markov random field: assuming that pixels outside their patch are independent from them. Small patching still contribute to high quality results, and reduces the computation.

Some other minor notes, they try to maximize the rate of discriminator instead of generator. But they also divide by two that rate so that the discriminator learns slowly. They use minibatch SGD with Adam solver, learning rate of 0.0002. One difficulty with evaluating synthetic images is that our usual euclidean distance doesn’t really work. Since the mean square only measure the total distance, it doesn’t capture the spatial concept (images are in 2 dimensions and each pixel location has its imporance).

Here is the discriminator:

Source: https://machinelearningmastery.com/how-to-implement-pix2pix-gan-models-from-scratch-with-keras/

The generator, with encoder and decoder blocks:

Source: https://machinelearningmastery.com/how-to-implement-pix2pix-gan-models-from-scratch-with-keras/

# example of defining a composite model for training the generator model
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import Dropout
from keras.layers import BatchNormalization
from keras.layers import LeakyReLU
from keras.utils.vis_utils import plot_model

# define the discriminator model
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_src_image = Input(shape=image_shape)
	# target image input
	in_target_image = Input(shape=image_shape)
	# concatenate images channel-wise
	merged = Concatenate()([in_src_image, in_target_image])
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	patch_out = Activation('sigmoid')(d)
	# define model
	model = Model([in_src_image, in_target_image], patch_out)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
	return model

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add downsampling layer
	g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# conditionally add batch normalization
	if batchnorm:
		g = BatchNormalization()(g, training=True)
	# leaky relu activation
	g = LeakyReLU(alpha=0.2)(g)
	return g

# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add upsampling layer
	g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# add batch normalization
	g = BatchNormalization()(g, training=True)
	# conditionally add dropout
	if dropout:
		g = Dropout(0.5)(g, training=True)
	# merge with skip connection
	g = Concatenate()([g, skip_in])
	# relu activation
	g = Activation('relu')(g)
	return g

# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)
	# encoder model: C64-C128-C256-C512-C512-C512-C512-C512
	e1 = define_encoder_block(in_image, 64, batchnorm=False)
	e2 = define_encoder_block(e1, 128)
	e3 = define_encoder_block(e2, 256)
	e4 = define_encoder_block(e3, 512)
	e5 = define_encoder_block(e4, 512)
	e6 = define_encoder_block(e5, 512)
	e7 = define_encoder_block(e6, 512)
	# bottleneck, no batch norm and relu
	b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
	b = Activation('relu')(b)
	# decoder model: CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
	d1 = decoder_block(b, e7, 512)
	d2 = decoder_block(d1, e6, 512)
	d3 = decoder_block(d2, e5, 512)
	d4 = decoder_block(d3, e4, 512, dropout=False)
	d5 = decoder_block(d4, e3, 256, dropout=False)
	d6 = decoder_block(d5, e2, 128, dropout=False)
	d7 = decoder_block(d6, e1, 64, dropout=False)
	# output
	g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
	# make weights in the discriminator not trainable
	for layer in d_model.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	# define the source image
	in_src = Input(shape=image_shape)
	# connect the source image to the generator input
	gen_out = g_model(in_src)
	# connect the source input and generator output to the discriminator input
	dis_out = d_model([in_src, gen_out])
	# src image as input, generated image and classification output
	model = Model(in_src, [dis_out, gen_out])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
	return model