Search icon CANCEL
Arrow left icon
Explore Products
Best Sellers
New Releases
Books
Videos
Audiobooks
Learning Hub
Conferences
Free Learning
Arrow right icon

Generative Adversarial Networks: Google open sources TensorFlow-GAN (TFGAN)

Save for later
  • 11 min read
  • 13 Dec 2017

article-image

If you have played the game Prince of Persia, you know what it is like defending yourself from the ‘shadow’ which tries to kill you. It’s a conundrum: If you kill the shadow you die; if you don’t do anything, you definitely die!

For all its merits, Generative Adversarial Networks, or GAN, has faced a similar problem with differentiation. Most deep learning experts who endorse GAN mix their support with a little bit of caution – there is a stability issue!

You may call it a holistic convergence problem. Both discriminator and generator are at loggerheads, while still being dependant on each other for efficient training. If one of them fails, the entire system fails. And you have got to ensure they don’t explode. The Prince of Persia is an interesting concept!

To begin with, Neural Networks were designed to replicate human brain (albeit, artificially). They have succeeded in recognizing objects and processing natural languages. But to think and act like humans at that neurological level – let us admit it’s a far cry still.

Which is why Generative Adversarial Networks became a hot topic in machine learning. It’s a relatively new architecture, but have gone on to revolutionize deep learning by accurately modeling real world data in ways better than any other model has done before.

After all, they came up with a new model for training a neural net, with not one but two independent nets that work separately (and act as adversaries!) as Discriminator and Generator. Such a new architecture for an unsupervised neural network yields far better performance when compared to traditional nets.

But the fact is, we have barely scratched the surface. Challenge is to train GAN here onwards. It comes with its own problems, such as failing to differentiate how many of a particular object should occur at a location, failing to adapt to 3D objects (it doesn’t understand the perspectives of frontview and backview), not being able to understand real-life holistic structures, etc. Substantial research has been taking place to take care of these problems. New models have been proposed to give more accurate results than previous techniques.

Now  Google intends to make the Generative Adversarial Networks easier to experiment with! They have just open sourced TFGAN, a lightweight TensorFlow library designned to make it easy to train and evaluate GANs.

[embed width="" height=""]https://www.youtube.com/watch?v=f2GF7TZpuGQ[/embed]

According to Google, TFGAN provides the infrastructure to easily train a GAN, provides well-tested loss and evaluation metrics, and gives easy-to-use examples that highlight the expressiveness and flexibility of TFGAN.

"We’ve also released a tutorial that includes a high-level API to quickly get a model trained on your data," Google said in its announcement.

google-opensources-tensorflow-gan-tfgan-library-for-generative-adversarial-networks-neural-network-model-img-0

Source: research.googleblog.com

The above image demonstrates the effect of an adversarial loss on image compression. The top row shows image patches from the ImageNet dataset. The middle row shows the results of compressing and uncompressing an image through an image compression neural network trained on a traditional loss. The bottom row shows the results from a network trained with a traditional loss and an adversarial loss. The GAN-loss images are sharper and more detailed, even if they are less like the original.

TFGAN offers simple function calls for majority of GAN use-cases (where users can run a model in a few lines of code), but it's also built in a modular way that covers sophisticated GAN designs. "You can just use the modules you want — loss, evaluation, features, training, etc. are all independent. TFGAN’s lightweight design also means you can use it alongside other frameworks, or with native TensorFlow code," Google says, adding that GAN models written using TFGAN will easily benefit from future infrastructure improvements. That users can select from a large number of already-implemented losses and features without having to rewrite their own.

Unlock access to the largest independent learning library in Tech for FREE!
Get unlimited access to 7500+ expert-authored eBooks and video courses covering every tech area you can think of.
Renews at €18.99/month. Cancel anytime

Most importantly, Google is assuring us that the code is well-tested: "You don’t have to worry about numerical or statistical mistakes that are easily made with GAN libraries."

