Hands On: Tricky Digits
Mix.install(
[
{:nx, "~> 0.6.3"},
{:exla, "~> 0.6.4"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
MNIST from Chapter 6
defmodule C6.MNIST do
def load_images(path) do
# Open and unzip the file of images then store header information into variables
<<_magic_number::32, n_images::32, n_rows::32, n_cols::32, images::binary>> =
path
|> File.read!()
|> :zlib.gunzip()
images
# Create 1D tensor of type unsigned 8-bit integer from binary
|> Nx.from_binary({:u, 8})
# Reshape the pixels into a matrix into a matrix where each row is an image
|> Nx.reshape({n_images, n_cols * n_rows})
end
@doc """
Inserts a column of 1's into position 0 of tensor X along the the x-axis
"""
def prepend_bias(x) do
{row, _col} = Nx.shape(x)
bias = Nx.broadcast(Nx.tensor(1), {row, 1})
Nx.concatenate([bias, x], axis: 1)
end
def load_labels(filename) do
# Open and unzip the file of images then read all the labels into a list
<<_magic_number::32, n_items::32, labels::binary>> =
filename
|> File.read!()
|> :zlib.gunzip()
labels
# Create 1D tensor of type unsigned 8-bit integer from binary
|> Nx.from_binary({:u, 8})
# Reshape the labels into a 1-column matrix
|> Nx.reshape({n_items, 1})
end
@doc """
Converts all 5s to 1s, and everything else to 0.
This is updated to accept a 2nd argument so we can pass which integer it will be encoded
"""
def encode_to(y, encode_to), do: Nx.equal(y, encode_to)
end
{:module, C6.MNIST, <<70, 79, 82, 49, 0, 0, 13, ...>>, {:encode_to, 2}}
files_path = __DIR__ |> Path.join("../files") |> Path.expand()
training_images_path = Path.join(files_path, "train-images-idx3-ubyte.gz")
training_labels_path = Path.join(files_path, "train-labels-idx1-ubyte.gz")
test_images_path = Path.join(files_path, "t10k-images-idx3-ubyte.gz")
test_labels_path = Path.join(files_path, "t10k-labels-idx1-ubyte.gz")
import C6.MNIST
x_train = training_images_path |> load_images() |> prepend_bias()
x_test = test_images_path |> load_images() |> prepend_bias()
#Nx.Tensor<
s64[10000][785]
EXLA.Backend
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...],
...
]
>
Logistic Regression from Chapter 5
defmodule C5.LogisticRegression do
import Nx.Defn
defn sigmoid(z) do
1 / (1 + Nx.exp(-z))
end
# renaming predict function to forward
def forward(x, w) do
weighted_sum = Nx.dot(x, w)
sigmoid(weighted_sum)
end
def classify(x, w) do
x
|> forward(w)
|> Nx.round()
end
def loss(x, y, w) do
y_hat = forward(x, w)
first_term = Nx.multiply(y, Nx.log(y_hat))
second_term = Nx.multiply(Nx.subtract(1, y), Nx.log(Nx.subtract(1, y_hat)))
first_term
|> Nx.add(second_term)
|> Nx.mean()
|> Nx.multiply(-1)
end
def gradient(x, y, w) do
{num_samples, _} = Nx.shape(x)
x
|> forward(w)
|> Nx.subtract(y)
|> then(&Nx.dot(Nx.transpose(x), &1))
|> Nx.divide(num_samples)
end
def train(x, y, iterations, lr) do
{_, x_cols} = Nx.shape(x)
w = Nx.broadcast(0, {x_cols, 1})
Enum.reduce(0..iterations, w, fn _i, w ->
gradient = gradient(x, y, w)
Nx.subtract(w, Nx.multiply(gradient, lr))
end)
end
def test(x, y, w) do
{total_examples, _} = Nx.shape(x)
correct_results = Nx.sum(classify(x, w) |> Nx.equal(y)) |> Nx.to_number()
Nx.multiply(correct_results, 100) |> Nx.divide(total_examples) |> Nx.to_number()
end
end
{:module, C5.LogisticRegression, <<70, 79, 82, 49, 0, 0, 20, ...>>, {:test, 3}}
Wild guess
I guess the hardest digit to be recognized is 4 because of its resemblance to digit 9 and the easiest is the number 2 but we can prove it by training and testing each digit:
execute = fn digit ->
y_train = training_labels_path |> load_labels() |> encode_to(digit)
y_test = test_labels_path |> load_labels() |> encode_to(digit)
w = C5.LogisticRegression.train(x_train, y_train, 100, 1.0e-5)
success_percent = C5.LogisticRegression.test(x_test, y_test, w)
{digit, success_percent}
end
0..9
|> Task.async_stream(fn digit -> execute.(digit) end, timeout: :infinity)
|> Stream.map(fn {:ok, res} -> res end)
|> Enum.to_list()
|> Enum.sort_by(fn {_, success_percent} -> success_percent end)
[
{8, 93.83999633789062},
{9, 95.58000183105469},
{5, 96.38999938964844},
{3, 96.9800033569336},
{2, 97.37000274658203},
{4, 97.5999984741211},
{6, 98.06999969482422},
{7, 98.13999938964844},
{0, 98.98999786376953},
{1, 99.02999877929688}
]
The result above is a list of digits sorted by increasing corresponding success rate.
As a result of the experiment above the hardest digit to recognize is the number 8
which has a success rate of 94.84%
compared to the easiest to classify which is the number 1
with a success rate of 99.03%