If you have ever heard of paintings drawn using artificial intelligence or deepfake app and wondered how these things work, then it would be good to know that all these applications use some variant of Generative adversarial network or GANs.
GANs has been a buzz word in the deep learning world and it has been an active field of research ever since its introduction in 2014 by Ian J. Goodfellow. So, what are GANs and why is it popular? Before we get into the details of GANs, let’s look at the definition of GANs. GANs is a deep learning architecture for training a generative model for image synthesis. GANs use generative models to generate new data. This ability of these networks to generate brand new data or unseen data is what makes them interesting. We will be looking at a high-level intuition of how it works, the underlying concepts, the math behind it and some of its interesting applications.
Fig 1. GAN framework
GAN is based on a min-max game between two different adversarial neural network models: a generative model, G, and a discriminative model, D. Both generative and discriminative models are neural networks with different objectives.
The generative model is a neural network that maps random variables z sampled from a latent space with a prior noise distribution Pz, to the original data distribution (data space). The discriminator model, D(x), is a neural network that estimates the probability that a sample comes from the real training data, rather than a sample from G. The reason why GANs perform so well is because of this architecture of two neural networks generator and the discriminator.
A discriminative model is the one that can only be used to discriminate/classify the data points. So we have to model P(y|x) where y is the target label and x is the feature vector. Typically y is either 0(when x is not a real data point or x is a fake generated image) or 1(when x is a real image). The discriminative model yields a probability between 0 and 1 for a given value of x.
A generative model is the one that can generate data. This model is expected to take some random noise as input and generate images similar to the original data. So a generative model is a model of the conditional probability of X(input data) given the target label Y or P(X|Y=y).
Let’s now look at the architecture of GANs and see how generative model and discriminative models interact and improve over time
Fig 2. GAN architecture
The discriminative model gets the real image and generated image by the generative model as input vector and it is expected to become good at distinguishing real and fake data. The generative model generates images with data from latent space z as input and feeds this generated image to the discriminative model. The generative model actually tries to fool the discriminative model into classifying a fake image as a real image and the discriminative models strive to get better in classification between real and fake images generated by the generator.
Backpropagation in discriminative model
Let’s now take a look at how the discriminative model and updates its weight during the training process.
Fig 3. Backpropagation in discriminative model .
- Input image from latent space(z) (real image from a training set or a fake image generated by the generative model) is selected
- Input data goes through the discriminative model and based on the result the loss is computed with respect to y and this information is fed back to the neural network via backpropagation
Backpropagation in a generative model
In the training phase the generative model tries to mimic original training data and based on the feedback received from backpropagation it updates until the discriminative network classifies a fake image generated from noise as an original image
Fig 4. Backpropagation in a generative model
- Select random noise from latent space(z) and generate fake image and label
- Feed the generated image to the discriminator
- Based on the loss value of the discriminator, backpropagate the result to the generator
When the generator is updated with backpropagation, the discriminator would be locked in order to ensure that both the networks are in synchronization with respect to the loss computed.
GANs models are really difficult to converge because of the vast number of hyperparameters to be tweaked. In an ideal scenario, the output of a trained discriminative model becomes 0.5 as it cannot distinguish between real and fake images. That when we can say that the model has converged. Once the model is converged, we can then remove the discriminative network and use the generative model to generate/synthesize new data which will similar to the original data in the training set.
Applications of GANs
GANs can be used for a variety of purposes. Let’s look at the potential applications in the healthcare domain.
- GANs can be used to improve medical imaging/ segmentation
- GANs can be used to generate high-resolution images from low-resolution data
- GANs can be used as an excellent augmentation strategy when training deep neural networks
- GANs can also help simulate medical images with a variety of clinical conditions
- GANs are being used in various drug discovery experiments