Ch 12: Unsupervised
Mix.install([
{:scidata, "~> 0.1"},
{:axon, "~> 0.5"},
{:exla, "~> 0.5"},
{:nx, "~> 0.5"},
{:kino, "~> 0.8"}
])
Data prep
Nx.global_default_backend(EXLA.Backend)
batch_size = 64
{{data, type, shape}, _} = Scidata.MNIST.download()
train_data =
data
|> Nx.from_binary(type)
# imgs are 28px by 28px and grayscale (1 colour channel)
|> Nx.reshape({:auto, 28, 28, 1})
|> Nx.divide(255)
|> Nx.to_batched(batch_size)
Autoencoder
defmodule Autoencoder do
def encoder(input) do
input
|> Axon.flatten()
|> Axon.dense(256, activation: :relu, name: "encoder_dense_0")
|> Axon.dense(128, activation: :relu, name: "encoder_dense_1")
end
def decoder(input) do
input
|> Axon.dense(256, activation: :relu, name: "decoder_dense_0")
|> Axon.dense(784, activation: :sigmoid, name: "decoder_dense_1")
|> Axon.reshape({:batch, 28, 28, 1})
end
end
model =
Axon.input("image")
|> Autoencoder.encoder()
|> Autoencoder.decoder()
test_batch = Enum.at(train_data, 0)
test_image =
test_batch[0]
|> Nx.new_axis(0)
visualize_test_image = fn %Axon.Loop.State{step_state: step_state} = state ->
out_image =
Axon.predict(
model,
step_state[:model_state],
test_image,
compiler: EXLA
)
out_image =
out_image
|> Nx.multiply(255)
# Kino.Image expects image tensors to be created from u8 tensors
|> Nx.as_type(:u8)
|> Nx.reshape({28, 28, 1})
Kino.Image.new(out_image)
|> Kino.render()
{:continue, state}
end
trained_model_state =
model
|> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam(learning_rate: 1.0e-3))
|> Axon.Loop.handle_event(:epoch_completed, visualize_test_image)
|> Axon.Loop.run(
Stream.zip(train_data, train_data),
%{},
epochs: 5,
compiler: EXLA
)
By treating the input as both the input and label, you’ve turned this unsupervised learning problem into a supervised learning problem.
It’s difficult to interpret how this loss corresponds to how well your model is reconstructing compressed inputs. A much more interpretable measure of progress is a periodic visualization of your model’s outputs on some example data, which comes from the anon function bound to visualise_test_image.
In this example, your decoder is actually kind of a generative model—it takes a latent representation and produces an output. Hypothetically, you can skip the encoder altogether and give the decoder some random latent representation, and it should give you an output that resembles a handwritten digit:
decoder_only =
Axon.input("noise")
|> Autoencoder.decoder()
key = Nx.Random.key(42)
{noise, _key} = Nx.Random.normal(key, shape: {1, 128})
out_image = Axon.predict(decoder_only, trained_model_state, noise)
upsampled = Axon.Layers.resize(out_image, size: {512, 512})
out_image =
upsampled
|> Nx.reshape({512, 512, 1})
|> Nx.multiply(255)
|> Nx.as_type(:u8)
Kino.Image.new(out_image)
> Unfortunately, you didn’t force your model to have a latent space with structure. The structure of your encoded representations is at the mercy of a neural network and gradient descent. You can’t pass random uniform or normal noise to your decoder and expect coherent output because your decoder only knows how to handle latent representations produced by the encoder. But what if there was a way to force your encoder to learn a structured representation, which you can easily query later on? Fortunately, there is. Enter the variational autoencoder.
Variational autoencoder
defmodule VAE do
import Nx.Defn
def encoder(input) do
encoded =
input
|> Axon.conv(32,
kernel_size: 3,
activation: :relu,
strides: 2,
padding: :same
)
|> Axon.conv(32,
kernel_size: 3,
activation: :relu,
strides: 2,
padding: :same
)
|> Axon.flatten()
|> Axon.dense(16, activation: :relu)
z_mean = Axon.dense(encoded, 2)
z_log_var = Axon.dense(encoded, 2)
z = Axon.layer(&sample/3, [z_mean, z_log_var], op_name: :sample)
Axon.container({z_mean, z_log_var, z})
end
defnp sample(z_mean, z_log_var, _opts \\ []) do
noise_shape = Nx.shape(z_mean)
# Nx.random_normal removed from lib in 0.7
# epsilon = Nx.random_normal(noise_shape)
key = Nx.Random.key(42)
{epsilon, _} = Nx.Random.normal(key, shape: noise_shape)
z_mean + Nx.exp(0.5 * z_log_var) * epsilon
end
def decoder(input) do
input
|> Axon.dense(7 * 7 * 64, activation: :relu)
|> Axon.reshape({:batch, 7, 7, 64})
|> Axon.conv_transpose(64,
kernel_size: {3, 3},
activation: :relu,
strides: [2, 2],
padding: :same
)
|> Axon.conv_transpose(32,
kernel_size: {3, 3},
activation: :relu,
strides: [2, 2],
padding: :same
)
|> Axon.conv_transpose(1,
kernel_size: {3, 3},
activation: :sigmoid,
padding: :same
)
end
defn train_step(encoder_fn, decoder_fn, optimizer_fn, batch, state) do
{batch_loss, joint_param_grads} =
value_and_grad(
state[:model_state],
&joint_objective(encoder_fn, decoder_fn, batch, &1)
)
{scaled_updates, new_optimizer_state} =
optimizer_fn.(
joint_param_grads,
state[:optimizer_state],
state[:model_state]
)
new_model_state =
Axon.Updates.apply_updates(
state[:model_state],
scaled_updates
)
new_loss =
state[:loss]
|> Nx.multiply(state[:i])
|> Nx.add(batch_loss)
|> Nx.divide(Nx.add(state[:i], 1))
%{
state
| i: Nx.add(state[:i], 1),
loss: new_loss,
model_state: new_model_state,
optimizer_state: new_optimizer_state
}
end
defnp joint_objective(encoder_fn, decoder_fn, batch, joint_params) do
%{prediction: preds} = encoder_fn.(joint_params["encoder"], batch)
{z_mean, z_log_var, z} = preds
%{prediction: reconstruction} = decoder_fn.(joint_params["decoder"], z)
# how well your decoder reconstructs the original batch of images from the encoded representation.
recon_loss =
Axon.Losses.binary_cross_entropy(
batch,
reconstruction,
reduction: :mean
)
# Regularization which penalizes the model from drifting too far away from a normal distribution.
# It basically measures the difference of the distribution defined by the given parameters
# z_mean and z_log_var from a normal distribution with mean 1 and variance 0.
# This term helps coerce your encoder into learning to encode your data as a normal distribution.
kl_loss = -0.5 * (1 + z_log_var - Nx.pow(z_mean, 2) - Nx.exp(z_log_var))
kl_loss = Nx.mean(Nx.sum(kl_loss, axes: [1]))
recon_loss + kl_loss
end
defn init_step(
encoder_init_fn,
decoder_init_fn,
optimizer_init_fn,
batch,
init_state
) do
encoder_params = encoder_init_fn.(batch, init_state)
key = Nx.Random.key(28)
{uniform, _} = Nx.Random.uniform(key, shape: {64, 2})
decoder_params = decoder_init_fn.(uniform, init_state)
joint_params = %{
"encoder" => encoder_params,
"decoder" => decoder_params
}
optimizer_state = optimizer_init_fn.(joint_params)
%{
i: Nx.tensor(0),
loss: Nx.tensor(0.0),
model_state: joint_params,
optimizer_state: optimizer_state
}
end
def display_sample(
%Axon.Loop.State{step_state: state} = out_state,
decoder_fn
) do
latent = Nx.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
%{prediction: out} = decoder_fn.(state[:model_state]["decoder"], latent)
out_image =
Nx.multiply(out, 255)
|> Nx.as_type(:u8)
upsample =
Axon.Layers.resize(
out_image,
size: {512, 512},
channels: :last
)
for i <- 0..2 do
Kino.Image.new(Nx.reshape(upsample[i], {512, 512, 1}))
|> Kino.render()
end
{:continue, out_state}
end
end
template = Nx.template({1, 128}, :f32)
Axon.Display.as_graph(VAE.decoder(Axon.input("latent")), template)
Custom training loop. This needs 2 functions defined, one for state initialisation and another for state updates between steps.
These we’ve defined in the VAE module above as init_step/5 and train_step/5 respectively.
encoder =
Axon.input("image")
|> VAE.encoder()
decoder =
Axon.input("latent")
|> VAE.decoder()
{encoder_init_fn, encoder_fn} = Axon.build(encoder, mode: :train)
{decoder_init_fn, decoder_fn} = Axon.build(decoder, mode: :train)
{optimizer_init_fn, optimizer_fn} = Polaris.Optimizers.adam(learning_rate: 1.0e-3)
init_fn =
&VAE.init_step(
encoder_init_fn,
decoder_init_fn,
optimizer_init_fn,
&1,
&2
)
step_fn =
&VAE.train_step(
encoder_fn,
decoder_fn,
optimizer_fn,
&1,
&2
)
step_fn
|> Axon.Loop.loop(init_fn)
|> Axon.Loop.handle_event(:epoch_completed, &VAE.display_sample(&1, decoder_fn))
|> Axon.Loop.log(
fn %Axon.Loop.State{epoch: epoch, iteration: iter, step_state: state} ->
"\rEpoch: #{epoch}, batch: #{iter}, loss: #{Nx.to_number(state[:loss])}"
end,
event: :iteration_completed,
device: :stdio
)
|> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 10)
GAN
defmodule GAN do
import Nx.Defn
def discriminator(input) do
input
|> Axon.conv(32, activation: :mish, kernel_size: 3, strides: 2, padding: :same)
|> Axon.layer_norm()
|> Axon.conv(64, activation: :mish, kernel_size: 3, strides: 2, padding: :same)
|> Axon.layer_norm()
|> Axon.flatten()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(1, activation: :sigmoid)
end
def generator(input) do
input
|> Axon.dense(128 * 7 * 7, activation: :mish)
|> Axon.reshape({:batch, 7, 7, 128})
|> Axon.resize({14, 14})
|> Axon.conv(128, kernel_size: 3, padding: :same)
|> Axon.layer_norm()
|> Axon.relu()
|> Axon.resize({28, 28})
|> Axon.conv(64, kernel_size: 3, padding: :same)
|> Axon.layer_norm()
|> Axon.relu()
|> Axon.conv(1, activation: :tanh, kernel_size: 3, padding: :same)
end
defn train_step(
discriminator_fn,
generator_fn,
discriminator_optimizer,
generator_optimizer,
batch,
state
) do
d_params = state[:model_state]["discriminator"]
g_params = state[:model_state]["generator"]
d_optimizer_state = state[:optimizer_state]["discriminator"]
g_optimizer_state = state[:optimizer_state]["generator"]
# Update discriminator
{d_loss, d_grads} =
value_and_grad(d_params, fn d_params ->
d_objective(
d_params,
g_params,
discriminator_fn,
generator_fn,
batch
)
end)
{d_updates, new_d_optimizer_state} =
discriminator_optimizer.(
d_grads,
d_optimizer_state,
d_params
)
new_d_params = Axon.Updates.apply_updates(d_params, d_updates)
# Update generator
{g_loss, g_grads} =
value_and_grad(g_params, fn g_params ->
g_objective(
d_params,
g_params,
discriminator_fn,
generator_fn,
batch
)
end)
{g_updates, new_g_optimizer_state} =
generator_optimizer.(
g_grads,
g_optimizer_state,
g_params
)
new_g_params = Axon.Updates.apply_updates(g_params, g_updates)
# Update Losses
new_d_loss =
state[:loss]["discriminator"]
|> Nx.multiply(state[:i])
|> Nx.add(d_loss)
|> Nx.divide(Nx.add(state[:i], 1))
new_g_loss =
state[:loss]["generator"]
|> Nx.multiply(state[:i])
|> Nx.add(g_loss)
|> Nx.divide(Nx.add(state[:i], 1))
new_loss = %{
"discriminator" => new_d_loss,
"generator" => new_g_loss
}
new_model_state = %{
"discriminator" => new_d_params,
"generator" => new_g_params
}
new_optimizer_state = %{
"discriminator" => new_d_optimizer_state,
"generator" => new_g_optimizer_state
}
%{
model_state: new_model_state,
optimizer_state: new_optimizer_state,
loss: new_loss,
i: Nx.add(state[:i], 1)
}
end
defn d_objective(
d_params,
g_params,
discriminator_fn,
generator_fn,
real_batch
) do
batch_size = Nx.axis_size(real_batch, 0)
real_targets = Nx.broadcast(1, {batch_size, 1})
fake_targets = Nx.broadcast(0, {batch_size, 1})
key = Nx.Random.key(832)
{latent, _} = Nx.Random.normal(key, shape: {batch_size, 128})
%{prediction: fake_batch} = generator_fn.(g_params, latent)
%{prediction: real_labels} = discriminator_fn.(d_params, real_batch)
%{prediction: fake_labels} = discriminator_fn.(d_params, fake_batch)
real_loss =
Axon.Losses.binary_cross_entropy(
real_targets,
real_labels,
reduction: :mean
)
fake_loss =
Axon.Losses.binary_cross_entropy(
fake_targets,
fake_labels,
reduction: :mean
)
0.5 * real_loss + 0.5 * fake_loss
end
defn g_objective(
d_params,
g_params,
discriminator_fn,
generator_fn,
real_batch
) do
batch_size = Nx.axis_size(real_batch, 0)
real_targets = Nx.broadcast(1, {batch_size, 1})
key = Nx.Random.key(737)
{latent, _} = Nx.Random.normal(key, shape: {batch_size, 128})
%{prediction: fake_batch} = generator_fn.(g_params, latent)
%{prediction: fake_labels} = discriminator_fn.(d_params, fake_batch)
Axon.Losses.binary_cross_entropy(
real_targets,
fake_labels,
reduction: :mean
)
end
defn init_state(
discriminator_init_fn,
generator_init_fn,
discriminator_optimizer_init,
generator_optimizer_init,
batch,
init_state
) do
d_params = discriminator_init_fn.(batch, init_state)
key = Nx.Random.key(945)
{normal, _} = Nx.Random.normal(key, shape: {64, 128})
g_params = generator_init_fn.(normal, init_state)
d_optimizer_state = discriminator_optimizer_init.(d_params)
g_optimizer_state = generator_optimizer_init.(g_params)
model_state = %{
"discriminator" => d_params,
"generator" => g_params
}
optimizer_state = %{
"discriminator" => d_optimizer_state,
"generator" => g_optimizer_state
}
loss = %{
"discriminator" => Nx.tensor(0.0),
"generator" => Nx.tensor(0.0)
}
%{
model_state: model_state,
optimizer_state: optimizer_state,
loss: loss,
i: Nx.tensor(0)
}
end
def display_sample(
%Axon.Loop.State{step_state: state} = out_state,
generator_fn
) do
key = Nx.Random.key(356)
{latent, _} = Nx.Random.normal(key, shape: {3, 128})
%{prediction: out} = generator_fn.(state[:model_state]["generator"], latent)
out_image =
Nx.multiply(out, 255)
|> Nx.as_type(:u8)
upsample =
Axon.Layers.resize(
out_image,
size: {512, 512},
channels: :last
)
for i <- 0..2 do
Kino.Image.new(Nx.reshape(upsample[i], {512, 512, 1}))
|> Kino.render()
end
{:continue, out_state}
end
end
discriminator = GAN.discriminator(Axon.input("image"))
generator = GAN.generator(Axon.input("latent"))
{discriminator_init_fn, discriminator_fn} =
Axon.build(discriminator, mode: :train)
{generator_init_fn, generator_fn} =
Axon.build(generator, mode: :train)
{d_optimizer_init, d_optimizer} = Polaris.Optimizers.adam(learning_rate: 1.0e-4)
{g_optimizer_init, g_optimizer} = Polaris.Optimizers.adam(learning_rate: 1.0e-3)
init_fn =
&GAN.init_state(
discriminator_init_fn,
generator_init_fn,
d_optimizer_init,
g_optimizer_init,
&1,
&2
)
step_fn =
&GAN.train_step(
discriminator_fn,
generator_fn,
d_optimizer,
g_optimizer,
&1,
&2
)
Notice that you’re using slightly different optimizers for each model. To prevent the discriminator from dominating the generator before it can get its bearings during training, you need to lower the learning rate of the discriminator.
step_fn
|> Axon.Loop.loop(init_fn)
|> Axon.Loop.handle_event(
:epoch_completed,
&GAN.display_sample(&1, generator_fn)
)
|> Axon.Loop.log(
fn
%Axon.Loop.State{epoch: epoch, iteration: iter, step_state: state} ->
d_loss = state[:loss]["discriminator"]
g_loss = state[:loss]["generator"]
"\rEpoch: #{epoch}, batch: #{iter}," <>
" d_loss: #{Nx.to_number(d_loss)}," <>
" g_loss: #{Nx.to_number(g_loss)}"
end,
event: :iteration_completed,
device: :stdio
)
|> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 10)