google-opensources-tensorflow-gan-tfgan-library-for-generative-adversarial-networks-neural-network-model-img-1

Source: research.googleblog.com

Most neural text-to-speech (TTS) systems produce over-smoothed spectrograms. When applied to the TacotronTTS system, Google says, a GAN can recreate some of the realistic-texture reducing artifacts in the resulting audio.

And then, there is no harm in reiterating that when Google has open sourced a project, it must be absolute production ready! "When you use TFGAN, you’ll be using the same infrastructure that many Google researchers use, and you’ll have access to the cutting-edge improvements that we develop with the library," the tech giant added.

To Start With

import tensorflow as tf
tfgan = tf.contrib.gan

Why TFGAN?

  • Easily train generator and discriminator networks with well-tested, flexible library calls. You can mix TFGAN, native TF, and other custom frameworks
  • Use already implemented GAN losses and penalties (ex Wasserstein loss, gradient penalty, mutual information penalty, etc)
  • Monitor and visualize GAN progress during training, and evaluate them
  • Use already-implemented tricks to stabilize and improve training
  • Develop based on examples of common GAN setups
  • Use the TFGAN-backed GANEstimator to easily train a GAN model
  • Improvements in TFGAN infrastructure will automatically benefit your TFGAN project
  • Stay up-to-date with research as we add more algorithms

What are the TFGAN components?

TFGAN is composed of several parts which were designed to exist independently. These include the following main pieces (explained in detail below).

  • core: provides the main infrastructure needed to train a GAN. Training occurs in four phases, and each phase can be completed by custom-code or by using a TFGAN library call.
  • features: Many common GAN operations and normalization techniques are implemented for you to use, such as instance normalization and conditioning.
  • losses: Easily experiment with already-implemented and well-tested losses and penalties, such as the Wasserstein loss, gradient penalty, mutual information penalty, etc
  • evaluation: Use Inception Score or Frechet Distance with a pretrained Inception network to evaluate your unconditional generative model. You can also use your own pretrained classifier for more specific performance numbers, or use other methods for evaluating conditional generative models.
  • examples and tutorial: See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your own project. These include unconditional and conditional GANs, InfoGANs, adversarial losses on existing networks, and image-to-image translation.

Training a GAN model

Training in TFGAN typically consists of the following steps:

  1. Specify the input to your networks.
  2. Set up your generator and discriminator using a GANModel.
  3. Specify your loss using a GANLoss.
  4. Create your train ops using a GANTrainOps.
  5. Run your train ops.

There are various types of GAN setups. For instance, you can train a generator to sample unconditionally from a learned distribution, or you can condition on extra information such as a class label. TFGAN is compatible with many setups, and a few are demonstrated below:

Examples

Unconditional MNIST generation

This example trains a generator to produce handwritten MNIST digits. The generator maps random draws from a multivariate normal distribution to MNIST digit images. See 'Generative Adversarial Networks' by Goodfellow et al.

# Set up the input.
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=mnist.unconditional_generator,  # you define
    discriminator_fn=mnist.unconditional_discriminator,  # you define
    real_data=images,
    generator_inputs=noise)

# Build the GAN loss.
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss)

# Create the train ops, which calculate gradients and apply updates to weights.
train_ops = tfgan.gan_train_ops(
    gan_model,
    gan_loss,
    generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

# Run the train ops in the alternating training scheme.
tfgan.gan_train(
    train_ops,
    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
    logdir=FLAGS.train_log_dir)

Conditional MNIST generation

This example trains a generator to generate MNIST images of a given class. The generator maps random draws from a multivariate normal distribution and a one-hot label of the desired digit class to an MNIST digit image. See 'Conditional Generative Adversarial Nets' by Mirza and Osindero.

# Set up the input.
images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=mnist.conditional_generator,  # you define
    discriminator_fn=mnist.conditional_discriminator,  # you define
    real_data=images,
    generator_inputs=(noise, one_hot_labels))

