Learning Without Supervision
Mix.install([
{:scidata, "~> 0.1"},
{:axon, "~> 0.5"},
{:exla, "~> 0.5"},
{:nx, "~> 0.5"},
{:kino, "~> 0.8"}
])
The Input Data
batch_size = 64
{{data, type, shape}, _} = Scidata.MNIST.download()
train_data =
data
|> Nx.from_binary(type)
|> Nx.reshape({:auto, 28, 28, 1})
|> Nx.divide(255)
|> Nx.to_batched(batch_size)
Compressing Data with Autoencoders
defmodule Autoencoder do
def encoder(input) do
input
|> Axon.flatten()
|> Axon.dense(256, activation: :relu, name: "encoder_dense_0")
|> Axon.dense(128, activation: :relu, name: "encoder_dense_1")
end
def decoder(input) do
input
|> Axon.dense(256, activation: :relu, name: "decoder_dense_0")
|> Axon.dense(784, activation: :sigmoid, name: "decoder_dense_1")
|> Axon.reshape({:batch, 28, 28, 1})
end
end
model =
Axon.input("image")
|> Autoencoder.encoder()
|> Autoencoder.decoder()
model
trained_model_state =
model
|> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam(1.0e-3))
|> Axon.Loop.run(
Stream.zip(train_data, train_data),
%{},
epochs: 5,
compiler: EXLA
)
[test_batch] = Enum.take(train_data, 1)
test_image = test_batch[0] |> Nx.new_axis(0)
visualize_test_image = fn %Axon.Loop.State{step_state: step_state} = state ->
out_image = Axon.predict(model, step_state[:model_state], test_image, compiler: EXLA)
out_image = Nx.multiply(out_image, 255) |> Nx.as_type(:u8)
Kino.Image.new(Nx.reshape(out_image, {28, 28, 1})) |> Kino.render()
{:continue, state}
end
trained_model_state =
model
|> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam(1.0e-3))
|> Axon.Loop.handle_event(:epoch_completed, visualize_test_image)
|> Axon.Loop.run(Stream.zip(train_data, train_data), %{}, epochs: 5, compiler: EXLA)
decoder_only =
Axon.input("noise")
|> Autoencoder.decoder()
key = Nx.Random.key(42)
{noise, _key} = Nx.Random.normal(key, shape: {1, 128})
out_image = Axon.predict(decoder_only, trained_model_state, noise)
upsampled = Axon.Layers.resize(out_image, size: {512, 512})
out_image =
upsampled
|> Nx.reshape({512, 512, 1})
|> Nx.multiply(255)
|> Nx.as_type(:u8)
Kino.Image.new(out_image)
Learning a Structured Latent
defmodule VAE do
import Nx.Defn
def encoder(input) do
encoded =
input
|> Axon.conv(32,
kernel_size: {3, 3},
activation: :relu,
strides: 2,
padding: :same
)
|> Axon.conv(32,
kernel_size: {3, 3},
activation: :relu,
strides: 2,
padding: :same
)
|> Axon.flatten()
|> Axon.dense(16, activation: :relu)
z_mean = Axon.dense(encoded, 2)
z_log_var = Axon.dense(encoded, 2)
z = Axon.layer(&sample/3, [z_mean, z_log_var], op_name: :sample)
Axon.container({z_mean, z_log_var, z})
end
defnp sample(z_mean, z_log_var, _opts \\ []) do
noise_shape = Nx.shape(z_mean)
epsilon = Nx.random_normal(noise_shape)
z_mean + Nx.exp(0.5 * z_log_var) * epsilon
end
def decoder(input) do
input
|> Axon.dense(7 * 7 * 64, activation: :relu)
|> Axon.reshape({:batch, 7, 7, 64})
|> Axon.conv_transpose(64,
kernel_size: {3, 3},
activation: :relu,
strides: [2, 2],
padding: :same
)
|> Axon.conv_transpose(32,
kernel_size: {3, 3},
activation: :relu,
strides: [2, 2],
padding: :same
)
|> Axon.conv_transpose(1,
kernel_size: {3, 3},
activation: :sigmoid,
padding: :same
)
end
defn train_step(encoder_fn, decoder_fn, optimizer_fn, batch, state) do
{batch_loss, joint_param_grads} =
value_and_grad(
state[:model_state],
&joint_objective(encoder_fn, decoder_fn, batch, &1)
)
{scaled_updates, new_optimizer_state} =
optimizer_fn.(joint_param_grads, state[:optimizer_state], state[:model_state])
new_model_state = Axon.Updates.apply_updates(state[:model_state], scaled_updates)
new_loss =
state[:loss]
|> Nx.multiply(state[:i])
|> Nx.add(batch_loss)
|> Nx.divide(Nx.add(state[:i], 1))
%{
state
| i: Nx.add(state[:i], 1),
loss: new_loss,
model_state: new_model_state,
optimizer_state: new_optimizer_state
}
end
defnp joint_objective(encoder_fn, decoder_fn, batch, joint_params) do
%{prediction: {z_mean, z_log_var, z}} = encoder_fn.(joint_params["encoder"], batch)
%{prediction: reconstruction} = decoder_fn.(joint_params["decoder"], z)
recon_loss = Axon.Losses.binary_cross_entropy(batch, reconstruction, reduction: :mean)
kl_loss = -0.5 * (1 + z_log_var - Nx.pow(z_mean, 2) - Nx.exp(z_log_var))
kl_loss = Nx.mean(Nx.sum(kl_loss, axes: [1]))
recon_loss + kl_loss
end
defn init_step(encoder_init_fn, decoder_init_fn, optimizer_init_fn, batch, init_state) do
encoder_params = encoder_init_fn.(batch, init_state)
decoder_params = decoder_init_fn.(Nx.random_uniform({64, 2}), init_state)
joint_params = %{"encoder" => encoder_params, "decoder" => decoder_params}
optimizer_state = optimizer_init_fn.(joint_params)
%{
i: Nx.tensor(0),
loss: Nx.tensor(0.0),
model_state: joint_params,
optimizer_state: optimizer_state
}
end
def display_sample(%Axon.Loop.State{step_state: state} = out_state, decoder_fn) do
latent = Nx.tensor([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
%{prediction: out} = decoder_fn.(state[:model_state]["decoder"], latent)
out_image = Nx.multiply(out, 255) |> Nx.as_type(:u8)
upsample = Axon.Layers.resize(out_image, size: {512, 512}, channels: :first)
for i <- 0..2 do
Kino.Image.new(Nx.reshape(upsample[i], {512, 512, 1})) |> Kino.render()
end
{:continue, out_state}
end
end
template = Nx.template({1, 128}, :f32)
Axon.Display.as_graph(VAE.decoder(Axon.input("latent")), template)
encoder = Axon.input("image") |> VAE.encoder()
decoder = Axon.input("latent") |> VAE.decoder()
{encoder_init_fn, encoder_fn} = Axon.build(encoder, mode: :train)
{decoder_init_fn, decoder_fn} = Axon.build(decoder, mode: :train)
{optimizer_init_fn, optimizer_fn} = Axon.Optimizers.adam(1.0e-3)
init_fn = &VAE.init_step(encoder_init_fn, decoder_init_fn, optimizer_init_fn, &1, &2)
step_fn = &VAE.train_step(encoder_fn, decoder_fn, optimizer_fn, &1, &2)
step_fn
|> Axon.Loop.loop(init_fn)
|> Axon.Loop.handle_event(:epoch_completed, &VAE.display_sample(&1, decoder_fn))
|> Axon.Loop.log(
fn %Axon.Loop.State{epoch: epoch, iteration: iter, step_state: state} ->
"\rEpoch: #{epoch}, batch: #{iter}, loss: #{Nx.to_number(state[:loss])}"
end,
event: :iteration_completed
)
|> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 10)
Generating with GANs
defmodule GAN do
import Nx.Defn
def discriminator(input) do
input
|> Axon.conv(32, activation: :mish, kernel_size: 3, strides: 2, padding: :same)
|> Axon.layer_norm()
|> Axon.conv(64, activation: :mish, kernel_size: 3, strides: 2, padding: :same)
|> Axon.layer_norm()
|> Axon.flatten()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(1, activation: :sigmoid)
end
def generator(input) do
input
|> Axon.dense(128 * 7 * 7, activation: :mish)
|> Axon.reshape({:batch, 7, 7, 128})
|> Axon.resize({14, 14})
|> Axon.conv(128, kernel_size: 3, padding: :same)
|> Axon.layer_norm()
|> Axon.relu()
|> Axon.resize({28, 28})
|> Axon.conv(64, kernel_size: 3, padding: :same)
|> Axon.layer_norm()
|> Axon.relu()
|> Axon.conv(1, activation: :tanh, kernel_size: 3, padding: :same)
end
defn init_state(
discriminator_init_fn,
generator_init_fn,
discriminator_optimizer_init,
generator_optimizer_init,
batch,
init_state
) do
d_params = discriminator_init_fn.(batch, init_state)
g_params = generator_init_fn.(Nx.random_normal({64, 128}), init_state)
d_optimizer_state = discriminator_optimizer_init.(d_params)
g_optimizer_state = generator_optimizer_init.(g_params)
model_state = %{"discriminator" => d_params, "generator" => g_params}
optimizer_state = %{"discriminator" => d_optimizer_state, "generator" => g_optimizer_state}
loss = %{"discriminator" => Nx.tensor(0.0), "generator" => Nx.tensor(0.0)}
%{
model_state: model_state,
optimizer_state: optimizer_state,
loss: loss,
i: Nx.tensor(0)
}
end
defn g_objective(d_params, g_params, discriminator_fn, generator_fn, real_batch) do
batch_size = Nx.axis_size(real_batch, 0)
real_targets = Nx.broadcast(1, {batch_size, 1})
latent = Nx.random_normal({batch_size, 128})
%{prediction: fake_batch} = generator_fn.(g_params, latent)
%{prediction: fake_labels} = discriminator_fn.(d_params, fake_batch)
Axon.Losses.binary_cross_entropy(real_targets, fake_labels, reduction: :mean)
end
defn d_objective(d_params, g_params, discriminator_fn, generator_fn, real_batch) do
batch_size = Nx.axis_size(real_batch, 0)
real_targets = Nx.broadcast(1, {batch_size, 1})
fake_targets = Nx.broadcast(0, {batch_size, 1})
latent = Nx.random_normal({batch_size, 128})
%{prediction: fake_batch} = generator_fn.(g_params, latent)
%{prediction: real_labels} = discriminator_fn.(d_params, real_batch)
%{prediction: fake_labels} = discriminator_fn.(d_params, fake_batch)
real_loss = Axon.Losses.binary_cross_entropy(real_targets, real_labels, reduction: :mean)
fake_loss = Axon.Losses.binary_cross_entropy(fake_targets, fake_labels, reduction: :mean)
0.5 * real_loss + 0.5 * fake_loss
end
defn train_step(
discriminator_fn,
generator_fn,
discriminator_optimizer,
generator_optimizer,
batch,
state
) do
d_params = state[:model_state]["discriminator"]
g_params = state[:model_state]["generator"]
d_optimizer_state = state[:optimizer_state]["discriminator"]
g_optimizer_state = state[:optimizer_state]["generator"]
# Update discriminator
{d_loss, d_grads} =
value_and_grad(d_params, &d_objective(&1, g_params, discriminator_fn, generator_fn, batch))
{d_updates, new_d_optimizer_state} =
discriminator_optimizer.(d_grads, d_optimizer_state, d_params)
new_d_params = Axon.Updates.apply_updates(d_params, d_updates)
# Update generator
{g_loss, g_grads} =
value_and_grad(g_params, &g_objective(d_params, &1, discriminator_fn, generator_fn, batch))
{g_updates, new_g_optimizer_state} =
generator_optimizer.(g_grads, g_optimizer_state, g_params)
new_g_params = Axon.Updates.apply_updates(g_params, g_updates)
# Update Losses
new_d_loss =
state[:loss]["discriminator"]
|> Nx.multiply(state[:i])
|> Nx.add(d_loss)
|> Nx.divide(Nx.add(state[:i], 1))
new_g_loss =
state[:loss]["generator"]
|> Nx.multiply(state[:i])
|> Nx.add(g_loss)
|> Nx.divide(Nx.add(state[:i], 1))
new_loss = %{"discriminator" => new_d_loss, "generator" => new_g_loss}
new_model_state = %{"discriminator" => new_d_params, "generator" => new_g_params}
new_optimizer_state = %{
"discriminator" => new_d_optimizer_state,
"generator" => new_g_optimizer_state
}
%{
model_state: new_model_state,
optimizer_state: new_optimizer_state,
loss: new_loss,
i: Nx.add(state[:i], 1)
}
end
def display_sample(%Axon.Loop.State{step_state: state} = out_state, generator_fn) do
latent = Nx.random_normal({3, 128})
%{prediction: out} = generator_fn.(state[:model_state]["generator"], latent)
out_image = Nx.multiply(out, 255) |> Nx.as_type(:u8)
upsample = Axon.Layers.resize(out_image, size: {512, 512})
for i <- 0..2 do
Kino.Image.new(Nx.reshape(upsample[i], {512, 512, 1})) |> Kino.render()
end
{:continue, out_state}
end
end
discriminator = GAN.discriminator(Axon.input("image"))
generator = GAN.generator(Axon.input("latent"))
{discriminator_init_fn, discriminator_fn} = Axon.build(discriminator, mode: :train)
{generator_init_fn, generator_fn} = Axon.build(generator, mode: :train)
{d_optimizer_init, d_optimizer} = Axon.Optimizers.adam(1.0e-4)
{g_optimizer_init, g_optimizer} = Axon.Optimizers.adam(1.0e-3)
init_fn =
&GAN.init_state(
discriminator_init_fn,
generator_init_fn,
d_optimizer_init,
g_optimizer_init,
&1,
&2
)
step_fn = &GAN.train_step(discriminator_fn, generator_fn, d_optimizer, g_optimizer, &1, &2)
step_fn
|> Axon.Loop.loop(init_fn)
|> Axon.Loop.handle(:epoch_completed, &GAN.display_sample(&1, generator_fn))
|> Axon.Loop.log(fn %Axon.Loop.State{epoch: epoch, iteration: iter, step_state: state} ->
d_loss = state[:loss]["discriminator"]
g_loss = state[:loss]["generator"]
"\rEpoch: #{epoch}, batch: #{iter}, d_loss: #{Nx.to_number(d_loss)}, g_loss: #{Nx.to_number(g_loss)}"
end)
|> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 10)