Powered by AppSignal & Oban Pro

Ch 12: Unsupervised

Ch12 - Unsupervised.livemd

Ch 12: Unsupervised

Mix.install([
  {:scidata, "~> 0.1"},
  {:axon, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:nx, "~> 0.5"},
  {:kino, "~> 0.8"}
])

Data prep

Nx.global_default_backend(EXLA.Backend)
batch_size = 64
{{data, type, shape}, _} = Scidata.MNIST.download()

train_data =
  data
  |> Nx.from_binary(type)
  # imgs are 28px by 28px and grayscale (1 colour channel)
  |> Nx.reshape({:auto, 28, 28, 1})
  |> Nx.divide(255)
  |> Nx.to_batched(batch_size)

Autoencoder

defmodule Autoencoder do
  def encoder(input) do
    input
    |> Axon.flatten()
    |> Axon.dense(256, activation: :relu, name: "encoder_dense_0")
    |> Axon.dense(128, activation: :relu, name: "encoder_dense_1")
  end

  def decoder(input) do
    input
    |> Axon.dense(256, activation: :relu, name: "decoder_dense_0")
    |> Axon.dense(784, activation: :sigmoid, name: "decoder_dense_1")
    |> Axon.reshape({:batch, 28, 28, 1})
  end
end
model =   
  Axon.input("image")   
  |> Autoencoder.encoder()   
  |> Autoencoder.decoder()     
test_batch = Enum.at(train_data, 0)

test_image =
  test_batch[0]
  |> Nx.new_axis(0)

visualize_test_image = fn %Axon.Loop.State{step_state: step_state} = state ->
  out_image =
    Axon.predict(
      model,
      step_state[:model_state],
      test_image,
      compiler: EXLA
    )

  out_image =
    out_image
    |> Nx.multiply(255)
    # Kino.Image expects image tensors to be created from u8 tensors
    |> Nx.as_type(:u8)
    |> Nx.reshape({28, 28, 1})

  Kino.Image.new(out_image)
  |> Kino.render()

  {:continue, state}
end
trained_model_state =
  model
  |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam(learning_rate: 1.0e-3))
  |> Axon.Loop.handle_event(:epoch_completed, visualize_test_image)
  |> Axon.Loop.run(
    Stream.zip(train_data, train_data),
    %{},
    epochs: 5,
    compiler: EXLA
  )

By treating the input as both the input and label, you’ve turned this unsupervised learning problem into a supervised learning problem.

It’s difficult to interpret how this loss corresponds to how well your model is reconstructing compressed inputs. A much more interpretable measure of progress is a periodic visualization of your model’s outputs on some example data, which comes from the anon function bound to visualise_test_image.

In this example, your decoder is actually kind of a generative model—it takes a latent representation and produces an output. Hypothetically, you can skip the encoder altogether and give the decoder some random latent representation, and it should give you an output that resembles a handwritten digit:

decoder_only =
  Axon.input("noise")
  |> Autoencoder.decoder()

key = Nx.Random.key(42)
{noise, _key} = Nx.Random.normal(key, shape: {1, 128})
out_image = Axon.predict(decoder_only, trained_model_state, noise)
upsampled = Axon.Layers.resize(out_image, size: {512, 512})

out_image =
  upsampled
  |> Nx.reshape({512, 512, 1})
  |> Nx.multiply(255)
  |> Nx.as_type(:u8)

Kino.Image.new(out_image)

> Unfortunately, you didn’t force your model to have a latent space with structure. The structure of your encoded representations is at the mercy of a neural network and gradient descent. You can’t pass random uniform or normal noise to your decoder and expect coherent output because your decoder only knows how to handle latent representations produced by the encoder. But what if there was a way to force your encoder to learn a structured representation, which you can easily query later on? Fortunately, there is. Enter the variational autoencoder.

Variational autoencoder

