Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Deep Learning And Axon

deep-learning-axon.livemd

Deep Learning And Axon

Mix.install([
  {:nx, "~> 0.5"},
  {:axon, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:scidata, "~> 0.1"},
  {:kino, "~> 0.8"},
  {:table_rex, "~> 3.1.1"}
])

Nx.default_backend(EXLA.Backend)

Building a Neural Network

defmodule NeuralNetwork do
  import Nx.Defn

  defn dense(input, weight, bias) do
    input
    |> Nx.dot(weight)
    |> Nx.add(bias)
  end

  defn activation(input) do
    Nx.sigmoid(input)
  end

  defn hidden(input, weight, bias) do
    input
    |> dense(weight, bias)
    |> activation()
  end

  defn output(input, weight, bias) do
    input
    |> dense(weight, bias)
    |> activation()
  end

  defn predict(input, w1, b1, w2, b2) do
    input
    |> hidden(w1, b1)
    |> output(w2, b2)
  end
end
key = Nx.Random.key(42)
{w1, new_key} = Nx.Random.uniform(key)
{b1, new_key} = Nx.Random.uniform(new_key)
{w2, new_key} = Nx.Random.uniform(new_key)
{b2, new_key} = Nx.Random.uniform(new_key)

Nx.Random.uniform_split(new_key, 0, 1.0, shape: {})
|> NeuralNetwork.predict(w1, b1, w2, b2)

Getting Started With Axon

{images, labels} = Scidata.MNIST.download()

Transforming the Data

{image_data, image_type, image_shape} = images
{label_data, label_type, label_shape} = labels

images =
  image_data
  |> Nx.from_binary(image_type)
  |> Nx.divide(255)
  |> Nx.reshape({60000, :auto})

labels =
  label_data
  |> Nx.from_binary(label_type)
  |> Nx.reshape(label_shape)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 10}))
train_range = 0..49_999//1
test_range = 50_000..-1//1

train_images = images[train_range]
train_labels = labels[train_range]

test_images = images[test_range]
test_labels = labels[test_range]
batch_size = 64

train_data =
  train_images
  |> Nx.to_batched(batch_size)
  |> Stream.zip(Nx.to_batched(train_labels, batch_size))

test_data =
  test_images
  |> Nx.to_batched(batch_size)
  |> Stream.zip(Nx.to_batched(test_labels, batch_size))
model =
  Axon.input("images", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)
template = Nx.template({1, 784}, :f32)
Axon.Display.as_graph(model, template)
Axon.Display.as_table(model, template)
|> IO.puts()
IO.inspect(model, structs: false)

Run the Training Loop

trained_model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)

Testing the Accuracy

model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
{test_batch, _} = Enum.at(test_data, 0)

test_image = test_batch[0]

test_image
|> Nx.reshape({28, 28})
|> Nx.to_heatmap()
{_, predict_fn} = Axon.build(model, compiler: EXLA)

probabilities =
  test_image
  |> Nx.new_axis(0)
  |> then(&predict_fn.(trained_model_state, &1))
probabilities |> Nx.argmax()