# The rest is the same as in the unconditional case.
...

Adversarial loss

This example combines an L1 pixel loss and an adversarial loss to learn to autoencode images. The bottleneck layer can be used to transmit compressed representations of the image. Neutral networks with pixel-wise loss only tend to produce blurry results, so the GAN can be used to make the reconstructions more plausible. See 'Full Resolution Image Compression with Recurrent Neural Networks' by Toderici et al for an example of neural networks used for image compression, and 'Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network' by Ledig et al for a more detailed description of how GANs can sharpen image output.

# Set up the input pipeline.
images = image_provider.provide_data(FLAGS.batch_size)

# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=nets.autoencoder,  # you define
    discriminator_fn=nets.discriminator,  # you define
    real_data=images,
    generator_inputs=images)

# Build the GAN loss and standard pixel loss.
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
    gradient_penalty=1.0)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

# Modify the loss tuple to include the pixel loss.
gan_loss = tfgan.losses.combine_adversarial_loss(
    gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

# The rest is the same as in the unconditional case.
...

Image-to-image translation

This example maps images in one domain to images of the same size in a different dimension. For example, it can map segmentation masks to street images, or grayscale images to color. See 'Image-to-Image Translation with Conditional Adversarial Networks' by Isola et al for more details.

# Set up the input pipeline.
input_image, target_image = data_provider.provide_data(FLAGS.batch_size)

# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=nets.generator,  # you define
    discriminator_fn=nets.discriminator,  # you define
    real_data=target_image,
    generator_inputs=input_image)

# Build the GAN loss and standard pixel loss.
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan_losses.least_squares_generator_loss,
    discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

# Modify the loss tuple to include the pixel loss.
gan_loss = tfgan.losses.combine_adversarial_loss(
    gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

# The rest is the same as in the unconditional case.
...

InfoGAN

Train a generator to generate specific MNIST digit images, and control for digit style without using any labels. See 'InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets' for more details.

# Set up the input pipeline.
images = mnist_data_provider.provide_data(FLAGS.batch_size)

# Build the generator and discriminator.
gan_model = tfgan.infogan_model(
    generator_fn=mnist.infogan_generator,  # you define
    discriminator_fn=mnist.infogran_discriminator,  # you define
    real_data=images,
    unstructured_generator_inputs=unstructured_inputs,  # you define
    structured_generator_inputs=structured_inputs)  # you define

# Build the GAN loss with mutual information penalty.
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
    gradient_penalty=1.0,
    mutual_information_penalty_weight=1.0)

# The rest is the same as in the unconditional case.
...

Custom model creation

Train an unconditional GAN to generate MNIST digits, but manually construct the GANModel tuple for more fine-grained control.

# Set up the input pipeline.
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# Manually build the generator and discriminator.
with tf.variable_scope('Generator') as gen_scope:
  generated_images = generator_fn(noise)
with tf.variable_scope('Discriminator') as dis_scope:
  discriminator_gen_outputs = discriminator_fn(generated_images)
with variable_scope.variable_scope(dis_scope, reuse=True):
  discriminator_real_outputs = discriminator_fn(images)
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
# Depending on what TFGAN features you use, you don't always need to supply
# every `GANModel` field. At a minimum, you need to include the discriminator
# outputs and variables if you want to use TFGAN to construct losses.
gan_model = tfgan.GANModel(
    generator_inputs,
    generated_data,
    generator_variables,
    gen_scope,
    generator_fn,
    real_data,
    discriminator_real_outputs,
    discriminator_gen_outputs,
    discriminator_variables,
    dis_scope,
    discriminator_fn)

# The rest is the same as the unconditional case.
...

Google has allowed anyone to contribute to the github repositories to facilitate code-sharing among machine learning users. For more examples on TFGAN, see tensorflow/models on GitHub.