Hello World
Train a Simple TensorFlow Lite Model
This notebook is a direct port of train_hello_world.ipynb to Nx/Axon.
WARNING: This is a work in progress!
When it’s complete, it will demonstrate the process of training a simple model for TensorFlow Lite for Microcontrollers in Elixir.
To do:
- Graphs showing training and validation loss and error. Not sure how to get these yet…
- Generate TensorFlow Lite model. Not supported by Axon yet.
- Model quantization. Not supported yet.
- Demo that uses the model
- Simplify plotting code. It works, but really seems like it could be improved.
Setup
This notebook requires Axon.
Mix.install([
{:nx, "~> 0.1.0", override: true},
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla", override: true},
{:vega_lite, "~> 0.1"},
{:kino, "~> 0.4"}
])
alias VegaLite, as: Vl
# Set the default options for this notebook
Nx.Defn.default_options(compiler: EXLA)
Dataset
The code in the following cell will generate a set of random x values, calculate their sine values, and display them on a graph.
sample_count = 1000
# Generate a uniformly distributed set of random numbers from
# 0 to 2π, which covers a complete sine wave oscillation
samples = Nx.random_uniform({sample_count, 1}, 0, 2 * :math.pi())
# Calculate the corresponding sine values
targets = Nx.sin(samples)
Vl.new(width: 400, height: 400)
|> Vl.data_from_series(
x: Nx.to_flat_list(samples),
y: Nx.to_flat_list(targets),
color: List.duplicate("sin(x)", sample_count)
)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "color", title: nil)
Add noise
Since it was generated directly by the sine function, our data fits a nice, smooth curve.
…
targets = Nx.multiply(targets, Nx.random_uniform({sample_count, 1}, 0.9, 1.1))
Vl.new(width: 400, height: 400)
|> Vl.data_from_series(
x: Nx.to_flat_list(samples),
y: Nx.to_flat_list(targets),
color: List.duplicate("sin(x) + noise", sample_count)
)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "color", title: nil)
Split the data
The data is split as follows:
- Training: 60%
- Validation: 20%
- Testing: 20%
training_sample_count = floor(0.6 * sample_count)
validation_sample_count = floor(0.2 * sample_count)
testing_sample_count = sample_count - training_sample_count - validation_sample_count
validation_index = training_sample_count
testing_index = validation_index + validation_sample_count
training_samples = Nx.slice(samples, [0, 0], [training_sample_count, 1])
training_targets = Nx.slice(targets, [0, 0], [training_sample_count, 1])
validation_samples = Nx.slice(samples, [validation_index, 0], [validation_sample_count, 1])
validation_targets = Nx.slice(targets, [validation_index, 0], [validation_sample_count, 1])
testing_samples = Nx.slice(samples, [testing_index, 0], [testing_sample_count, 1])
testing_targets = Nx.slice(targets, [testing_index, 0], [testing_sample_count, 1])
Vl.new(width: 400, height: 400)
|> Vl.data_from_series(
x:
Nx.to_flat_list(training_samples) ++
Nx.to_flat_list(validation_samples) ++ Nx.to_flat_list(testing_samples),
y:
Nx.to_flat_list(training_targets) ++
Nx.to_flat_list(validation_targets) ++ Nx.to_flat_list(testing_targets),
color:
List.duplicate("Train", training_sample_count) ++
List.duplicate("Validate", validation_sample_count) ++
List.duplicate("Test", testing_sample_count)
)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "color", title: nil)
Training
Design the model
We’re going to build a simple neural network model that will take an input value (in this case, x
) and use it to predict a numeric output value, sin(x)
. This type of problem is called a regression. It will use layers of neurons to attempt to learn any patterns underlying the training data, so it can make predictions.
To begin with, we’ll define two layers. The first layer takes a single input (our x
value) and runs it through 8 neurons. Based on this input, each neuron will become activated to a certain degree based on its internal state (its weight and bias values). A neuron’s degree of activation is expressed as a number.
The activation numbers from our first layer will be fed as inputs to our second layer, which is a single neuron. It will apply its own weights and bias to these inputs and calculate its own activation, which will be output as our y
value.
model =
Axon.input({nil, 1})
|> Axon.dense(8, activation: :relu)
|> Axon.dense(1)
data = [{training_samples, training_targets}]
validation_data = [{validation_samples, validation_targets}]
model_state =
model
|> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam(0.005))
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(data, epochs: 500, iterations: 100)
Plot metrics
TBD
Actual vs. predicted outputs
To get more insight into what is happening, let’s check its predictions against the test dataset we set aside earlier:
require Axon
result = Axon.predict(model, model_state, testing_samples)
Vl.new(width: 400, height: 400)
# |> Vl.data_from_series(x: Nx.to_flat_list(train), y: Nx.to_flat_list(targets), color: List.duplicate("targets", 5000))
# |> Vl.data_from_series(x: Nx.to_flat_list(train), y: Nx.to_flat_list(result), color: List.duplicate("result", 5000))
|> Vl.data_from_series(
x: Nx.to_flat_list(testing_samples) ++ Nx.to_flat_list(testing_samples),
y: Nx.to_flat_list(testing_targets) ++ Nx.to_flat_list(result),
color:
List.duplicate("Actual values", testing_sample_count) ++
List.duplicate("Predicted", testing_sample_count)
)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "color", title: nil)
The graph makes it clear that our network has learned to approximate the sine function in a very limited way.
The rigidity of this fit suggests that the model does not have enough capacity to learn the fully complexity of the sine wave function, so it’s only able to approximate it in an overly simplistic way. By making our model bigger, we should be able to improve performace.
Training a larger model
Design the model
To make our model bigger, let’s add an additional layer of neurons.
model =
Axon.input({nil, 1})
|> Axon.dense(16, activation: :relu)
|> Axon.dense(16, activation: :relu)
|> Axon.dense(1)
Train the model
We’ll now train and save the new model.
data = [{training_samples, training_targets}]
validation_data = [{validation_samples, validation_targets}]
model_state =
model
|> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam(0.005))
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(data, epochs: 500, interations: 100)
Plot metrics
TBD
result = Axon.predict(model, model_state, testing_samples)
# Nx.to_flat_list(result)
Vl.new(width: 400, height: 400)
# |> Vl.data_from_series(x: Nx.to_flat_list(train), y: Nx.to_flat_list(targets), color: List.duplicate("targets", 5000))
# |> Vl.data_from_series(x: Nx.to_flat_list(train), y: Nx.to_flat_list(result), color: List.duplicate("result", 5000))
|> Vl.data_from_series(
x: Nx.to_flat_list(testing_samples) ++ Nx.to_flat_list(testing_samples),
y: Nx.to_flat_list(testing_targets) ++ Nx.to_flat_list(result),
color:
List.duplicate("Actual values", testing_sample_count) ++
List.duplicate("Predicted", testing_sample_count)
)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "color", title: nil)
Much better! The evaluation metrics we printed show that the model has a low loss and MAE on the test data, and the predictions line up visually with our data fairly well.
The model isn’t perfect; its predictions don’t form a smooth sine curve. For instance, the line is almost straight when x is between 4.2 and 5.2. If we wanted to go further, we could try further increasing the capacity of the model, perhaps using some techniques to defend from overfitting.
However, an important part of machine learning is knowing when to stop. This model is good enough for our use case - which is to make some LEDs blink in a pleasing pattern.
Generate a TensorFlow Lite model
TODO