Programming Machine Learning - Chapter 6
Mix.install([
{:scidata, "~> 0.1.9"},
{:exla, "~> 0.4.1"},
{:nx, "~> 0.4.1"}
])
Section
{train_images, train_labels} = Scidata.MNIST.download()
# {test_images, test_labels} = Scidata.MNIST.download_test()
# Normalize and batch images
{images_binary, images_type, images_shape} = train_images
# IO.inspect(images_binary)
IO.inspect(images_shape, label: "train images_shape")
# batched_images =
# images_binary
# |> Nx.from_binary(images_type)
# |> Nx.reshape(images_shape)
# |> Nx.divide(255)
# |> Nx.to_batched(32)
# # One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels
# IO.inspect(labels_type, label: "train labels_type")
# batchd_labels =
# labels_binary
# |> Nx.from_binary(labels_type)
# |> Nx.new_axis(-1)
# |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
# |> Nx.to_batched(32)
ones = Nx.broadcast(1, {60_000, 1})
images =
images_binary
|> Nx.from_binary(images_type)
|> Nx.reshape({60_000, 784})
# |> Nx.divide(255)
# IO.inspect(ones)
# IO.inspect(images)
x = Nx.concatenate([ones, images], axis: 1)
labels =
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.equal(5)
defmodule Ch06 do
import Nx.Defn
def train(x, y, iterations, lr) do
{_x_rows, x_cols} = Nx.shape(x)
w = Nx.broadcast(0, {x_cols, 1})
for i <- 1..iterations, reduce: w do
w_acc ->
IO.puts("#{i} => Loss: #{Nx.to_number(loss(x, y, w_acc))}")
update(x, y, lr, w_acc)
end
end
defn classify(x, w) do
forward(x, w) |> Nx.round()
end
def test(x, y, w) do
{total_examples, _} = Nx.shape(x)
correct_results = Nx.sum(Nx.equal(classify(x, w), y))
{Nx.to_number(correct_results), total_examples}
end
# -- Private
# Was previously `predict`.
defnp forward(x, w) do
Nx.sigmoid(Nx.dot(x, w))
end
defnp loss(x, y, w) do
y_hat = forward(x, w)
first_term = y * Nx.log(y_hat)
second_term = (1 - y) * Nx.log(1 - y_hat)
-Nx.mean(first_term + second_term)
end
defnp gradient(x, y, w) do
grad(w, &loss(x, y, &1))
end
defnp update(x, y, lr, w) do
w - gradient(x, y, w) * lr
end
end
w = Ch06.train(x, labels, 100, 0.00001)