defmodule VAE do
  import Nx.Defn

  def encoder(input) do
    encoded =
      input
      |> Axon.conv(32,
        kernel_size: 3,
        activation: :relu,
        strides: 2,
        padding: :same
      )
      |> Axon.conv(32,
        kernel_size: 3,
        activation: :relu,
        strides: 2,
        padding: :same
      )
      |> Axon.flatten()
      |> Axon.dense(16, activation: :relu)

    z_mean = Axon.dense(encoded, 2)
    z_log_var = Axon.dense(encoded, 2)
    z = Axon.layer(&sample/3, [z_mean, z_log_var], op_name: :sample)
    Axon.container({z_mean, z_log_var, z})
  end

  defnp sample(z_mean, z_log_var, _opts \\ []) do
    noise_shape = Nx.shape(z_mean)
    # Nx.random_normal removed from lib in 0.7
    # epsilon = Nx.random_normal(noise_shape)
    key = Nx.Random.key(42)
    {epsilon, _} = Nx.Random.normal(key, shape: noise_shape)
    z_mean + Nx.exp(0.5 * z_log_var) * epsilon
  end

  def decoder(input) do
    input
    |> Axon.dense(7 * 7 * 64, activation: :relu)
    |> Axon.reshape({:batch, 7, 7, 64})
    |> Axon.conv_transpose(64,
      kernel_size: {3, 3},
      activation: :relu,
      strides: [2, 2],
      padding: :same
    )
    |> Axon.conv_transpose(32,
      kernel_size: {3, 3},
      activation: :relu,
      strides: [2, 2],
      padding: :same
    )
    |> Axon.conv_transpose(1,
      kernel_size: {3, 3},
      activation: :sigmoid,
      padding: :same
    )
  end

  defn train_step(encoder_fn, decoder_fn, optimizer_fn, batch, state) do
    {batch_loss, joint_param_grads} =
      value_and_grad(
        state[:model_state],
        &joint_objective(encoder_fn, decoder_fn, batch, &1)
      )

    {scaled_updates, new_optimizer_state} =
      optimizer_fn.(
        joint_param_grads,
        state[:optimizer_state],
        state[:model_state]
      )

    new_model_state =
      Axon.Updates.apply_updates(
        state[:model_state],
        scaled_updates
      )

    new_loss =
      state[:loss]
      |> Nx.multiply(state[:i])
      |> Nx.add(batch_loss)
      |> Nx.divide(Nx.add(state[:i], 1))

    %{
      state
      | i: Nx.add(state[:i], 1),
        loss: new_loss,
        model_state: new_model_state,
        optimizer_state: new_optimizer_state
    }
  end

  defnp joint_objective(encoder_fn, decoder_fn, batch, joint_params) do
    %{prediction: preds} = encoder_fn.(joint_params["encoder"], batch)
    {z_mean, z_log_var, z} = preds
    %{prediction: reconstruction} = decoder_fn.(joint_params["decoder"], z)

    # how well your decoder reconstructs the original batch of images from the encoded representation.
    recon_loss =
      Axon.Losses.binary_cross_entropy(
        batch,
        reconstruction,
        reduction: :mean
      )

    # Regularization which penalizes the model from drifting too far away from a normal distribution. 
    # It basically measures the difference of the distribution defined by the given parameters 
    #   z_mean and z_log_var from a normal distribution with mean 1 and variance 0. 
    # This term helps coerce your encoder into learning to encode your data as a normal distribution.
    kl_loss = -0.5 * (1 + z_log_var - Nx.pow(z_mean, 2) - Nx.exp(z_log_var))
    kl_loss = Nx.mean(Nx.sum(kl_loss, axes: [1]))
    recon_loss + kl_loss
  end

  defn init_step(
         encoder_init_fn,
         decoder_init_fn,
         optimizer_init_fn,
         batch,
         init_state
       ) do
    encoder_params = encoder_init_fn.(batch, init_state)
    key = Nx.Random.key(28)
    {uniform, _} = Nx.Random.uniform(key, shape: {64, 2})
    decoder_params = decoder_init_fn.(uniform, init_state)

    joint_params = %{
      "encoder" => encoder_params,
      "decoder" => decoder_params
    }

    optimizer_state = optimizer_init_fn.(joint_params)

    %{
      i: Nx.tensor(0),
      loss: Nx.tensor(0.0),
      model_state: joint_params,
      optimizer_state: optimizer_state
    }
  end

  def display_sample(
        %Axon.Loop.State{step_state: state} = out_state,
        decoder_fn
      ) do
    latent = Nx.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
    %{prediction: out} = decoder_fn.(state[:model_state]["decoder"], latent)

    out_image =
      Nx.multiply(out, 255)
      |> Nx.as_type(:u8)

    upsample =
      Axon.Layers.resize(
        out_image,
        size: {512, 512},
        channels: :last
      )

    for i <- 0..2 do
      Kino.Image.new(Nx.reshape(upsample[i], {512, 512, 1}))
      |> Kino.render()
    end

    {:continue, out_state}
  end
end
template = Nx.template({1, 128}, :f32)
Axon.Display.as_graph(VAE.decoder(Axon.input("latent")), template)

Custom training loop. This needs 2 functions defined, one for state initialisation and another for state updates between steps.

These we’ve defined in the VAE module above as init_step/5 and train_step/5 respectively.

encoder =
  Axon.input("image")
  |> VAE.encoder()

decoder =
  Axon.input("latent")
  |> VAE.decoder()

{encoder_init_fn, encoder_fn} = Axon.build(encoder, mode: :train)
{decoder_init_fn, decoder_fn} = Axon.build(decoder, mode: :train)

