Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Handwriting Recognition

notebook.livemd

Handwriting Recognition

Mix.install([
  {:exla, ">= 0.0.0", github: "elixir-nx/nx", ref: "6fda0ce5ec", sparse: "exla"},
  {:nx, ">= 0.0.0", github: "elixir-nx/nx", ref: "6fda0ce5ec", sparse: "nx", override: true},
  {:axon, ">= 0.0.0", github: "elixir-nx/axon", ref: "be94bba7a"}
])

Download data

:inets.start()
:ssl.start()

{:ok, {{_, 200, _}, _headers, train_body}} =
  :httpc.request("https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz")

:ok
{:ok, {{_, 200, _}, _headers, label_body}} =
  :httpc.request("https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz")

:ok
<<_::32, n_images::32, n_rows::32, n_cols::32, train_body::binary>> = :zlib.gunzip(train_body)

train_tensor =
  train_body
  |> Nx.from_binary({:u, 8})
  |> Nx.reshape({n_images, n_rows, n_cols})
  |> Nx.divide(255)

Nx.to_heatmap(train_tensor)
<<_::32, ^n_images::32, label_body::binary>> = :zlib.gunzip(label_body)

label_tensor =
  label_body
  |> Nx.from_binary({:u, 8})
  |> Nx.reshape({n_images, 1})
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

Create Model

require Axon

model =
  Axon.input({nil, n_rows, n_cols})
  |> Axon.flatten()
  |> Axon.dense(128, activation: :sigmoid)
  |> Axon.dense(10, activation: :softmax)

Train Model

train_batch = Nx.to_batched_list(train_tensor, 32)
label_batch = Nx.to_batched_list(label_tensor, 32)

%{params: trained_params} =
  model
  |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.sgd(0.01))
  |> Axon.Training.train(train_batch, label_batch, epochs: 10, compiler: EXLA)

Results

results =
  model
  |> Axon.predict(trained_params, hd(train_batch))
  |> Nx.argmax(axis: 1)
  |> Nx.to_flat_list()

Show Input and Result

inputs = Nx.to_batched_list(train_tensor, 1)

results
|> Enum.with_index()
|> Enum.map(fn {result, i} ->
  heatmap =
    inputs
    |> Enum.at(i)
    |> Nx.to_heatmap()

  {heatmap, result}
end)