Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Learning Without Supervision

LearningWithoutSupervision.livemd

Learning Without Supervision

Mix.install([
  {:scidata, "~> 0.1"},
  {:axon, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:nx, "~> 0.5"},
  {:kino, "~> 0.8"}
])

The Input Data

batch_size = 64

{{data, type, shape}, _} = Scidata.MNIST.download()

train_data =
  data
  |> Nx.from_binary(type)
  |> Nx.reshape({:auto, 28, 28, 1})
  |> Nx.divide(255)
  |> Nx.to_batched(batch_size)

Compressing Data with Autoencoders

defmodule Autoencoder do
  def encoder(input) do
    input
    |> Axon.flatten()
    |> Axon.dense(256, activation: :relu, name: "encoder_dense_0")
    |> Axon.dense(128, activation: :relu, name: "encoder_dense_1")
  end

  def decoder(input) do
    input
    |> Axon.dense(256, activation: :relu, name: "decoder_dense_0")
    |> Axon.dense(784, activation: :sigmoid, name: "decoder_dense_1")
    |> Axon.reshape({:batch, 28, 28, 1})
  end
end
model =
  Axon.input("image")
  |> Autoencoder.encoder()
  |> Autoencoder.decoder()

model
trained_model_state =
  model
  |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam(1.0e-3))
  |> Axon.Loop.run(
    Stream.zip(train_data, train_data),
    %{},
    epochs: 5,
    compiler: EXLA
  )
[test_batch] = Enum.take(train_data, 1)
test_image = test_batch[0] |> Nx.new_axis(0)

visualize_test_image = fn %Axon.Loop.State{step_state: step_state} = state ->
  out_image = Axon.predict(model, step_state[:model_state], test_image, compiler: EXLA)
  out_image = Nx.multiply(out_image, 255) |> Nx.as_type(:u8)
  Kino.Image.new(Nx.reshape(out_image, {28, 28, 1})) |> Kino.render()
  {:continue, state}
end
trained_model_state =
  model
  |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam(1.0e-3))
  |> Axon.Loop.handle_event(:epoch_completed, visualize_test_image)
  |> Axon.Loop.run(Stream.zip(train_data, train_data), %{}, epochs: 5, compiler: EXLA)
decoder_only =
  Axon.input("noise")
  |> Autoencoder.decoder()

key = Nx.Random.key(42)
{noise, _key} = Nx.Random.normal(key, shape: {1, 128})

out_image = Axon.predict(decoder_only, trained_model_state, noise)
upsampled = Axon.Layers.resize(out_image, size: {512, 512})

out_image =
  upsampled
  |> Nx.reshape({512, 512, 1})
  |> Nx.multiply(255)
  |> Nx.as_type(:u8)

Kino.Image.new(out_image)

Learning a Structured Latent

