Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 11

chapter11.livemd

Chapter 11

Mix.install(
  [
    {:nx, "~> 0.6.4"},
    {:exla, "~> 0.6.4"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Building the network

defmodule MNIST do
  def load_images(:train), do: load_images(__DIR__ <> "/data/mnist/train-images-idx3-ubyte.gz")
  def load_images(:test), do: load_images(__DIR__ <> "/data/mnist/t10k-images-idx3-ubyte.gz")

  def load_images(filename) when is_binary(filename) do
    <<2051::32, count::32, rows::32, cols::32, pixels::binary>> =
      File.read!(filename)
      |> :zlib.gunzip()

    Nx.from_binary(pixels, :u8)
    |> Nx.reshape({count, rows * cols})
  end

  def load_labels(:train), do: load_labels(__DIR__ <> "/data/mnist/train-labels-idx1-ubyte.gz")
  def load_labels(:test), do: load_labels(__DIR__ <> "/data/mnist/t10k-labels-idx1-ubyte.gz")

  def load_labels(filename) when is_binary(filename) do
    <<2049::32, count::32, labels::binary>> =
      File.read!(filename)
      |> :zlib.gunzip()

    Nx.from_binary(labels, :u8)
    |> Nx.reshape({count})
  end
end

defmodule OHE do
  import Nx.Defn

  defn one_hot_encode(y, opts \\ [classes: 10]) do
    Nx.new_axis(y, -1)
    |> Nx.equal(Nx.iota({opts[:classes]}))
  end
end
defmodule NN do
  import Nx.Defn

  defn prepend_bias(x) do
    Nx.pad(x, 1, [{0, 0, 0}, {1, 0, 0}])
  end

  defn softmax(x) do
    Nx.exp(x) / Nx.sum(Nx.exp(x), axes: [1], keep_axes: true)
  end

  defn sigmoid_gradient(s) do
    s * (1 - s)
  end

  defn loss(y, y_hat) do
    -Nx.sum(y * Nx.log(y_hat)) / Nx.axis_size(y, 0)
  end

  defn forward(x, w1, w2) do
    h =
      prepend_bias(x)
      |> Nx.dot(w1)
      |> Nx.sigmoid()

    y_hat =
      prepend_bias(h)
      |> Nx.dot(w2)
      |> softmax()

    {y_hat, h}
  end

  defn back(x, y, y_hat, w2, h) do
    w2_gradient =
      prepend_bias(h)
      |> Nx.transpose()
      |> Nx.dot(y_hat - y)
      |> Nx.divide(Nx.axis_size(x, 0))

    a_gradient =
      (y_hat - y)
      |> Nx.dot(Nx.transpose(w2[1..-1//1]))
      |> Nx.multiply(sigmoid_gradient(h))

    w1_gradient =
      prepend_bias(x)
      |> Nx.transpose()
      |> Nx.dot(a_gradient)
      |> Nx.divide(Nx.axis_size(x, 0))

    {w1_gradient, w2_gradient}
  end

  defn classify(x, w1, w2) do
    forward(x, w1, w2)
    |> elem(0)
    |> Nx.argmax(axis: 1)
  end

  def initialize_weights(n_input_variables, n_hidden_nodes, n_classes) do
    {w1_rows, w2_rows} = {n_input_variables + 1, n_hidden_nodes + 1}

    key = Nx.Random.key(42)
    {w1_dist, key} = Nx.Random.normal(key, shape: {w1_rows, n_hidden_nodes})
    {w2_dist, _key} = Nx.Random.normal(key, shape: {w2_rows, n_classes})

    {Nx.multiply(w1_dist, Nx.sqrt(Nx.divide(1, w1_rows))),
     Nx.multiply(w2_dist, Nx.sqrt(Nx.divide(1, w2_rows)))}
  end

  def train(x_train, y_train, x_test, y_test, n_hidden_nodes, iterations, lr) do
    n_input_variables = Nx.axis_size(x_train, 1)
    n_classes = Nx.axis_size(y_train, 1)
    init_weights = initialize_weights(n_input_variables, n_hidden_nodes, n_classes)

    Enum.reduce(1..iterations, init_weights, fn i, {w1, w2} ->
      {y_hat, h} = forward(x_train, w1, w2)
      {w1_gradient, w2_gradient} = back(x_train, y_train, y_hat, w2, h)

      {w1, w2} =
        {Nx.subtract(w1, Nx.multiply(w1_gradient, lr)),
         Nx.subtract(w2, Nx.multiply(w2_gradient, lr))}

      report(i, x_train, y_train, x_test, y_test, w1, w2)
      {w1, w2}
    end)
  end

  def report(i, x_train, y_train, x_test, y_test, w1, w2) do
    {y_hat, _h} = forward(x_train, w1, w2)
    training_loss = loss(y_train, y_hat)
    accuracy = Nx.equal(classify(x_test, w1, w2), y_test) |> Nx.mean()

    IO.puts("#{i} - Loss #{Nx.to_number(training_loss)}, Accuracy #{Nx.to_number(accuracy)}")
  end
end
x_train = MNIST.load_images(:train)
x_test = MNIST.load_images(:test)

y_train = MNIST.load_labels(:train) |> OHE.one_hot_encode()
y_test = MNIST.load_labels(:test)
NN.train(x_train, y_train, x_test, y_test, 200, 10000, 0.01)

Starting off wrong

defmodule NNBad do
  def initialize_weights(n_input_variables, n_hidden_nodes, n_classes) do
    {w1_rows, w2_rows} = {n_input_variables + 1, n_hidden_nodes + 1}

    w1 = Nx.broadcast(0, {w1_rows, n_hidden_nodes})
    w2 = Nx.broadcast(0, {w2_rows, n_classes})

    {w1, w2}
  end

  def train(x_train, y_train, x_test, y_test, n_hidden_nodes, iterations, lr) do
    n_input_variables = Nx.axis_size(x_train, 1)
    n_classes = Nx.axis_size(y_train, 1)
    init_weights = initialize_weights(n_input_variables, n_hidden_nodes, n_classes)

    Enum.reduce(1..iterations, init_weights, fn i, {w1, w2} ->
      {y_hat, h} = NN.forward(x_train, w1, w2)
      {w1_gradient, w2_gradient} = NN.back(x_train, y_train, y_hat, w2, h)

      {w1, w2} =
        {Nx.subtract(w1, Nx.multiply(w1_gradient, lr)),
         Nx.subtract(w2, Nx.multiply(w2_gradient, lr))}

      NN.report(i, x_train, y_train, x_test, y_test, w1, w2)
      {w1, w2}
    end)
  end
end
NNBad.train(x_train, y_train, x_test, y_test, 10, 10000, 0.01)