Human-Scale AI: 7-segment display
Mix.install([
{:axon, "~> 0.6"},
{:exla, "~> 0.7"},
{:kino, "~> 0.12.3"},
{:phoenix_html, "~> 4.0"},
{:phoenix_live_view, "~> 0.20"}
])
Data
defmodule SevenSegment.Number do
@moduledoc """
Data representation for both the 0-9 digits and the associated 7-segment bit patterns.
There's no number data structure per. se - this module just converts (in both directions)
between 0-9 integers (digits) and 7-element lists of 0/1 (bitlists).
"""
@bitlists [
[1, 1, 1, 0, 1, 1, 1],
[0, 0, 1, 0, 0, 1, 0],
[1, 0, 1, 1, 1, 0, 1],
[1, 0, 1, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 0],
[1, 1, 0, 1, 0, 1, 1],
[1, 1, 0, 1, 1, 1, 1],
[1, 0, 1, 0, 0, 1, 0],
[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 1, 1]
]
@doc """
Return the bitlist for a given dikit (0-9)
This function will raise if `digit` is not a single (0-9) digit.
iex> SevenSegment.Number.encode_digit!(1)
[0, 0, 1, 0, 0, 1, 0]
iex> SevenSegment.Number.encode_digit!(5)
[1, 1, 0, 1, 0, 1, 1]
"""
def encode_digit!(digit) do
unless digit in 0..9, do: "digit must be 0-9"
Enum.at(@bitlists, digit)
end
@doc """
Return the digit for a given bitlist (0-9)
This function will raise if the bitlist doesn't correspond to a single (0-9) digit.
iex> SevenSegment.Number.decode_digit!([1, 1, 1, 1, 1, 1, 1])
8
iex> SevenSegment.Number.decode_digit!([1, 1, 0, 1, 0, 1, 1])
5
"""
def decode_digit!(bitlist) do
digit = Enum.find_index(@bitlists, fn bp -> bp == bitlist end)
digit || raise "bitlist did not correspond to a digit 0-9"
end
end
Model definition
defmodule SevenSegment.Model do
@moduledoc """
Helper module for defining fully-connected networks of different sizes.
This module is a leaky abstraction - the returned models are [Axon](https://hexdocs.pm/axon/)
data structures. If you just follow this notebook you (probably) don't need to understand
how they work.
"""
@doc """
Create a fully-conneted model
The model will have a 7-dimensional input (for the bitlists) and a 10-dimensional
output (for the softmax predictions; one for each digit 0-9).
`hidden_layer_sizes` should be a list of sizes for the hidden layers.
Example: create a networks with a single hidden layer of 2 neurons:
iex> SevenSegment.Model.new([2])
#Axon<
inputs: %{"bitlist" => {nil, 7}}
outputs: "softmax_0"
nodes: 5
>
"""
def new(hidden_layer_sizes) do
input = Axon.input("bitlist", shape: {nil, 7})
hidden_layer_sizes
|> Enum.reduce(input, fn layer_size, model ->
Axon.dense(model, layer_size, activation: :relu)
end)
|> Axon.dense(10, activation: :softmax)
end
end
Training
defmodule SevenSegment.Train do
@moduledoc """
Create datasets and train models.
"""
@doc """
Create a training set of bitlists for use as a training set.
Compared to most AI problems this is _extremely_ trivial; there are only
10 digits, and each one has one unambiguous bitlist representation, so
z there are only 10 pairs in the training set. Toy problems ftw :)
The output won't be a list of lists, it'll be an [Nx](https://hexdocs.pm/nx/) tensor,
because that's what's expected by the trainingkcode.
Note that the returned tensor won't include the digits explicitly, but the digits can be used to index
into the `:digit` axis to get the correct bitlist, e.g.
iex> train_data = SevenSegment.Train.inputs()
iex> train_data[[digit: 0]]
#Nx.Tensor<
u8[bitlist: 7]
[1, 1, 1, 0, 1, 1, 1]
>
"""
def inputs() do
0..9
|> Enum.map(&SevenSegment.Number.encode_digit!/1)
|> Nx.tensor(names: [:digit, :bitlist], type: :u8)
end
@doc """
Return a tensor of the (one-hot-encoded) digits 0-9 (one per row).
"""
def targets() do
0..9
|> Enum.to_list()
|> Nx.tensor(type: :u8, names: [:digit])
|> Nx.new_axis(-1, :one_hot)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
end
@doc "convenience function for building an {inputs, targets} tuple of tensors for use in training"
def training_set() do
{inputs(), targets()}
end
@doc """
Run the training procedure, returning a map of (trained) params
"""
def run(model, inputs, targets, opts \\ []) do
# since this training set is so small, use batches of size 1
data = Enum.zip(Nx.to_batched(inputs, 1), Nx.to_batched(targets, 1))
opts = Keyword.merge(opts, epochs: 1000, compiler: EXLA)
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(data, %{}, opts)
end
end
Now that we’ve prepared our infrastructure, we can run the training procedure to get the weights for our trained network.
model = SevenSegment.Model.new([4])
{inputs, targets} = SevenSegment.Train.training_set()
params = SevenSegment.Train.run(model, inputs, targets)
Predictions/inference
defmodule SevenSegment.Predict do
@moduledoc """
Run single-shot inference for a trained model.
Intended use:
- `model` comes from `SevenSegment.Model.new/1`
- `params` comes from `SevenSegment.Train.run/4`
"""
@doc """
For a given `digit` 0-9, return the predicted class distribution under `model`.
"""
def from_digit(model, params, digit) do
input = SevenSegment.Number.encode_digit!(digit) |> Nx.tensor() |> Nx.new_axis(0)
Axon.predict(model, params, input)
end
end
Visualisation
# visualise the confusion matrix (i.e. perfect model will have all white on diagonal, black otherwise)
Axon.predict(model, params, inputs) |> Nx.to_heatmap()
defmodule SevenSegment.Vis do
@segment_paths [
"M 190.79731,72.5 L 175.58534,88 L 116.58535,88 L 101.06756,72.5 L 116.58535,57 L 175.58534,57 L 190.79731,72.5 z ",
"M 98,75.38513 L 113.5,90.59709 L 113.5,135.59708 L 98,151.11487 L 82.5,135.59708 L 82.5,90.59709 L 98,75.38513 z ",
"M 194,75.38513 L 209.5,90.59709 L 209.5,135.59708 L 194,151.11487 L 178.5,135.59708 L 178.5,90.59709 L 194,75.38513 z ",
"M 190.79731,154 L 175.58534,169.5 L 116.58535,169.5 L 101.06756,154 L 116.58535,138.5 L 175.58534,138.5 L 190.79731,154 z",
"M 98,157.44257 L 113.5,172.65453 L 113.5,217.65452 L 98,233.1723 L 82.5,217.65452 L 82.5,172.65453 L 98,157.44257 z ",
"M 194,157.44257 L 209.5,172.65453 L 209.5,217.65452 L 194,233.1723 L 178.5,217.65452 L 178.5,172.65453 L 194,157.44257 z ",
"M 190.79731,236.05743 L 175.58534,251.55743 L 116.58535,251.55743 L 101.06756,236.05743 L 116.58535,220.55743 L 175.58534,220.55743 L 190.79731,236.05743 z "
]
def digit(segments, transform \\ nil) do
Enum.zip(segments, @segment_paths)
|> Enum.map(fn {s, d} ->
color_attrs =
case s do
0 -> ~s|fill="none" stroke="#EEE"|
1 -> ~s|fill="red" stroke="red"|
end
~s||
end)
|> then(fn path_components ->
"""
#{path_components}
"""
end)
end
end
SevenSegment.Number.encode_digit!(0)
|> SevenSegment.Vis.digit()
|> Kino.HTML.new()