defmodule VAE do
  import Nx.Defn

  def encoder(input) do
    encoded =
      input
      |> Axon.conv(32,
        kernel_size: {3, 3},
        activation: :relu,
        strides: 2,
        padding: :same
      )
      |> Axon.conv(32,
        kernel_size: {3, 3},
        activation: :relu,
        strides: 2,
        padding: :same
      )
      |> Axon.flatten()
      |> Axon.dense(16, activation: :relu)

    z_mean = Axon.dense(encoded, 2)
    z_log_var = Axon.dense(encoded, 2)
    z = Axon.layer(&sample/3, [z_mean, z_log_var], op_name: :sample)

    Axon.container({z_mean, z_log_var, z})
  end

  defnp sample(z_mean, z_log_var, _opts \\ []) do
    noise_shape = Nx.shape(z_mean)
    epsilon = Nx.random_normal(noise_shape)
    z_mean + Nx.exp(0.5 * z_log_var) * epsilon
  end

  def decoder(input) do
    input
    |> Axon.dense(7 * 7 * 64, activation: :relu)
    |> Axon.reshape({:batch, 7, 7, 64})
    |> Axon.conv_transpose(64,
      kernel_size: {3, 3},
      activation: :relu,
      strides: [2, 2],
      padding: :same
    )
    |> Axon.conv_transpose(32,
      kernel_size: {3, 3},
      activation: :relu,
      strides: [2, 2],
      padding: :same
    )
    |> Axon.conv_transpose(1,
      kernel_size: {3, 3},
      activation: :sigmoid,
      padding: :same
    )
  end

  defn train_step(encoder_fn, decoder_fn, optimizer_fn, batch, state) do
    {batch_loss, joint_param_grads} =
      value_and_grad(
        state[:model_state],
        &joint_objective(encoder_fn, decoder_fn, batch, &1)
      )

    {scaled_updates, new_optimizer_state} =
      optimizer_fn.(joint_param_grads, state[:optimizer_state], state[:model_state])

    new_model_state = Axon.Updates.apply_updates(state[:model_state], scaled_updates)

    new_loss =
      state[:loss]
      |> Nx.multiply(state[:i])
      |> Nx.add(batch_loss)
      |> Nx.divide(Nx.add(state[:i], 1))

    %{
      state
      | i: Nx.add(state[:i], 1),
        loss: new_loss,
        model_state: new_model_state,
        optimizer_state: new_optimizer_state
    }
  end

  defnp joint_objective(encoder_fn, decoder_fn, batch, joint_params) do
    %{prediction: {z_mean, z_log_var, z}} = encoder_fn.(joint_params["encoder"], batch)
    %{prediction: reconstruction} = decoder_fn.(joint_params["decoder"], z)

    recon_loss = Axon.Losses.binary_cross_entropy(batch, reconstruction, reduction: :mean)
    kl_loss = -0.5 * (1 + z_log_var - Nx.pow(z_mean, 2) - Nx.exp(z_log_var))
    kl_loss = Nx.mean(Nx.sum(kl_loss, axes: [1]))

    recon_loss + kl_loss
  end

  defn init_step(encoder_init_fn, decoder_init_fn, optimizer_init_fn, batch, init_state) do
    encoder_params = encoder_init_fn.(batch, init_state)
    decoder_params = decoder_init_fn.(Nx.random_uniform({64, 2}), init_state)
    joint_params = %{"encoder" => encoder_params, "decoder" => decoder_params}
    optimizer_state = optimizer_init_fn.(joint_params)

    %{
      i: Nx.tensor(0),
      loss: Nx.tensor(0.0),
      model_state: joint_params,
      optimizer_state: optimizer_state
    }
  end

  def display_sample(%Axon.Loop.State{step_state: state} = out_state, decoder_fn) do
    latent = Nx.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
    %{prediction: out} = decoder_fn.(state[:model_state]["decoder"], latent)
    out_image = Nx.multiply(out, 255) |> Nx.as_type(:u8)
    upsample = Axon.Layers.resize(out_image, size: {512, 512}, channels: :first)

    for i <- 0..2 do
      Kino.Image.new(Nx.reshape(upsample[i], {512, 512, 1})) |> Kino.render()
    end

    {:continue, out_state}
  end
end
template = Nx.template({1, 128}, :f32)
Axon.Display.as_graph(VAE.decoder(Axon.input("latent")), template)
encoder = Axon.input("image") |> VAE.encoder()
decoder = Axon.input("latent") |> VAE.decoder()
{encoder_init_fn, encoder_fn} = Axon.build(encoder, mode: :train)
{decoder_init_fn, decoder_fn} = Axon.build(decoder, mode: :train)
{optimizer_init_fn, optimizer_fn} = Axon.Optimizers.adam(1.0e-3)
init_fn = &amp;VAE.init_step(encoder_init_fn, decoder_init_fn, optimizer_init_fn, &amp;1, &amp;2)
step_fn = &amp;VAE.train_step(encoder_fn, decoder_fn, optimizer_fn, &amp;1, &amp;2)
step_fn
|> Axon.Loop.loop(init_fn)
|> Axon.Loop.handle_event(:epoch_completed, &amp;VAE.display_sample(&amp;1, decoder_fn))
|> Axon.Loop.log(
  fn %Axon.Loop.State{epoch: epoch, iteration: iter, step_state: state} ->
    "\rEpoch: #{epoch}, batch: #{iter}, loss: #{Nx.to_number(state[:loss])}"
  end,
  event: :iteration_completed
)
|> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 10)

