Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Neural Network

learn_neural_network.livemd

Neural Network

Mix.install([
  {:axon, "~> 0.6.0"},
  {:nx, "~> 0.6.4"},
  {:exla, "~> 0.6.4"},
  {:scidata, "~> 0.1.11"},
  {:kino, "~> 0.12.3"}
])

Section

{images, labels} = Scidata.MNIST.download()
{images_data, images_type, images_shape} = images

images_tensor =
  images_data
  |> Nx.from_binary(images_type)
  |> Nx.reshape(images_shape)
  |> Nx.divide(255)
  |> Nx.reshape({60_000, :auto})
{labels_data, labels_type, labels_shape} = labels

labels_tensor =
  labels_data
  |> Nx.from_binary(labels_type)
  |> Nx.reshape(labels_shape)
  |> Nx.new_axis(-1)

The Model

model =
  Axon.input("feature", shape: {nil, 784})
  |> Axon.dense(128)
  |> Axon.relu()
  |> Axon.dense(10)
  |> Axon.softmax(name: "labels")
Axon.Display.as_graph(model, Nx.template({1, 784}, :f32))

Training

images_train_data = Nx.to_batched(images_tensor, 32)
labels_train_data = Nx.to_batched(labels_tensor, 32)
train_data = Stream.zip(images_train_data, labels_train_data)
Enum.take(train_data, 1)
train_data =
  Enum.map(train_data, fn {images_tensor, labels_tensor} ->
    {images_tensor, Nx.equal(labels_tensor, Nx.iota({10}))}
  end)
trained_model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.run(train_data, %{}, compiler: EXLA, epochs: 20)

Evaluation

{images, labels} = Scidata.MNIST.download()
{images_data, images_type, images_shape} = images

images_tensor =
  images_data
  |> Nx.from_binary(images_type)
  |> Nx.reshape(images_shape)
  |> Nx.divide(255)
  |> Nx.reshape({60_000, :auto})
{labels_data, labels_type, labels_shape} = labels

labels_tensor =
  labels_data
  |> Nx.from_binary(labels_type)
  |> Nx.reshape(labels_shape)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({10}))
images_test_data = Nx.to_batched(images_tensor, 32)
labels_test_data = Nx.to_batched(labels_tensor, 32)
test_data = Stream.zip(images_test_data, labels_test_data)
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)