Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Programming Machine Learning - Chapter 6

livebook/ch06.livemd

Programming Machine Learning - Chapter 6

Mix.install([
  {:scidata, "~> 0.1.9"},
  {:exla, "~> 0.4.1"},
  {:nx, "~> 0.4.1"}
])

Section

{train_images, train_labels} = Scidata.MNIST.download()
# {test_images, test_labels} = Scidata.MNIST.download_test()

# Normalize and batch images
{images_binary, images_type, images_shape} = train_images

# IO.inspect(images_binary)
IO.inspect(images_shape, label: "train images_shape")

# batched_images =
#   images_binary
#   |> Nx.from_binary(images_type)
#   |> Nx.reshape(images_shape)
#   |> Nx.divide(255)
#   |> Nx.to_batched(32)

# # One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels

# IO.inspect(labels_type, label: "train labels_type")

# batchd_labels =
#   labels_binary
#   |> Nx.from_binary(labels_type)
#   |> Nx.new_axis(-1)
#   |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
#   |> Nx.to_batched(32)

ones = Nx.broadcast(1, {60_000, 1})

images =
  images_binary
  |> Nx.from_binary(images_type)
  |> Nx.reshape({60_000, 784})

# |> Nx.divide(255)

# IO.inspect(ones)
# IO.inspect(images)

x = Nx.concatenate([ones, images], axis: 1)

labels =
  labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.equal(5)
defmodule Ch06 do
  import Nx.Defn

  def train(x, y, iterations, lr) do
    {_x_rows, x_cols} = Nx.shape(x)

    w = Nx.broadcast(0, {x_cols, 1})

    for i <- 1..iterations, reduce: w do
      w_acc ->
        IO.puts("#{i} => Loss: #{Nx.to_number(loss(x, y, w_acc))}")
        update(x, y, lr, w_acc)
    end
  end

  defn classify(x, w) do
    forward(x, w) |> Nx.round()
  end

  def test(x, y, w) do
    {total_examples, _} = Nx.shape(x)
    correct_results = Nx.sum(Nx.equal(classify(x, w), y))

    {Nx.to_number(correct_results), total_examples}
  end

  # -- Private

  # Was previously `predict`.
  defnp forward(x, w) do
    Nx.sigmoid(Nx.dot(x, w))
  end

  defnp loss(x, y, w) do
    y_hat = forward(x, w)
    first_term = y * Nx.log(y_hat)
    second_term = (1 - y) * Nx.log(1 - y_hat)

    -Nx.mean(first_term + second_term)
  end

  defnp gradient(x, y, w) do
    grad(w, &amp;loss(x, y, &amp;1))
  end

  defnp update(x, y, lr, w) do
    w - gradient(x, y, w) * lr
  end
end
w = Ch06.train(x, labels, 100, 0.00001)