Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Genetic Algorithm in Nx

ga_in_nix.livemd

Genetic Algorithm in Nx

Preface

After writing Genetic Algorithms in Elixir, I had no real hopes or expectations that numerical computing would become a focus for Elixir. However, to my surprise, the Nx project developed rather quickly, and proved that Elixir could be used for practical machine learning and numerical computing applications. While I don’t have any plans (yet) of reworking the examples in my book to take advantage of the acceleration enabled by Nx, EXLA, and other projects, I feel it’s necessary to show how you could go about rewriting your genetic algorithms to take advantage of Elixir’s new numerical computing libraries.

Introduction

Nx is a numerical computing library for Elixir which supports the creation and manipulation of multi-dimensional arrays (called tensors in the API), automatic differentiation, and just-in-time (JIT) compilation to CPU, GPU, and other accelerators via pluggable backends and compilers. Nx opens up a realm of possibilities for Elixir developers including the ability to conduct accelerated simulations, perform machine learning, and manipulate large amounts of data in ways that were otherwise not possible in Elixir.

The Nx API is inspired by Python’s NumPy - which is an array programming library. Manipulating arrays or tensors in libraries such as Nx and NumPy requires a different way of thinking. In Genetic Algorithms in Elixir, you created a framework that represented populations as lists, and performed most of the computations using Elixir Enumerable types. Constructs such as map, reduce, filter, etc. were the fundamental parts of your genetic algorithms. At the time, using Elixir’s Enum API was pretty much the only option.

Nx represents data in-memory as flat binaries. While the Nx API has some map and reduce functions, they’re inefficient compared to alternative options in the API. In order to take advantage of Nx, you’ll need to rework your algorithms to work on tensors.

In order to motivate the usage of Nx, we’ll be solving a problem that would be incredibly slow using plain Elixir. We’ll be creating a genetic algorithm that reconstructs a handwritten digit.

Requirements

In order to move forward with the rest of the post, you’ll need an installation of at least Elixir 1.12 and OTP 24. Additionally, you’ll need to be able to use EXLA. EXLA has precompiled binaries for Linux and macOS, as well as variants with CUDA support.

Additionally, it’s probably easiest to follow along with the code in a LiveBook.

We’ll start by installing our required libraries. SciData is a library for grabbing well-known datasets in formats that are easily ingestible by Nx.