{optimizer_init_fn, optimizer_fn} = Polaris.Optimizers.adam(learning_rate: 1.0e-3)

init_fn =
  &amp;VAE.init_step(
    encoder_init_fn,
    decoder_init_fn,
    optimizer_init_fn,
    &amp;1,
    &amp;2
  )

step_fn =
  &amp;VAE.train_step(
    encoder_fn,
    decoder_fn,
    optimizer_fn,
    &amp;1,
    &amp;2
  )

step_fn
|> Axon.Loop.loop(init_fn)
|> Axon.Loop.handle_event(:epoch_completed, &amp;VAE.display_sample(&amp;1, decoder_fn))
|> Axon.Loop.log(
  fn %Axon.Loop.State{epoch: epoch, iteration: iter, step_state: state} ->
    "\rEpoch: #{epoch}, batch: #{iter}, loss: #{Nx.to_number(state[:loss])}"
  end,
  event: :iteration_completed,
  device: :stdio
)
|> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 10)

GAN

defmodule GAN do
  import Nx.Defn

  def discriminator(input) do
    input
    |> Axon.conv(32, activation: :mish, kernel_size: 3, strides: 2, padding: :same)
    |> Axon.layer_norm()
    |> Axon.conv(64, activation: :mish, kernel_size: 3, strides: 2, padding: :same)
    |> Axon.layer_norm()
    |> Axon.flatten()
    |> Axon.dropout(rate: 0.5)
    |> Axon.dense(1, activation: :sigmoid)
  end

  def generator(input) do
    input
    |> Axon.dense(128 * 7 * 7, activation: :mish)
    |> Axon.reshape({:batch, 7, 7, 128})
    |> Axon.resize({14, 14})
    |> Axon.conv(128, kernel_size: 3, padding: :same)
    |> Axon.layer_norm()
    |> Axon.relu()
    |> Axon.resize({28, 28})
    |> Axon.conv(64, kernel_size: 3, padding: :same)
    |> Axon.layer_norm()
    |> Axon.relu()
    |> Axon.conv(1, activation: :tanh, kernel_size: 3, padding: :same)
  end

  defn train_step(
         discriminator_fn,
         generator_fn,
         discriminator_optimizer,
         generator_optimizer,
         batch,
         state
       ) do
    d_params = state[:model_state]["discriminator"]
    g_params = state[:model_state]["generator"]
    d_optimizer_state = state[:optimizer_state]["discriminator"]
    g_optimizer_state = state[:optimizer_state]["generator"]

    # Update discriminator 
    {d_loss, d_grads} =
      value_and_grad(d_params, fn d_params ->
        d_objective(
          d_params,
          g_params,
          discriminator_fn,
          generator_fn,
          batch
        )
      end)

    {d_updates, new_d_optimizer_state} =
      discriminator_optimizer.(
        d_grads,
        d_optimizer_state,
        d_params
      )

    new_d_params = Axon.Updates.apply_updates(d_params, d_updates)

    # Update generator  
    {g_loss, g_grads} =
      value_and_grad(g_params, fn g_params ->
        g_objective(
          d_params,
          g_params,
          discriminator_fn,
          generator_fn,
          batch
        )
      end)

    {g_updates, new_g_optimizer_state} =
      generator_optimizer.(
        g_grads,
        g_optimizer_state,
        g_params
      )

    new_g_params = Axon.Updates.apply_updates(g_params, g_updates)

    # Update Losses  
    new_d_loss =
      state[:loss]["discriminator"]
      |> Nx.multiply(state[:i])
      |> Nx.add(d_loss)
      |> Nx.divide(Nx.add(state[:i], 1))

    new_g_loss =
      state[:loss]["generator"]
      |> Nx.multiply(state[:i])
      |> Nx.add(g_loss)
      |> Nx.divide(Nx.add(state[:i], 1))

    new_loss = %{
      "discriminator" => new_d_loss,
      "generator" => new_g_loss
    }

    new_model_state = %{
      "discriminator" => new_d_params,
      "generator" => new_g_params
    }

    new_optimizer_state = %{
      "discriminator" => new_d_optimizer_state,
      "generator" => new_g_optimizer_state
    }

    %{
      model_state: new_model_state,
      optimizer_state: new_optimizer_state,
      loss: new_loss,
      i: Nx.add(state[:i], 1)
    }
  end

  defn d_objective(
         d_params,
         g_params,
         discriminator_fn,
         generator_fn,
         real_batch
       ) do
    batch_size = Nx.axis_size(real_batch, 0)
    real_targets = Nx.broadcast(1, {batch_size, 1})
    fake_targets = Nx.broadcast(0, {batch_size, 1})
    key = Nx.Random.key(832)
    {latent, _} = Nx.Random.normal(key, shape: {batch_size, 128})

    %{prediction: fake_batch} = generator_fn.(g_params, latent)

    %{prediction: real_labels} = discriminator_fn.(d_params, real_batch)
    %{prediction: fake_labels} = discriminator_fn.(d_params, fake_batch)

    real_loss =
      Axon.Losses.binary_cross_entropy(
        real_targets,
        real_labels,
        reduction: :mean
      )

    fake_loss =
      Axon.Losses.binary_cross_entropy(
        fake_targets,
        fake_labels,
        reduction: :mean
      )

    0.5 * real_loss + 0.5 * fake_loss
  end

  defn g_objective(
         d_params,
         g_params,
         discriminator_fn,
         generator_fn,
         real_batch
       ) do
    batch_size = Nx.axis_size(real_batch, 0)
    real_targets = Nx.broadcast(1, {batch_size, 1})
    key = Nx.Random.key(737)
    {latent, _} = Nx.Random.normal(key, shape: {batch_size, 128})

    %{prediction: fake_batch} = generator_fn.(g_params, latent)
    %{prediction: fake_labels} = discriminator_fn.(d_params, fake_batch)

    Axon.Losses.binary_cross_entropy(
      real_targets,
      fake_labels,
      reduction: :mean
    )
  end

  defn init_state(
         discriminator_init_fn,
         generator_init_fn,
         discriminator_optimizer_init,
         generator_optimizer_init,
         batch,
         init_state
       ) do
    d_params = discriminator_init_fn.(batch, init_state)
    key = Nx.Random.key(945)
    {normal, _} = Nx.Random.normal(key, shape: {64, 128})
    g_params = generator_init_fn.(normal, init_state)
    d_optimizer_state = discriminator_optimizer_init.(d_params)
    g_optimizer_state = generator_optimizer_init.(g_params)

    model_state = %{
      "discriminator" => d_params,
      "generator" => g_params
    }

    optimizer_state = %{
      "discriminator" => d_optimizer_state,
      "generator" => g_optimizer_state
    }

    loss = %{
      "discriminator" => Nx.tensor(0.0),
      "generator" => Nx.tensor(0.0)
    }

    %{
      model_state: model_state,
      optimizer_state: optimizer_state,
      loss: loss,
      i: Nx.tensor(0)
    }
  end

  def display_sample(
        %Axon.Loop.State{step_state: state} = out_state,
        generator_fn
      ) do
    key = Nx.Random.key(356)
    {latent, _} = Nx.Random.normal(key, shape: {3, 128})
    %{prediction: out} = generator_fn.(state[:model_state]["generator"], latent)

    out_image =
      Nx.multiply(out, 255)
      |> Nx.as_type(:u8)

    upsample =
      Axon.Layers.resize(
        out_image,
        size: {512, 512},
        channels: :last
      )

    for i <- 0..2 do
      Kino.Image.new(Nx.reshape(upsample[i], {512, 512, 1}))
      |> Kino.render()
    end

    {:continue, out_state}
  end
