Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us
Notesclub

Human-Scale AI: 7-segment display

7-segment-display.livemd

Human-Scale AI: 7-segment display

Mix.install([
  {:axon, "~> 0.6"},
  {:exla, "~> 0.7"},
  {:kino, "~> 0.12.3"},
  {:phoenix_html, "~> 4.0"},
  {:phoenix_live_view, "~> 0.20"}
])

Data

defmodule SevenSegment.Number do
  @moduledoc """
  Data representation for both the 0-9 digits and the associated 7-segment bit patterns.

  There's no number data structure per. se - this module just converts (in both directions)
  between 0-9 integers (digits) and 7-element lists of 0/1 (bitlists).
  """

  @bitlists [
    [1, 1, 1, 0, 1, 1, 1],
    [0, 0, 1, 0, 0, 1, 0],
    [1, 0, 1, 1, 1, 0, 1],
    [1, 0, 1, 1, 0, 1, 1],
    [0, 1, 1, 1, 0, 1, 0],
    [1, 1, 0, 1, 0, 1, 1],
    [1, 1, 0, 1, 1, 1, 1],
    [1, 0, 1, 0, 0, 1, 0],
    [1, 1, 1, 1, 1, 1, 1],
    [1, 1, 1, 1, 0, 1, 1]
  ]

  @doc """
  Return the bitlist for a given dikit (0-9)

  This function will raise if `digit` is not a single (0-9) digit.

      iex> SevenSegment.Number.encode_digit!(1)
      [0, 0, 1, 0, 0, 1, 0]

      iex> SevenSegment.Number.encode_digit!(5)
      [1, 1, 0, 1, 0, 1, 1]
  """
  def encode_digit!(digit) do
    unless digit in 0..9, do: "digit must be 0-9"
    Enum.at(@bitlists, digit)
  end

  @doc """
  Return the digit for a given bitlist (0-9)

  This function will raise if the bitlist doesn't correspond to a single (0-9) digit.

      iex> SevenSegment.Number.decode_digit!([1, 1, 1, 1, 1, 1, 1])
      8

      iex> SevenSegment.Number.decode_digit!([1, 1, 0, 1, 0, 1, 1])
      5
  """
  def decode_digit!(bitlist) do
    digit = Enum.find_index(@bitlists, fn bp -> bp == bitlist end)
    digit || raise "bitlist did not correspond to a digit 0-9"
  end
end

Model definition

defmodule SevenSegment.Model do
  @moduledoc """
  Helper module for defining fully-connected networks of different sizes.

  This module is a leaky abstraction - the returned models are [Axon](https://hexdocs.pm/axon/)
  data structures. If you just follow this notebook you (probably) don't need to understand
  how they work.
  """

  @doc """
  Create a fully-conneted model

  The model will have a 7-dimensional input (for the bitlists) and a 10-dimensional
  output (for the softmax predictions; one for each digit 0-9).

  `hidden_layer_sizes` should be a list of sizes for the hidden layers.

  Example: create a networks with a single hidden layer of 2 neurons:

      iex> SevenSegment.Model.new([2])
      #Axon<
        inputs: %{"bitlist" => {nil, 7}}
        outputs: "softmax_0"
        nodes: 5
      >

  """
  def new(hidden_layer_sizes) do
    input = Axon.input("bitlist", shape: {nil, 7})

    hidden_layer_sizes
    |> Enum.reduce(input, fn layer_size, model ->
      Axon.dense(model, layer_size, activation: :relu)
    end)
    |> Axon.dense(10, activation: :softmax)
  end
end

Training

