Genetic Algorithm in Nx
Preface
After writing Genetic Algorithms in Elixir, I had no real hopes or expectations that numerical computing would become a focus for Elixir. However, to my surprise, the Nx project developed rather quickly, and proved that Elixir could be used for practical machine learning and numerical computing applications. While I don’t have any plans (yet) of reworking the examples in my book to take advantage of the acceleration enabled by Nx, EXLA, and other projects, I feel it’s necessary to show how you could go about rewriting your genetic algorithms to take advantage of Elixir’s new numerical computing libraries.
Introduction
Nx
is a numerical computing library for Elixir which supports the creation and manipulation of multi-dimensional arrays (called tensors
in the API), automatic differentiation, and just-in-time (JIT) compilation to CPU, GPU, and other accelerators via pluggable backends and compilers. Nx
opens up a realm of possibilities for Elixir developers including the ability to conduct accelerated simulations, perform machine learning, and manipulate large amounts of data in ways that were otherwise not possible in Elixir.
The Nx
API is inspired by Python’s NumPy - which is an array programming library. Manipulating arrays or tensors in libraries such as Nx
and NumPy requires a different way of thinking. In Genetic Algorithms in Elixir, you created a framework that represented populations as lists, and performed most of the computations using Elixir Enumerable
types. Constructs such as map
, reduce
, filter
, etc. were the fundamental parts of your genetic algorithms. At the time, using Elixir’s Enum
API was pretty much the only option.
Nx
represents data in-memory as flat binaries. While the Nx
API has some map
and reduce
functions, they’re inefficient compared to alternative options in the API. In order to take advantage of Nx
, you’ll need to rework your algorithms to work on tensors.
In order to motivate the usage of Nx
, we’ll be solving a problem that would be incredibly slow using plain Elixir. We’ll be creating a genetic algorithm that reconstructs a handwritten digit.
Requirements
In order to move forward with the rest of the post, you’ll need an installation of at least Elixir 1.12 and OTP 24. Additionally, you’ll need to be able to use EXLA
. EXLA
has precompiled binaries for Linux and macOS, as well as variants with CUDA support.
Additionally, it’s probably easiest to follow along with the code in a LiveBook.
We’ll start by installing our required libraries. SciData is a library for grabbing well-known datasets in formats that are easily ingestible by Nx
.
Mix.install([
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.0"}
])
The Data
Our objective is to write a genetic algorithm which is able to reconstruct an example image. To start, we’ll need to extract an example image. First, use scidata
to download the MNIST dataset:
{images, _} = Scidata.MNIST.download()
{data, type, shape} = images
We only need a single image, so we can do some binary pattern matching to extract the first image without dropping to Nx
yet. shape
is of the form {n_images, channels, height, width}
and the type of the data is {:u, 8}
which means each value in the tensor would represent an unsigned 8-bit integer. That means each number in the binary represents a single pixel value from 0-255.
{_, channels, height, width} = shape
image_size = channels * height * width
<> = data
To convert the data to a tensor, we can use the Nx
function from_binary
, which takes a flat binary and a type and returns a newly constructed tensor. We’ll also want to reshape the tensor to the desired shape and normalize pixel values to be between 0 and 1. This will help us write our genetic algorithm later on:
example_image =
image
|> Nx.from_binary(type)
|> Nx.reshape({channels, height, width})
|> Nx.divide(255.0)
To verify that we’ve correctly extracted a single image, let’s visualize it using Nx.to_heatmap/1
:
example_image |> Nx.to_heatmap()
Defining the Algorithm
Initialization
With our example image extracted, we can start defining the genetic algorithm. The first thing we need to do is initialize our population. We’ll represent our population using a single tensor, where each row of the tensor represents a single flattened image. Because we normalized our example image, we can also use normalized representations of generated images:
population_size = 1000
initialize = fn ->
Nx.random_uniform({population_size, image_size}, backend: Nx.Defn.Expr)
end
One note: we have to add backend: Nx.Defn.Expr
so Nx
doesn’t attempt to inline and evaluate the function as a constant during JIT compilation.
Evaluation
The objective is to generate an image that is as close as possible to our example image. We can measure the fitness of a generated image by calculating the pixel-wise squared error (mean-squared error) between it and the example image:
evaluate = fn population, example ->
# Reshape example so it broadcasts for each image in population
example =
example
|> Nx.flatten()
|> Nx.new_axis(0)
# Calculate MSE between each image and example
population
|> Nx.subtract(example)
|> Nx.power(2)
|> Nx.mean(axes: [-1])
end
In the code above, we start by flattening and then expanding the example image so it broadcasts correctly over the population. Broadcasting is outside the scope of this post; however, you should know that thanks to broadcasting, our code will subtract example
from every row in the population.
Next, we define the mean squared-error calculation for each image. This is done by taking the difference between each image in the population and the example image, squaring the difference, and then taking the mean along the last axis in the population. Taking the mean in this way will return an pixel-wise average error for each image.
Selection
Now we need to define a selection strategy. To simplify things, we’ll select each image in our original population for crossover, and replace the original population with children generated from crossover. As another simplification, we’ll simply use a best selection strategy which pairs the best chromosomes for crossover:
select = fn population, target_image ->
population
|> evaluate.(target_image)
|> Nx.argsort()
|> then(&Nx.take(population, &1))
end
Crossover
Our selection strategy will ensure that parents are paired row-wise in the population. That means the first two flattened images should be combined, the second two should be combined, etc. We’ll combine our population such that each child is an average of it’s parents. This will create two identical children per pair of parents; however, we can add some variability in the mutation step.
crossover = fn population ->
{population_size, _} = Nx.shape(population)
half_pop = div(population_size, 2)
even_idx = Nx.multiply(Nx.iota({half_pop}), 2)
odd_idx = Nx.add(Nx.multiply(Nx.iota({half_pop}), 2), 1)
{evens, odds} = {
Nx.take(population, even_idx),
Nx.take(population, odd_idx)
}
children = Nx.divide(Nx.add(evens, odds), 2)
Nx.concatenate([children, children], axis: 0)
end
In this implementation, we calculate exactly half of the population size and then extract even and odd rows into separate tensors. We take the average of the two tensors, and then stack averages on top of one another such that we have a new population which matches the size of the original.
Mutation
Our crossover strategy will lead to premature convergence rather quickly, so we need to introduce some variability with mutation. We’ll do this by adding random noise to around 50% of the pixels in the new population:
mutate = fn population ->
mask = Nx.random_uniform(Nx.shape(population), backend: Nx.Defn.Expr)
noise = Nx.random_uniform(Nx.shape(population), -0.15, 0.15, backend: Nx.Defn.Expr)
Nx.select(Nx.less(mask, 0.4), Nx.add(population, noise), population)
|> Nx.clip(0, 1)
end
Here we generate a mask and noise which have the same shape as the original population. We constrain noise values between -0.1
and 0.1
so there are no extreme changes in pixel values. Finally, since there’s a possibility that pixel values become negative, we clamp the population back between 0 and 1.
The Algorithm
We’ve now implemented all of the required steps for our genetic algorithm. Now we need to run it. We’ll run for a fixed number of generations to see how close we can get to the original image.
evolve = fn population, target_image ->
population
|> select.(target_image)
|> crossover.()
|> mutate.()
end
population = Nx.Defn.jit(initialize, [], compiler: EXLA)
final_population =
Enum.reduce(1..2500, population, fn i, population ->
population = Nx.Defn.jit(evolve, [population, example_image], compiler: EXLA)
best =
Nx.Defn.jit(
fn population, example_image ->
population
|> evaluate.(example_image)
|> Nx.reduce_min()
end,
[population, example_image],
compiler: EXLA
)
|> Nx.to_scalar()
IO.write("\rGeneration: #{i} Best: #{:io_lib.format('~.5f', [best])}")
population
end)
# Visualize the top 3
final_population
|> select.(example_image)
|> Nx.slice_axis(0, 3, 0)
|> Nx.reshape({3, 28, 28})
|> Nx.to_heatmap()
In the code above, we implement evolve
which is the body of our genetic algorithm. Next, we initialize the population using Nx.Defn.jit/3
, which tells Nx
to JIT compile our functions using the EXLA
compiler. This is necessary because we’re not working inside a module with defn
, so by default our functions will not be JIT compiled. Next, we implement the genetic algorithm loop using reduce
, and keep track of progress over time by extracting the best chromosome from the population after each generation.
Finally, we take the final evolved population and extract the top 3 candidates from the pool to inspect. If you look hard enough, you can see how our random population converged to resemble the original target image! We can probably do better with more complex crossover, selection, and mutation schemes, or by messing with hyperparameters, but for a quick implementation our results are pretty good!