Generating with GANs

defmodule GAN do
  import Nx.Defn

  def discriminator(input) do
    input
    |> Axon.conv(32, activation: :mish, kernel_size: 3, strides: 2, padding: :same)
    |> Axon.layer_norm()
    |> Axon.conv(64, activation: :mish, kernel_size: 3, strides: 2, padding: :same)
    |> Axon.layer_norm()
    |> Axon.flatten()
    |> Axon.dropout(rate: 0.5)
    |> Axon.dense(1, activation: :sigmoid)
  end

  def generator(input) do
    input
    |> Axon.dense(128 * 7 * 7, activation: :mish)
    |> Axon.reshape({:batch, 7, 7, 128})
    |> Axon.resize({14, 14})
    |> Axon.conv(128, kernel_size: 3, padding: :same)
    |> Axon.layer_norm()
    |> Axon.relu()
    |> Axon.resize({28, 28})
    |> Axon.conv(64, kernel_size: 3, padding: :same)
    |> Axon.layer_norm()
    |> Axon.relu()
    |> Axon.conv(1, activation: :tanh, kernel_size: 3, padding: :same)
  end

  defn init_state(
         discriminator_init_fn,
         generator_init_fn,
         discriminator_optimizer_init,
         generator_optimizer_init,
         batch,
         init_state
       ) do
    d_params = discriminator_init_fn.(batch, init_state)
    g_params = generator_init_fn.(Nx.random_normal({64, 128}), init_state)
    d_optimizer_state = discriminator_optimizer_init.(d_params)
    g_optimizer_state = generator_optimizer_init.(g_params)

    model_state = %{"discriminator" => d_params, "generator" => g_params}
    optimizer_state = %{"discriminator" => d_optimizer_state, "generator" => g_optimizer_state}
    loss = %{"discriminator" => Nx.tensor(0.0), "generator" => Nx.tensor(0.0)}

    %{
      model_state: model_state,
      optimizer_state: optimizer_state,
      loss: loss,
      i: Nx.tensor(0)
    }
  end

  defn g_objective(d_params, g_params, discriminator_fn, generator_fn, real_batch) do
    batch_size = Nx.axis_size(real_batch, 0)
    real_targets = Nx.broadcast(1, {batch_size, 1})
    latent = Nx.random_normal({batch_size, 128})

    %{prediction: fake_batch} = generator_fn.(g_params, latent)
    %{prediction: fake_labels} = discriminator_fn.(d_params, fake_batch)

    Axon.Losses.binary_cross_entropy(real_targets, fake_labels, reduction: :mean)
  end

  defn d_objective(d_params, g_params, discriminator_fn, generator_fn, real_batch) do
    batch_size = Nx.axis_size(real_batch, 0)
    real_targets = Nx.broadcast(1, {batch_size, 1})
    fake_targets = Nx.broadcast(0, {batch_size, 1})
    latent = Nx.random_normal({batch_size, 128})
    %{prediction: fake_batch} = generator_fn.(g_params, latent)

    %{prediction: real_labels} = discriminator_fn.(d_params, real_batch)
    %{prediction: fake_labels} = discriminator_fn.(d_params, fake_batch)

    real_loss = Axon.Losses.binary_cross_entropy(real_targets, real_labels, reduction: :mean)
    fake_loss = Axon.Losses.binary_cross_entropy(fake_targets, fake_labels, reduction: :mean)

    0.5 * real_loss + 0.5 * fake_loss
  end

  defn train_step(
         discriminator_fn,
         generator_fn,
         discriminator_optimizer,
         generator_optimizer,
         batch,
         state
       ) do
    d_params = state[:model_state]["discriminator"]
    g_params = state[:model_state]["generator"]
    d_optimizer_state = state[:optimizer_state]["discriminator"]
    g_optimizer_state = state[:optimizer_state]["generator"]

    # Update discriminator
    {d_loss, d_grads} =
      value_and_grad(d_params, &amp;d_objective(&amp;1, g_params, discriminator_fn, generator_fn, batch))

    {d_updates, new_d_optimizer_state} =
      discriminator_optimizer.(d_grads, d_optimizer_state, d_params)

    new_d_params = Axon.Updates.apply_updates(d_params, d_updates)

    # Update generator
    {g_loss, g_grads} =
      value_and_grad(g_params, &amp;g_objective(d_params, &amp;1, discriminator_fn, generator_fn, batch))

    {g_updates, new_g_optimizer_state} =
      generator_optimizer.(g_grads, g_optimizer_state, g_params)

    new_g_params = Axon.Updates.apply_updates(g_params, g_updates)

    # Update Losses
    new_d_loss =
      state[:loss]["discriminator"]
      |> Nx.multiply(state[:i])
      |> Nx.add(d_loss)
      |> Nx.divide(Nx.add(state[:i], 1))

    new_g_loss =
      state[:loss]["generator"]
      |> Nx.multiply(state[:i])
      |> Nx.add(g_loss)
      |> Nx.divide(Nx.add(state[:i], 1))

    new_loss = %{"discriminator" => new_d_loss, "generator" => new_g_loss}
    new_model_state = %{"discriminator" => new_d_params, "generator" => new_g_params}

    new_optimizer_state = %{
      "discriminator" => new_d_optimizer_state,
      "generator" => new_g_optimizer_state
    }

    %{
      model_state: new_model_state,
      optimizer_state: new_optimizer_state,
      loss: new_loss,
      i: Nx.add(state[:i], 1)
    }
  end

  def display_sample(%Axon.Loop.State{step_state: state} = out_state, generator_fn) do
    latent = Nx.random_normal({3, 128})
    %{prediction: out} = generator_fn.(state[:model_state]["generator"], latent)
    out_image = Nx.multiply(out, 255) |> Nx.as_type(:u8)
    upsample = Axon.Layers.resize(out_image, size: {512, 512})

    for i <- 0..2 do
      Kino.Image.new(Nx.reshape(upsample[i], {512, 512, 1})) |> Kino.render()
    end

    {:continue, out_state}
  end