defmodule SevenSegment.Train do
  @moduledoc """
  Create datasets and train models.
  """

  @doc """
  Create a training set of bitlists for use as a training set.

  Compared to most AI problems this is _extremely_ trivial; there are only
  10 digits, and each one has one unambiguous bitlist representation, so
  z there are only 10 pairs in the training set. Toy problems ftw :)

  The output won't be a list of lists, it'll be an [Nx](https://hexdocs.pm/nx/) tensor, 
  because that's what's expected by the trainingkcode.

  Note that the returned tensor won't include the digits explicitly, but the digits can be used to index
  into the `:digit` axis to get the correct bitlist, e.g.

      iex> train_data = SevenSegment.Train.inputs()
      iex> train_data[[digit: 0]]
      #Nx.Tensor<
        u8[bitlist: 7]
        [1, 1, 1, 0, 1, 1, 1]
      >
  """
  def inputs() do
    0..9
    |> Enum.map(&amp;SevenSegment.Number.encode_digit!/1)
    |> Nx.tensor(names: [:digit, :bitlist], type: :u8)
  end

  @doc """
  Return a tensor of the (one-hot-encoded) digits 0-9 (one per row).
  """
  def targets() do
    0..9
    |> Enum.to_list()
    |> Nx.tensor(type: :u8, names: [:digit])
    |> Nx.new_axis(-1, :one_hot)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  end

  @doc "convenience function for building an {inputs, targets} tuple of tensors for use in training"
  def training_set() do
    {inputs(), targets()}
  end

  @doc """
  Run the training procedure, returning a map of (trained) params
  """
  def run(model, inputs, targets, opts \\ []) do
    # since this training set is so small, use batches of size 1
    data = Enum.zip(Nx.to_batched(inputs, 1), Nx.to_batched(targets, 1))

    opts = Keyword.merge(opts, epochs: 1000, compiler: EXLA)

    model
    |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
    |> Axon.Loop.metric(:accuracy, "Accuracy")
    |> Axon.Loop.run(data, %{}, opts)
  end
end

Now that we’ve prepared our infrastructure, we can run the training procedure to get the weights for our trained network.

model = SevenSegment.Model.new([4])
{inputs, targets} = SevenSegment.Train.training_set()
params = SevenSegment.Train.run(model, inputs, targets)

Predictions/inference

defmodule SevenSegment.Predict do
  @moduledoc """
  Run single-shot inference for a trained model.

  Intended use:
  - `model` comes from `SevenSegment.Model.new/1` 
  - `params` comes from `SevenSegment.Train.run/4` 
  """

  @doc """
  For a given `digit` 0-9, return the predicted class distribution under `model`.
  """
  def from_digit(model, params, digit) do
    input = SevenSegment.Number.encode_digit!(digit) |> Nx.tensor() |> Nx.new_axis(0)
    Axon.predict(model, params, input)
  end
end

Visualisation

# visualise the confusion matrix (i.e. perfect model will have all white on diagonal, black otherwise)
Axon.predict(model, params, inputs) |> Nx.to_heatmap()
defmodule SevenSegment.Vis do
  @segment_paths [
    "M 190.79731,72.5 L 175.58534,88 L 116.58535,88 L 101.06756,72.5 L 116.58535,57 L 175.58534,57 L 190.79731,72.5 z ",
    "M 98,75.38513 L 113.5,90.59709 L 113.5,135.59708 L 98,151.11487 L 82.5,135.59708 L 82.5,90.59709 L 98,75.38513 z ",
    "M 194,75.38513 L 209.5,90.59709 L 209.5,135.59708 L 194,151.11487 L 178.5,135.59708 L 178.5,90.59709 L 194,75.38513 z ",
    "M 190.79731,154 L 175.58534,169.5 L 116.58535,169.5 L 101.06756,154 L 116.58535,138.5 L 175.58534,138.5 L 190.79731,154 z",
    "M 98,157.44257 L 113.5,172.65453 L 113.5,217.65452 L 98,233.1723 L 82.5,217.65452 L 82.5,172.65453 L 98,157.44257 z ",
    "M 194,157.44257 L 209.5,172.65453 L 209.5,217.65452 L 194,233.1723 L 178.5,217.65452 L 178.5,172.65453 L 194,157.44257 z ",
    "M 190.79731,236.05743 L 175.58534,251.55743 L 116.58535,251.55743 L 101.06756,236.05743 L 116.58535,220.55743 L 175.58534,220.55743 L 190.79731,236.05743 z "
  ]

  def digit(segments, transform \\ nil) do
    Enum.zip(segments, @segment_paths)
    |> Enum.map(fn {s, d} ->
      color_attrs =
        case s do
          0 -> ~s|fill="none" stroke="#EEE"|
          1 -> ~s|fill="red" stroke="red"|
        end

      ~s||
    end)
    |> then(fn path_components ->
      """
      
      
        #{path_components}
      
      
      """
    end)
  end
end

SevenSegment.Number.encode_digit!(0)
|> SevenSegment.Vis.digit()
|> Kino.HTML.new()