end
discriminator = GAN.discriminator(Axon.input("image"))
generator = GAN.generator(Axon.input("latent"))

{discriminator_init_fn, discriminator_fn} =
  Axon.build(discriminator, mode: :train)

{generator_init_fn, generator_fn} =
  Axon.build(generator, mode: :train)

{d_optimizer_init, d_optimizer} = Polaris.Optimizers.adam(learning_rate: 1.0e-4)
{g_optimizer_init, g_optimizer} = Polaris.Optimizers.adam(learning_rate: 1.0e-3)

init_fn =
  &amp;GAN.init_state(
    discriminator_init_fn,
    generator_init_fn,
    d_optimizer_init,
    g_optimizer_init,
    &amp;1,
    &amp;2
  )

step_fn =
  &amp;GAN.train_step(
    discriminator_fn,
    generator_fn,
    d_optimizer,
    g_optimizer,
    &amp;1,
    &amp;2
  )

Notice that you’re using slightly different optimizers for each model. To prevent the discriminator from dominating the generator before it can get its bearings during training, you need to lower the learning rate of the discriminator.

step_fn
|> Axon.Loop.loop(init_fn)
|> Axon.Loop.handle_event(
  :epoch_completed,
  &amp;GAN.display_sample(&amp;1, generator_fn)
)
|> Axon.Loop.log(
  fn
    %Axon.Loop.State{epoch: epoch, iteration: iter, step_state: state} ->
      d_loss = state[:loss]["discriminator"]
      g_loss = state[:loss]["generator"]

      "\rEpoch: #{epoch}, batch: #{iter}," <>
        " d_loss: #{Nx.to_number(d_loss)}," <>
        " g_loss: #{Nx.to_number(g_loss)}"
  end,
  event: :iteration_completed,
  device: :stdio
)
|> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 10)