end
discriminator = GAN.discriminator(Axon.input("image"))
generator = GAN.generator(Axon.input("latent"))
{discriminator_init_fn, discriminator_fn} = Axon.build(discriminator, mode: :train)
{generator_init_fn, generator_fn} = Axon.build(generator, mode: :train)
{d_optimizer_init, d_optimizer} = Axon.Optimizers.adam(1.0e-4)
{g_optimizer_init, g_optimizer} = Axon.Optimizers.adam(1.0e-3)
init_fn =
  &amp;GAN.init_state(
    discriminator_init_fn,
    generator_init_fn,
    d_optimizer_init,
    g_optimizer_init,
    &amp;1,
    &amp;2
  )

step_fn = &amp;GAN.train_step(discriminator_fn, generator_fn, d_optimizer, g_optimizer, &amp;1, &amp;2)
step_fn
|> Axon.Loop.loop(init_fn)
|> Axon.Loop.handle(:epoch_completed, &amp;GAN.display_sample(&amp;1, generator_fn))
|> Axon.Loop.log(fn %Axon.Loop.State{epoch: epoch, iteration: iter, step_state: state} ->
  d_loss = state[:loss]["discriminator"]
  g_loss = state[:loss]["generator"]

  "\rEpoch: #{epoch}, batch: #{iter}, d_loss: #{Nx.to_number(d_loss)}, g_loss: #{Nx.to_number(g_loss)}"
end)
|> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 10)