Mix.install([
  {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
  {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
  {:scidata, "~> 0.1.0"}
])

The Data

Our objective is to write a genetic algorithm which is able to reconstruct an example image. To start, we’ll need to extract an example image. First, use scidata to download the MNIST dataset:

{images, _} = Scidata.MNIST.download()
{data, type, shape} = images

We only need a single image, so we can do some binary pattern matching to extract the first image without dropping to Nx yet. shape is of the form {n_images, channels, height, width} and the type of the data is {:u, 8} which means each value in the tensor would represent an unsigned 8-bit integer. That means each number in the binary represents a single pixel value from 0-255.

{_, channels, height, width} = shape
image_size = channels * height * width

<> = data

To convert the data to a tensor, we can use the Nx function from_binary, which takes a flat binary and a type and returns a newly constructed tensor. We’ll also want to reshape the tensor to the desired shape and normalize pixel values to be between 0 and 1. This will help us write our genetic algorithm later on:

example_image =
  image
  |> Nx.from_binary(type)
  |> Nx.reshape({channels, height, width})
  |> Nx.divide(255.0)

To verify that we’ve correctly extracted a single image, let’s visualize it using Nx.to_heatmap/1:

example_image |> Nx.to_heatmap()

Defining the Algorithm

Initialization

With our example image extracted, we can start defining the genetic algorithm. The first thing we need to do is initialize our population. We’ll represent our population using a single tensor, where each row of the tensor represents a single flattened image. Because we normalized our example image, we can also use normalized representations of generated images:

population_size = 1000

initialize = fn ->
  Nx.random_uniform({population_size, image_size}, backend: Nx.Defn.Expr)
end

One note: we have to add backend: Nx.Defn.Expr so Nx doesn’t attempt to inline and evaluate the function as a constant during JIT compilation.

Evaluation

The objective is to generate an image that is as close as possible to our example image. We can measure the fitness of a generated image by calculating the pixel-wise squared error (mean-squared error) between it and the example image:

evaluate = fn population, example ->
  # Reshape example so it broadcasts for each image in population
  example =
    example
    |> Nx.flatten()
    |> Nx.new_axis(0)

  # Calculate MSE between each image and example
  population
  |> Nx.subtract(example)
  |> Nx.power(2)
  |> Nx.mean(axes: [-1])
end

In the code above, we start by flattening and then expanding the example image so it broadcasts correctly over the population. Broadcasting is outside the scope of this post; however, you should know that thanks to broadcasting, our code will subtract example from every row in the population.

Next, we define the mean squared-error calculation for each image. This is done by taking the difference between each image in the population and the example image, squaring the difference, and then taking the mean along the last axis in the population. Taking the mean in this way will return an pixel-wise average error for each image.

Selection

Now we need to define a selection strategy. To simplify things, we’ll select each image in our original population for crossover, and replace the original population with children generated from crossover. As another simplification, we’ll simply use a best selection strategy which pairs the best chromosomes for crossover:

select = fn population, target_image ->
  population
  |> evaluate.(target_image)
  |> Nx.argsort()
  |> then(&amp;Nx.take(population, &amp;1))
end

Crossover

Our selection strategy will ensure that parents are paired row-wise in the population. That means the first two flattened images should be combined, the second two should be combined, etc. We’ll combine our population such that each child is an average of it’s parents. This will create two identical children per pair of parents; however, we can add some variability in the mutation step.

crossover = fn population ->
  {population_size, _} = Nx.shape(population)
  half_pop = div(population_size, 2)
  even_idx = Nx.multiply(Nx.iota({half_pop}), 2)
  odd_idx = Nx.add(Nx.multiply(Nx.iota({half_pop}), 2), 1)

  {evens, odds} = {
    Nx.take(population, even_idx),
    Nx.take(population, odd_idx)
  }

  children = Nx.divide(Nx.add(evens, odds), 2)
  Nx.concatenate([children, children], axis: 0)
end

In this implementation, we calculate exactly half of the population size and then extract even and odd rows into separate tensors. We take the average of the two tensors, and then stack averages on top of one another such that we have a new population which matches the size of the original.

Mutation

Our crossover strategy will lead to premature convergence rather quickly, so we need to introduce some variability with mutation. We’ll do this by adding random noise to around 50% of the pixels in the new population:

mutate = fn population ->
  mask = Nx.random_uniform(Nx.shape(population), backend: Nx.Defn.Expr)
  noise = Nx.random_uniform(Nx.shape(population), -0.15, 0.15, backend: Nx.Defn.Expr)

  Nx.select(Nx.less(mask, 0.4), Nx.add(population, noise), population)
  |> Nx.clip(0, 1)
end

Here we generate a mask and noise which have the same shape as the original population. We constrain noise values between -0.1 and 0.1 so there are no extreme changes in pixel values. Finally, since there’s a possibility that pixel values become negative, we clamp the population back between 0 and 1.

The Algorithm

We’ve now implemented all of the required steps for our genetic algorithm. Now we need to run it. We’ll run for a fixed number of generations to see how close we can get to the original image.

evolve = fn population, target_image ->
  population
  |> select.(target_image)
  |> crossover.()
  |> mutate.()
end

population = Nx.Defn.jit(initialize, [], compiler: EXLA)

final_population =
  Enum.reduce(1..2500, population, fn i, population ->
    population = Nx.Defn.jit(evolve, [population, example_image], compiler: EXLA)

    best =
      Nx.Defn.jit(
        fn population, example_image ->
          population
          |> evaluate.(example_image)
          |> Nx.reduce_min()
        end,
        [population, example_image],
        compiler: EXLA
      )
      |> Nx.to_scalar()

    IO.write("\rGeneration: #{i} Best: #{:io_lib.format('~.5f', [best])}")

    population
  end)

# Visualize the top 3
final_population
|> select.(example_image)
|> Nx.slice_axis(0, 3, 0)
|> Nx.reshape({3, 28, 28})
|> Nx.to_heatmap()

In the code above, we implement evolve which is the body of our genetic algorithm. Next, we initialize the population using Nx.Defn.jit/3, which tells Nx to JIT compile our functions using the EXLA compiler. This is necessary because we’re not working inside a module with defn, so by default our functions will not be JIT compiled. Next, we implement the genetic algorithm loop using reduce, and keep track of progress over time by extracting the best chromosome from the population after each generation.

Finally, we take the final evolved population and extract the top 3 candidates from the pool to inspect. If you look hard enough, you can see how our random population converged to resemble the original target image! We can probably do better with more complex crossover, selection, and mutation schemes, or by messing with hyperparameters, but for a quick implementation our results are pretty good!