Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 4: MNIST Basics


Chapter 4: MNIST Basics

  {:jason, "~> 1.4"},
  {:kino, "~> 0.8.0", override: true},
  {:kino_vega_lite, "~> 0.1.8"},
  {:image, "~> 0.28.1"},
  {:nx, "~> 0.5.3"},
  {:nx_image, "~> 0.1.1"},
  {:bumblebee, "~> 0.1"},
  {:explorer, "~> 0.5.6"},
  {:kino_explorer, "~> 0.1.2"},
  {:exla, "~> 0.4"},
  {:fastbook_elixir, path: "lib/fastbook_elixir"},
  {:httpoison, "~> 1.8"},
  {:poison, "~> 5.0"},
  {:axon, "~> 0.5.1"}

# Nx.Defn.default_options(compiler: EXLA)

alias FastbookElixir.URLs

Under the Hood: Training a Digit Classifier

Having seen what it looks like to actually train a variety of models in Chapter 2, let’s now look under the hood and see exactly what is going on. We’ll start by using computer vision to introduce fundamental tools and concepts for deep learning.

To be exact, we’ll discuss the roles of arrays and tensors and of broadcasting, a powerful technique for using them expressively. We’ll explain stochastic gradient descent (SGD), the mechanism for learning by updating weights automatically. We’ll discuss the choice of a loss function for our basic classification task, and the role of mini-batches. We’ll also describe the math that a basic neural network is actually doing. Finally, we’ll put all these pieces together.

In future chapters we’ll do deep dives into other applications as well, and see how these concepts and tools generalize. But this chapter is about laying foundation stones. To be frank, that also makes this one of the hardest chapters, because of how these concepts all depend on each other. Like an arch, all the stones need to be in place for the structure to stay up. Also like an arch, once that happens, it’s a powerful structure that can support other things. But it requires some patience to assemble.

Let’s begin. The first step is to consider how images are represented in a computer.

Pixels: The Foundations of Computer Vision

In order to understand what happens in a computer vision model, we first have to understand how computers handle images. We’ll use one of the most famous datasets in computer vision, MNIST, for our experiments. MNIST contains images of handwritten digits, collected by the National Institute of Standards and Technology and collated into a machine learning dataset by Yann Lecun and his colleagues. Lecun used MNIST in 1998 in Lenet-5, the first computer system to demonstrate practically useful recognition of handwritten digit sequences. This was one of the most important breakthroughs in the history of AI.

Sidebar: Tenacity and Deep Learning

The story of deep learning is one of tenacity and grit by a handful of dedicated researchers. After early hopes (and hype!) neural networks went out of favor in the 1990’s and 2000’s, and just a handful of researchers kept trying to make them work well. Three of them, Yann Lecun, Yoshua Bengio, and Geoffrey Hinton, were awarded the highest honor in computer science, the Turing Award (generally considered the “Nobel Prize of computer science”), in 2018 after triumphing despite the deep skepticism and disinterest of the wider machine learning and statistics community.

Geoff Hinton has told of how even academic papers showing dramatically better results than anything previously published would be rejected by top journals and conferences, just because they used a neural network. Yann Lecun’s work on convolutional neural networks, which we will study in the next section, showed that these models could read handwritten text—something that had never been achieved before. However, his breakthrough was ignored by most researchers, even as it was used commercially to read 10% of the checks in the US!

In addition to these three Turing Award winners, there are many other researchers who have battled to get us to where we are today. For instance, Jurgen Schmidhuber (who many believe should have shared in the Turing Award) pioneered many important ideas, including working with his student Sepp Hochreiter on the long short-term memory (LSTM) architecture (widely used for speech recognition and other text modeling tasks, and used in the IMDb example in <>). Perhaps most important of all, Paul Werbos in 1974 invented back-propagation for neural networks, the technique shown in this chapter and used universally for training neural networks (Werbos 1994). His development was almost entirely ignored for decades, but today it is considered the most important foundation of modern AI.

There is a lesson here for all of us! On your deep learning journey you will face many obstacles, both technical, and (even more difficult) posed by people around you who don’t believe you’ll be successful. There’s one guaranteed way to fail, and that’s to stop trying. We’ve seen that the only consistent trait amongst every fast.ai student that’s gone on to be a world-class practitioner is that they are all very tenacious.

3s and 7s Classifier

For this initial tutorial we are just going to try to create a model that can classify any image as a 3 or a 7. So let’s download a sample of MNIST that contains images of just these digits:

images = FastbookElixir.untar_data(URLs.mnist_sample())

threes = Enum.filter(images, fn {path, _} -> String.match?(path, ~r"/train/3/") end)
sevens = Enum.filter(images, fn {path, _} -> String.match?(path, ~r"/train/7/") end)


Let’s take a look at one now. Here’s an image of a handwritten number 3, taken from the famous MNIST dataset of handwritten numbers:

{im3_path, content} = Enum.random(threes)
im3 = Kino.Image.new(content, "image/png")

Here we are using the Image class from the Kino library. We have installed Kino as a dependency and so Livebook displays the image for us automatically.

In a computer, everything is represented as a number. To view the numbers that make up this image, we have to convert it to an Nx tensor. For instance, here’s what a section of the image looks like, converted to a tensor:

image_tensor =
  |> Image.from_binary!()
  |> Image.to_nx!()

The Nx library also includs NxImage to provide some standard image manipulations. Where we’ll grab a 16x16 crop from the center of the original image.

# NxImage code to create a 16x16 crop in the center
center_cropped_tensor = NxImage.center_crop(image_tensor, {16, 16})

# Code to render the image and cropped image in Livebook
original_image = Kino.Image.new(image_tensor)
original_label = Kino.Markdown.new("**Original image**")

cropped_image = Kino.Image.new(center_cropped_tensor)
cropped_label = Kino.Markdown.new("**Cropped image**")

  Kino.Layout.grid([original_image, original_label], boxed: true),
  Kino.Layout.grid([cropped_image, cropped_label], boxed: true)

TODO: Is there a way to render a dataframe like Python in Kino?

You can see that the background black pixels are stored as the number 0, white is the number 255, and shades of gray are between the two. The entire image contains 28 pixels across and 28 pixels down, for a total of 784 pixels. (This is much smaller than an image that you would get from a phone camera, which has millions of pixels, but is a convenient size for our initial learning and experiments. We will build up to bigger, full-color images soon.)

So, now you’ve seen what an image looks like to a computer, let’s recall our goal: create a model that can recognize 3s and 7s. How might you go about getting a computer to do that?

> Warning: Stop and Think!: Before you read on, take a moment to think about how a computer might be able to recognize these two different digits. What kinds of features might it be able to look at? How might it be able to identify these features? How could it combine them together? Learning works best when you try to solve problems yourself, rather than just reading somebody else’s answers; so step away from this book for a few minutes, grab a piece of paper and pen, and jot some ideas down…

First: Try Pixel Simularity

So, here is a first idea: how about we find the average pixel value for every pixel of the 3s, then do the same for the 7s. This will give us two group averages, defining what we might call the “ideal” 3 and 7. Then, to classify an image as one digit or the other, we see which of these two ideal digits the image is most similar to. This certainly seems like it should be better than nothing, so it will make a good baseline.

> jargon: Baseline: A simple model which you are confident should perform reasonably well. It should be very simple to implement, and very easy to test, so that you can then test each of your improved ideas, and make sure they are always better than your baseline. Without starting with a sensible baseline, it is very difficult to know whether your super-fancy models are actually any good. One good approach to creating a baseline is doing what we have done here: think of a simple, easy-to-implement model. Another good approach is to search around to find other people that have solved similar problems to yours, and download and run their code on your dataset. Ideally, try both of these!

Step one for our simple model is to get the average of pixel values for each of our two groups. In the process of doing this, we will learn a lot of neat Elixir numeric programming tricks!

Let’s create a tensor containing all of our 3s stacked together. We already know how to create a tensor containing a single image. To create a tensor containing all the images in a directory, we will first use Enum.map to create a plain list of the single image tensors.

We will use Livebook to do some little checks of our work along the way—in this case, making sure that the number of returned items seems reasonable:

seven_tensors =
  Enum.map(sevens, fn {_path, content} ->
    |> Image.from_binary!()
    |> Image.to_nx!()
    |> Nx.reshape({28, 28})

three_tensors =
  Enum.map(threes, fn {_path, content} ->
    |> Image.from_binary!()
    |> Image.to_nx!()
    |> Nx.reshape({28, 28})

{Enum.count(three_tensors), Enum.count(seven_tensors)}

We’ll also check that one of the images looks okay.

Since we now have tensors (which Livebook by default will print as data), rather than images, we need to use Livebook’s Kino.Image module to display it:

# Creating an image from a tensor requires a third rank to represent 'channel'

image_tensor = Nx.reshape(hd(three_tensors), {28, 28, 1})

For every pixel position, we want to compute the average over all the images of the intensity of that pixel. To do this we first combine all the images in this list into a single three-dimensional tensor. The most common way to describe such a tensor is to call it a rank-3 tensor. We often need to stack up individual tensors in a collection into a single tensor. Unsurprisingly, Nx comes with a function called stack that we can use for this purpose.

Notice that Nx does not require special handling to cast types to a float.

Generally when images are floats, the pixel values are expected to be between 0 and 1, so we will also divide by 255 here:

stacked_sevens = Nx.stack(seven_tensors) |> Nx.divide(255)
stacked_threes = Nx.stack(three_tensors) |> Nx.divide(255)


Perhaps the most important attribute of a tensor is its shape. This tells you the length of each axis. In this case, we can see that we have 6,131 images, each of size 28×28 pixels. There is nothing specifically about this tensor that says that the first axis is the number of images, the second is the height, and the third is the width—the semantics of a tensor are entirely up to us, and how we construct it. As far as Nx is concerned, it is just a bunch of numbers in memory.

The length of a tensor’s shape is its rank:


It is really important for you to commit to memory and practice these bits of tensor jargon: rank is the number of axes or dimensions in a tensor; shape is the size of each axis of a tensor.

A: Watch out because the term “dimension” is sometimes used in two ways. Consider that we live in “three-dimensonal space” where a physical position can be described by a 3-vector v. But according to PyTorch, the attribute v.ndim (which sure looks like the “number of dimensions” of v) equals one, not three! Why? Because v is a vector, which is a tensor of rank one, meaning that it has only one axis (even if that axis has a length of three). In other words, sometimes dimension is used for the size of an axis (“space is three-dimensional”); other times, it is used for the rank, or the number of axes (“a matrix has two dimensions”). When confused, I find it helpful to translate all statements into terms of rank, axis, and length, which are unambiguous terms.

We can also get a tensor’s rank directly with rank:


Finally, we can compute what the ideal 3 looks like. We calculate the mean of all the image tensors by taking the mean along dimension 0 of our stacked, rank-3 tensor. This is the dimension that indexes over all the images.

In other words, for every pixel position, this will compute the average of that pixel over all images. The result will be one value for every pixel position, or a single image. Here it is:

mean3 = Nx.mean(stacked_threes, axes: [0])

# Note to display the image we need to convert back to an unsigned int between 0-255
mean3_image =
  |> Nx.multiply(255)
  |> Nx.floor()
  |> Nx.reshape({28, 28, 1})
  |> Nx.as_type(:u8)


According to this dataset, this is the ideal number 3! (You may not like it, but this is what peak number 3 performance looks like.) You can see how it’s very dark where all the images agree it should be dark, but it becomes wispy and blurry where the images disagree.

Let’s do the same thing for the 7s, but put all the steps together at once to save some time:

mean7 = Nx.mean(stacked_sevens, axes: [0])


Let’s now pick an arbitrary 3 and measure its distance from our “ideal digits.”

> stop: Stop and Think!: How would you calculate how similar a particular image is to each of our ideal digits? Remember to step away from this book and jot down some ideas before you move on! Research shows that recall and understanding improves dramatically when you are engaged with the learning process by solving problems, experimenting, and trying new ideas yourself

Here’s a sample 3:

a_3 = stacked_threes[1]

How can we determine its distance from our ideal 3? We can’t just add up the differences between the pixels of this image and the ideal digit. Some differences will be positive while others will be negative, and these differences will cancel out, resulting in a situation where an image that is too dark in some places and too light in others might be shown as having zero total differences from the ideal. That would be misleading!

To avoid this, there are two main ways data scientists measure distance in this context:

  • Take the mean of the absolute value of differences (absolute value is the function that replaces negative values with positive values). This is called the mean absolute difference or L1 norm
  • Take the mean of the square of differences (which makes everything positive) and then take the square root (which undoes the squaring). This is called the root mean squared error (RMSE) or L2 norm.

> important: It’s Okay to Have Forgotten Your Math: In this book we generally assume that you have completed high school math, and remember at least some of it… But everybody forgets some things! It all depends on what you happen to have had reason to practice in the meantime. Perhaps you have forgotten what a square root is, or exactly how they work. No problem! Any time you come across a maths concept that is not explained fully in this book, don’t just keep moving on; instead, stop and look it up. Make sure you understand the basic idea, how it works, and why we might be using it. One of the best places to refresh your understanding is Khan Academy. For instance, Khan Academy has a great introduction to square roots.

Let’s try both of these now:

dist_3_abs = Nx.subtract(a_3, mean3) |> Nx.abs() |> Nx.mean()
dist_3_sqr = Nx.subtract(a_3, mean3) |> Nx.pow(2) |> Nx.mean() |> Nx.sqrt()

{dist_3_abs, dist_3_sqr}
dist_7_abs = Nx.subtract(a_3, mean7) |> Nx.abs() |> Nx.mean()
dist_7_sqr = Nx.subtract(a_3, mean7) |> Nx.pow(2) |> Nx.mean() |> Nx.sqrt()

{dist_7_abs, dist_7_sqr}

In both cases, the distance between our 3 and the “ideal” 3 is less than the distance to the ideal 7. So our simple model will give the right prediction in this case.

Elixir’s Axon library already provides both of these as loss functions. You’ll find these inside Axon.Losses:

# The python code
# F.l1_loss(a_3.float(),mean7), F.mse_loss(a_3,mean7).sqrt()

  Axon.Losses.mean_absolute_error(a_3, mean7) |> Nx.mean(),
  Axon.Losses.mean_squared_error(a_3, mean7) |> Nx.mean() |> Nx.sqrt()

Note that mean squared error is often abbreviated as mse, and mean absolute error is sometimes called l1 refers to the standard mathematical jargon for (in math it’s called the L1 norm).

> S: Intuitively, the difference between L1 norm and mean squared error (MSE) is that the latter will penalize bigger mistakes more heavily than the former (and be more lenient with small mistakes).

> J: When I first came across this “L1” thingie, I looked it up to see what on earth it meant. I found on Google that it is a vector norm using absolute value, so looked up vector norm and started reading: Given a vector space V over a field F of the real or complex numbers, a norm on V is a nonnegative-valued any function p: V → [0,+∞) with the following properties: For all a ∈ F and all u, v ∈ V, p(u + v) ≤ p(u) + p(v)… Then I stopped reading. “Ugh, I’ll never understand math!” I thought, for the thousandth time. Since then I’ve learned that every time these complex mathy bits of jargon come up in practice, it turns out I can replace them with a tiny bit of code! Like, the L1 loss is just equal to (a-b).abs().mean(), where a and b are tensors. I guess mathy folks just think differently than me… I’ll make sure in this book that every time some mathy jargon comes up, I’ll give you the little bit of code it’s equal to as well, and explain in common-sense terms what’s going on.

Nx Tensors

To create a tensor, pass a list (or list of lists, or list of lists of lists, etc.) to Nx.tensor():

data = [[1, 2, 3], [4, 5, 6]]
tns = Nx.tensor(data)

You can select a row (note that, like lists in Elixir, tensors are 0-indexed so 1 refers to the second row/column):


With Nx you can also name axes and access tensors by dimenson names.

This returns the same results as above but is more understandable with names:

tns_named = Nx.tensor(tns, names: [:y, :x])
tns_named[y: 1]

And you would use names again to access a column from a tensor (we sometimes refer to the dimensions of tensors/arrays as axes):

tns_named[x: 1]

You can combine these with an Elixir range (start..end with end being included) to select part of a row or column:

tns_named[y: 1, x: 1..2]

In order to apply standard operations like add, subtract, multiply, or divide you have to use the Nx functions

Nx.add(tns, 1)

Or you can use the special Nx.Defn module to create functions which allow you to use standard operators such as +, -, *, or / and many others.

> BS: Special Note: Using Nx.Defn is the idiomatic and cleaner way to right Nx based functions. However, this book will use primarily anonymous functions and explicit Nx function calls in order to follow the flow of the original Fastbook content.

defmodule Example do
  import Nx.Defn

  defn do_some_math(tensor) do
    tensor + 1


The anonymous function + explicit Nx function version of the do_some_math() above would be:

do_some_math = fn tensor ->
  Nx.add(tensor, 1)


Tensors will automatically adjust types as needed:

  Nx.multiply(tns, 1.5).type

Computing Metrics using Broadcasting

So, is our baseline model any good? To quantify this, we must define a metric.

Recall that a metric is a number that is calculated based on the predictions of our model, and the correct labels in our dataset, in order to tell us how good our model is. For instance, we could use either of the functions we saw in the previous section, mean squared error, or mean absolute error, and take the average of them over the whole dataset. However, neither of these are numbers that are very understandable to most people; in practice, we normally use accuracy as the metric for classification models.

As we’ve discussed, we want to calculate our metric over a validation set. This is so that we don’t inadvertently overfit—that is, train a model to work well only on our training data. This is not really a risk with the pixel similarity model we’re using here as a first try, since it has no trained components, but we’ll use a validation set anyway to follow normal practices and to be ready for our second try later.

To get a validation set we need to remove some of the data from training entirely, so it is not seen by the model at all. As it turns out, the creators of the MNIST dataset have already done this for us. Do you remember how there was a whole separate directory called valid? That’s what this directory is for!

So to start with, let’s create tensors for our 3s and 7s from that directory. These are the tensors we will use to calculate a metric measuring the quality of our first-try model, which measures distance from an ideal image:

valid_threes = Enum.filter(images, fn {path, _} -> String.match?(path, ~r"/valid/3/") end)
valid_sevens = Enum.filter(images, fn {path, _} -> String.match?(path, ~r"/valid/7/") end)

valid_7_tens =
  Enum.map(valid_sevens, fn {_path, content} ->
    |> Image.from_binary!()
    |> Image.to_nx!()
    |> Nx.reshape({28, 28})
  |> Nx.stack()
  |> Nx.divide(255)

valid_3_tens =
  Enum.map(valid_threes, fn {_path, content} ->
    |> Image.from_binary!()
    |> Image.to_nx!()
    |> Nx.reshape({28, 28})
  |> Nx.stack()
  |> Nx.divide(255)

[valid_3_tens.shape, valid_7_tens.shape]

It’s good to get in the habit of checking shapes as you go. Here we see two tensors, one representing the 3s validation set of 1,010 images of size 28×28, and one representing the 7s validation set of 1,028 images of size 28×28.

We ultimately want to write a function, is_3, that will decide if an arbitrary image is a 3 or a 7. It will do this by deciding which of our two “ideal digits” this arbitrary image is closer to. For that we need to define a notion of distance—that is, a function that calculates the distance between two images.

We can write a simple function that calculates the mean absolute error using an expression very similar to the one we wrote in the last section:

mnist_distance = fn a, b ->
  Nx.subtract(a, b)
  |> Nx.abs()
  |> Nx.mean(axes: [-1, -2])

mnist_distance.(a_3, mean3)

This is the same value we previously calculated for the distance between these two images, the ideal 3 mean3 and the arbitrary sample 3 a_3, which are both single-image tensors with a shape of [28,28].

But in order to calculate a metric for overall accuracy, we will need to calculate the distance to the ideal 3 for every image in the validation set. How do we do that calculation? We could write a loop over all of the single-image tensors that are stacked within our validation set tensor, valid_3_tens, which has a shape of [1010,28,28] representing 1,010 images. But there is a better way.

Something very interesting happens when we take this exact same distance function, designed for comparing two single images, but pass in as an argument valid_3_tens, the tensor that represents the 3s validation set:

valid_3_dist = mnist_distance.(valid_3_tens, mean3)
[valid_3_dist, valid_3_dist.shape]

Instead of complaining about shapes not matching, it returned the distance for every single image as a vector (i.e., a rank-1 tensor) of length 1,010 (the number of 3s in our validation set). How did that happen?

Take another look at our function mnist_distance, and you’ll see we have there the subtraction (a-b). The magic trick is that Nx, when it tries to perform a simple subtraction operation between two tensors of different ranks, will use broadcasting. That is, it will automatically expand the tensor with the smaller rank to have the same size as the one with the larger rank. Broadcasting is an important capability that makes tensor code much easier to write.

After broadcasting so the two argument tensors have the same rank, Nx applies its usual logic for two tensors of the same rank: it performs the operation on each corresponding element of the two tensors, and returns the tensor result. For instance:

Nx.add(Nx.tensor([1, 2, 3]), Nx.tensor(1))

So in this case, Nx treats mean3, a rank-2 tensor representing a single image, as if it were 1,010 copies of the same image, and then subtracts each of those copies from each 3 in our validation set. What shape would you expect this tensor to have? Try to figure it out yourself before you look at the answer below:

Nx.subtract(valid_3_tens, mean3).shape

We are calculating the difference between our “ideal 3” and each of the 1,010 3s in the validation set, for each of 28×28 images, resulting in the shape [1010,28,28].

TODO: Need someone who knows Nx internals to validate these statements

There are a couple of important points about how broadcasting is implemented, which make it valuable not just for expressivity but also for performance:

  • PyTorch doesn’t actually copy mean3 1,010 times. It pretends it were a tensor of that shape, but doesn’t actually allocate any additional memory
  • It does the whole calculation in C (or, if you’re using a GPU, in CUDA, the equivalent of C on the GPU), tens of thousands of times faster than pure Python (up to millions of times faster on a GPU!).

This is true of all broadcasting and elementwise operations and functions done in PyTorch. It’s the most important technique for you to know to create efficient PyTorch code.

Next in mnist_distance we see Nx.abs. You might be able to guess now what this does when applied to a tensor. It applies the method to each individual element in the tensor, and returns a tensor of the results (that is, it applies the method “elementwise”). So in this case, we’ll get back 1,010 matrices of absolute values.

Finally, our function calls Nx.mean(axes: [-1,-2]). This tells Nx which axes we want to calculate the mean over. In Elixir, -1 refers to the last element, and -2 refers to the second-to-last. So in this case, this tells Nx that we want to take the mean ranging over the values indexed by the last two axes of the tensor. The last two axes are the horizontal and vertical dimensions of an image. After taking the mean over the last two axes, we are left with just the first tensor axis, which indexes over our images, which is why our final size was (1010). In other words, for every image, we averaged the intensity of all the pixels in that image.

We’ll be learning lots more about broadcasting throughout this book, especially in <>, and will be practicing it regularly too.

We can use mnist_distance to figure out whether an image is a 3 or not by using the following logic: if the distance between the digit in question and the ideal 3 is less than the distance to the ideal 7, then it’s a 3. This function will automatically do broadcasting and be applied elementwise, just like all Nx functions and operators:

# using anonymous function syntax
is_3 = fn x -> Nx.less(mnist_distance.(x, mean3), mnist_distance.(x, mean7)) end


Note that Nx will convert the Boolean response to an integer, we get 1 for a true value and 0 for false. Thanks to broadcasting, we can also test it on the full validation set of 3s:


Now we can calculate the accuracy for each of the 3s and 7s by taking the average of that function for all 3s and its inverse for all 7s:

accuracy_3s = is_3.(valid_3_tens) |> Nx.mean()
accuracy_7s = Nx.subtract(1, is_3.(valid_7_tens)) |> Nx.mean()
average_accuracy = Nx.add(accuracy_3s, accuracy_7s) |> Nx.divide(2)


This looks like a pretty good start! We’re getting over 90% accuracy on both 3s and 7s, and we’ve seen how to define a metric conveniently using broadcasting.

But let’s be honest: 3s and 7s are very different-looking digits. And we’re only classifying 2 out of the 10 possible digits so far. So we’re going to need to do better!

To do better, perhaps it is time to try a system that does some real learning—that is, that can automatically modify itself to improve its performance. In other words, it’s time to talk about the training process, and SGD.

Stochastic Gradient Descent (SGD)

Do you remember the way that Arthur Samuel described machine learning, which we quoted in <>?

> : Suppose we arrange for some automatic means of testing the effectiveness of any current weight assignment in terms of actual performance and provide a mechanism for altering the weight assignment so as to maximize the performance. We need not go into the details of such a procedure to see that it could be made entirely automatic and to see that a machine so programmed would “learn” from its experience.

As we discussed, this is the key to allowing us to have a model that can get better and better—that can learn. But our pixel similarity approach does not really do this. We do not have any kind of weight assignment, or any way of improving based on testing the effectiveness of a weight assignment. In other words, we can’t really improve our pixel similarity approach by modifying a set of parameters. In order to take advantage of the power of deep learning, we will first have to represent our task in the way that Arthur Samuel described it.

Instead of trying to find the similarity between an image and an “ideal image,” we could instead look at each individual pixel and come up with a set of weights for each one, such that the highest weights are associated with those pixels most likely to be black for a particular category. For instance, pixels toward the bottom right are not very likely to be activated for a 7, so they should have a low weight for a 7, but they are likely to be activated for an 8, so they should have a high weight for an 8. This can be represented as a function and set of weight values for each possible category—for instance the probability of being the number 8:

defn pr_eight(x,w) do
  (x*w) |> Nx.sum()

Here we are assuming that x is the image, represented as a vector—in other words, with all of the rows stacked up end to end into a single long line. And we are assuming that the weights are a vector w. If we have this function, then we just need some way to update the weights to make them a little bit better. With such an approach, we can repeat that step a number of times, making the weights better and better, until they are as good as we can make them.

We want to find the specific values for the vector w that causes the result of our function to be high for those images that are actually 8s, and low for those images that are not. Searching for the best vector w is a way to search for the best function for recognising 8s. (Because we are not yet using a deep neural network, we are limited by what our function can actually do—we are going to fix that constraint later in this chapter.)

To be more specific, here are the steps that we are going to require, to turn this function into a machine learning classifier:

  1. Initialize the weights.
  2. For each image, use these weights to predict whether it appears to be a 3 or a 7.
  3. Based on these predictions, calculate how good the model is (its loss).
  4. Calculate the gradient, which measures for each weight, how changing that weight would change the loss
  5. Step (that is, change) all the weights based on that calculation.
  6. Go back to the step 2, and repeat the process.
  7. Iterate until you decide to stop the training process (for instance, because the model is good enough or you don’t want to wait any longer).

These seven steps, illustrated in <>, are the key to the training of all deep learning models. That deep learning turns out to rely entirely on these steps is extremely surprising and counterintuitive. It’s amazing that this process can solve such complex problems. But, as you’ll see, it really does!

flowchart LR
  id1([init]) --> id2([predict])
  id2 --> id3([loss])
  id3 --> id4([gradient])
  id4 --> id5([step])
  id5 --> id6([stop])
  id5 -->|repeat| id2

There are many different ways to do each of these seven steps, and we will be learning about them throughout the rest of this book. These are the details that make a big difference for deep learning practitioners, but it turns out that the general approach to each one generally follows some basic principles. Here are a few guidelines:

  • Initialize:: We initialize the parameters to random values. This may sound surprising. There are certainly other choices we could make, such as initializing them to the percentage of times that pixel is activated for that category—but since we already know that we have a routine to improve these weights, it turns out that just starting with random weights works perfectly well.
  • Loss:: This is what Samuel referred to when he spoke of testing the effectiveness of any current weight assignment in terms of actual performance. We need some function that will return a number that is small if the performance of the model is good (the standard approach is to treat a small loss as good, and a large loss as bad, although this is just a convention).
  • Step:: A simple way to figure out whether a weight should be increased a bit, or decreased a bit, would be just to try it: increase the weight by a small amount, and see if the loss goes up or down. Once you find the correct direction, you could then change that amount by a bit more, and a bit less, until you find an amount that works well. However, this is slow! As we will see, the magic of calculus allows us to directly figure out in which direction, and by roughly how much, to change each weight, without having to try all these small changes. The way to do this is by calculating gradients. This is just a performance optimization, we would get exactly the same results by using the slower manual process as well.
  • Stop:: Once we’ve decided how many epochs to train the model for (a few suggestions for this were given in the earlier list), we apply that decision. This is where that decision is applied. For our digit classifier, we would keep training until the accuracy of the model started getting worse, or we ran out of time.

Before applying these steps to our image classification problem, let’s illustrate what they look like in a simpler case. First we will define a very simple function, the quadratic—let’s pretend that this is our loss function, and x is a weight parameter of the function:

f = fn x -> x ** 2 end

Here is a graph of that function:

The sequence of steps we described earlier starts by picking some random value for a parameter, and calculating the value of the loss:

Now we look to see what would happen if we increased or decreased our parameter by a little bit—the adjustment. This is simply the slope at a particular point:

We can change our weight by a little in the direction of the slope, calculate our loss and adjustment again, and repeat this a few times. Eventually, we will get to the lowest point on our curve:

This basic idea goes all the way back to Isaac Newton, who pointed out that we can optimize arbitrary functions in this way. Regardless of how complicated our functions become, this basic approach of gradient descent will not significantly change. The only minor changes we will see later in this book are some handy ways we can make it faster, by finding better steps.

Calculating Gradients

The one magic step is the bit where we calculate the gradients. As we mentioned, we use calculus as a performance optimization; it allows us to more quickly calculate whether our loss will go up or down when we adjust our parameters up or down. In other words, the gradients will tell us how much we have to change each weight to make our model better.

You may remember from your high school calculus class that the derivative of a function tells you how much a change in its parameters will change its result. If not, don’t worry, lots of us forget calculus once high school is behind us! But you will have to have some intuitive understanding of what a derivative is before you continue, so if this is all very fuzzy in your head, head over to Khan Academy and complete the lessons on basic derivatives. You won’t have to know how to calculate them yourselves, you just have to know what a derivative is.

The key point about a derivative is this: for any function, such as the quadratic function we saw in the previous section, we can calculate its derivative. The derivative is another function. It calculates the change, rather than the value. For instance, the derivative of the quadratic function at the value 3 tells us how rapidly the function changes at the value 3. More specifically, you may recall that gradient is defined as rise/run, that is, the change in the value of the function, divided by the change in the value of the parameter. When we know how our function will change, then we know what we need to do to make it smaller. This is the key to machine learning: having a way to change the parameters of a function to make it smaller. Calculus provides us with a computational shortcut, the derivative, which lets us directly calculate the gradients of our functions.

One important thing to be aware of is that our function has lots of weights that we need to adjust, so when we calculate the derivative we won’t get back one number, but lots of them—a gradient for every weight. But there is nothing mathematically tricky here; you can calculate the derivative with respect to one weight, and treat all the other ones as constant, then repeat that for each other weight. This is how all of the gradients are calculated, for every weight.

We mentioned just now that you won’t have to calculate any gradients yourself. How can that be? Amazingly enough, Nx is able to automatically compute the derivative of nearly any function! What’s more, it does it very fast. Most of the time, it will be at least as fast as any derivative function that you can create by hand. Let’s see an example.

First, let’s pick a tensor value which we want gradients at:

xt = Nx.tensor(3)

Now we can define our x**2 function using Nx and then ask Nx to compute the gradient at the point defined in the tensor above.

f = fn x -> Nx.pow(x, 2) end
gradient = Nx.Defn.grad(f).(xt)

If you remember your high school calculus rules, the derivative of x*2 is 2 x, and we have x=3, so the gradients should be 2*3=6, which is what Nx calculated for us!

Now we’ll repeat the preceding steps, but with a vector argument for our function:

xt = Nx.tensor([3, 4, 10])

And we’ll add sum to our function so it can take a vector (i.e., a rank-1 tensor), and return a scalar (i.e., a rank-0 tensor):

grad_func = Nx.Defn.grad(fn x -> Nx.pow(x, 2) |> Nx.sum() end)
xt = Nx.tensor([3, 4, 10])

fx = fn x ->
  IO.puts("I'm running")
  Nx.pow(x, 2) |> Nx.sum()

dxdt = Nx.Defn.grad(fx)


Our gradients are 2*xt, as we’d expect!


The gradients only tell us the slope of our function, they don’t actually tell us exactly how far to adjust the parameters. But it gives us some idea of how far; if the slope is very large, then that may suggest that we have more adjustments to do, whereas if the slope is very small, that may suggest that we are close to the optimal value.

Stepping With a Learning Rate

Deciding how to change our parameters based on the values of the gradients is an important part of the deep learning process. Nearly all approaches start with the basic idea of multiplying the gradient by some small number, called the learning rate (LR). The learning rate is often a number between 0.001 and 0.1, although it could be anything. Often, people select a learning rate just by trying a few, and finding which results in the best model after training (we’ll show you a better approach later in this book, called the learning rate finder). Once you’ve picked a learning rate, you can adjust your parameters using this simple function:

w -= gradient(w) * lr

This is known as stepping your parameters, using an optimizer step. Notice how we subtract the gradient * lr from the parameter to update it. This allows us to adjust the parameter in the direction of the slope by increasing the parameter when the slope is negative and decreasing the parameter when the slope is positive. We want to adjust our parameters in the direction of the slope because our goal in deep learning is to minimize the loss.

If you pick a learning rate that’s too low, it can mean having to do a lot of steps. <> illustrates that.

But picking a learning rate that’s too high is even worse—it can actually result in the loss getting worse, as we see in <>!

If the learning rate is too high, it may also “bounce” around, rather than actually diverging; <> shows how this has the result of taking many steps to train successfully.

Now let’s apply all of this in an end-to-end example.

An End-to-End SGD Example

We’ve seen how to use gradients to find a minimum. Now it’s time to look at an SGD example and see how finding a minimum can be used to train a model to fit data better.

Let’s start with a simple, synthetic, example model. Imagine you were measuring the speed of a roller coaster as it went over the top of a hump. It would start fast, and then get slower as it went up the hill; it would be slowest at the top, and it would then speed up again as it went downhill. You want to build a model of how the speed changes over time. If you were measuring the speed manually every second for 20 seconds, it might look something like this:

time = Nx.iota({20})

> BS: Here we used the iota() creation function which returns a tensor of a given shape with values from an increasing index by default.

Function simply used to create fake “actual” data for the next example.

# Python
# speed = torch.randn(20)*3 + 0.75*(time-9.5)**2 + 1

get_speed = fn x ->
  key = Nx.Random.key(42)
  {randint, _new_key} = Nx.Random.randint(key, 0, 20, shape: x.shape)

  Nx.multiply(randint, 3)
  |> Nx.add(Nx.multiply(0.75, Nx.pow(Nx.subtract(x, 9.5), 2)))
  |> Nx.add(1)

speed = get_speed.(time)

We’ve added a bit of random noise, since measuring things manually isn’t precise. This means it’s not that easy to answer the question: what was the roller coaster’s speed? Using SGD we can try to find a function that matches our observations. We can’t consider every possible function, so let’s use a guess that it will be quadratic; i.e., a function of the form a(time**2)+(btime)+c.

We want to distinguish clearly between the function’s input (the time when we are measuring the coaster’s speed) and its parameters (the values that define which quadratic we’re trying). So, let’s collect the parameters in one argument and thus separate the input, t, and the parameters, params, in the function’s signature:

quadratic = fn time, params ->
  [a, b, c] = [params[0], params[1], params[2]]

  Nx.multiply(a, Nx.pow(time, 2))
  |> Nx.add(Nx.multiply(b, time))
  |> Nx.add(c)

In other words, we’ve restricted the problem of finding the best imaginable function that fits the data, to finding the best quadratic function. This greatly simplifies the problem, since every quadratic function is fully defined by the three parameters a, b, and c. Thus, to find the best quadratic function, we only need to find the best values for a, b, and c.

If we can solve this problem for the three parameters of a quadratic function, we’ll be able to apply the same approach for other, more complex functions with more parameters—such as a neural net. Let’s find the parameters for f first, and then we’ll come back and do the same thing for the MNIST dataset with a neural net.

We need to define first what we mean by “best.” We define this precisely by choosing a loss function, which will return a value based on a prediction and a target, where lower values of the function correspond to “better” predictions. It is important for loss functions to return lower values when predictions are more accurate, as the SGD procedure we defined earlier will try to minimize this loss. For continuous data, it’s common to use mean squared error:

# Python
# def mse(preds, targets): return ((preds-targets)**2).mean()

mse = fn preds, targets -> Nx.subtract(preds, targets) |> Nx.pow(2) |> Nx.mean() end

Now, let’s work through our 7 step process.

Step 1: Initialize the parameters

First, we initialize the parameters to random values as a tensor so we can use Nx to manipulate them later:

key = Nx.Random.key(42)
{rand, _key} = Nx.Random.normal(key, shape: {3})

params = rand

Step 2: Calculate the predictions

Next, we calculate the predictions:

preds = quadratic.(time, params)
TODO: Figure out how to plot charts in Livebook and graph loss

Step 3: Calculate the loss

We calculate the loss as follows:

# Python
# loss = mse(preds, speed)

loss = mse.(preds, speed)

Our goal is now to improve this. To do that, we’ll need to know the gradients.

Step 4: Calculate the gradients

The next step is to calculate the gradients. time and speed are our fake observed values so these will be treated as constants here. We will therefore create an anonymous function that takes in our params tensor and returns our calculated loss.

In other words, calculate an approximation of how the parameters need to change to reduce the loss:

f = fn params ->
  preds = quadratic.(time, params)
  mse.(preds, speed)

gradients = Nx.Defn.grad(f).(params)
rate = 0.00001
Nx.multiply(gradients, rate)

We can use these gradients to improve our parameters. We’ll need to pick a learning rate (we’ll discuss how to do that in practice in the next chapter; for now we’ll just use 1e-5, or 0.00001):

Step 5: Step the weights

Now we need to update the parameters based on the gradients we just calculated:

lr = 0.00001
params = Nx.subtract(params, Nx.multiply(lr, gradients))

Let’s see if the loss has improved:

# preds = f(time,params)
# mse(preds, speed)

preds = quadratic.(time, params)
mse.(preds, speed)

We need to repeat this a few times, so we’ll create a function to apply one step:

# Python
# def apply_step(params, prn=True):
#     preds = f(time, params)
#     loss = mse(preds, speed)
#     loss.backward()
#     params.data -= lr * params.grad.data
#     params.grad = None
#     if prn: print(loss.item())
#     return preds

f = fn time, speed, pms ->
  preds = quadratic.(time, pms)
  mse.(preds, speed)

apply_step = fn params ->
  {loss, gradients} = Nx.Defn.value_and_grad(fn p -> f.(time, speed, p) end).(params)

  params = Nx.subtract(params, Nx.multiply(lr, gradients))
  {loss, params}

apply_step1 = fn params ->
  preds = quadratic.(time, params)
  loss = mse.(preds, speed)
  grad_func = Nx.Defn.grad(fn p -> f.(time, speed, p) end)
  gradients = grad_func.(params)
  params = Nx.subtract(params, Nx.multiply(lr, gradients))
  {loss, params}

apply_step2 = fn params ->
  # preds = quadratic.(time, params)
  # loss = mse.(preds, speed)
  grad_func = Nx.Defn.value_and_grad(fn p -> f.(time, speed, p) end)
  {loss, gradients} = grad_func.(params)
  params = Nx.subtract(params, Nx.multiply(lr, gradients))
  {loss, params}

Step 6: Repeat process

Now we iterate. By looping and performing many improvements, we hope to reach a good result:

# for i in range(10): apply_step(params)

# for i <- 0..9, reduce: params do
#   params ->
#     {loss, new_params} = apply_step.(params)
#     IO.inspect(loss, label: "loss #{i}")
#     new_params
# end

params =
  Enum.reduce(0..9, params, fn i, params ->
    {loss, new_params} = apply_step.(params)
    IO.inspect(loss, label: "loss #{i}")
TODO: Is there a better way to loop and update params than using reduce w/ acc above?
TODO: Use Livebook plotting to visualize the improvements in the parameters after the 10 iterations of SGD

Summarizing Gradient Descent

flowchart LR
  id1([init]) --> id2([predict])
  id2 --> id3([loss])
  id3 --> id4([gradient])
  id4 --> id5([step])
  id5 --> id6([stop])
  id5 -->|repeat| id2

To summarize, at the beginning, the weights of our model can be random (training from scratch) or come from a pretrained model (transfer learning). In the first case, the output we will get from our inputs won’t have anything to do with what we want, and even in the second case, it’s very likely the pretrained model won’t be very good at the specific task we are targeting. So the model will need to learn better weights.

We begin by comparing the outputs the model gives us with our targets (we have labeled data, so we know what result the model should give) using a loss function, which returns a number that we want to make as low as possible by improving our weights. To do this, we take a few data items (such as images) from the training set and feed them to our model. We compare the corresponding targets using our loss function, and the score we get tells us how wrong our predictions were. We then change the weights a little bit to make it slightly better.

To find how to change the weights to make the loss a bit better, we use calculus to calculate the gradients. (Actually, we let PyTorch do it for us!) Let’s consider an analogy. Imagine you are lost in the mountains with your car parked at the lowest point. To find your way back to it, you might wander in a random direction, but that probably wouldn’t help much. Since you know your vehicle is at the lowest point, you would be better off going downhill. By always taking a step in the direction of the steepest downward slope, you should eventually arrive at your destination. We use the magnitude of the gradient (i.e., the steepness of the slope) to tell us how big a step to take; specifically, we multiply the gradient by a number we choose called the learning rate to decide on the step size. We then iterate until we have reached the lowest point, which will be our parking lot, then we can stop.

All of that we just saw can be transposed directly to the MNIST dataset, except for the loss function. Let’s now see how we can define a good training objective.

The MNIST Loss Function

We already have our independent variables x—these are the images themselves. We’ll concatenate them all into a single tensor, and also change them from a list of matrices (a rank-3 tensor) to a list of vectors (a rank-2 tensor). We can do this using reshape(), which is an Nx method that changes the shape of a tensor without changing its contents. :auto is a special parameter to reshape that means “make this axis as big as necessary to fit all the data”:

# Python
# train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)

train_x =
  Nx.concatenate([stacked_threes, stacked_sevens])
  |> Nx.reshape({:auto, 28 * 28})

We need a label for each image. We’ll use 1 for 3s and 0 for 7s:

answers = List.duplicate(1, length(threes)) ++ List.duplicate(0, length(sevens))
train_y = Nx.tensor(answers) |> Nx.reshape({:auto, 1})


A Dataset in PyTorch is required to return a tuple of (x,y) when indexed. Python provides a zip function which, when combined with list, provides a simple way to get this functionality:

# Python
# dset = list(zip(train_x,train_y))
# x,y = dset[0]
# x.shape,y

dset = [train_x, train_y]
valid_x =
  Nx.concatenate([valid_3_tens, valid_7_tens])
  |> Nx.reshape({:auto, 28 * 28})

answers = List.duplicate(1, length(valid_threes)) ++ List.duplicate(0, length(valid_sevens))

valid_y = Nx.tensor(answers) |> Nx.reshape({:auto, 1})

valid_dset = [valid_x, valid_y]

Now we need an (initially random) weight for every pixel (this is the initialize step in our seven-step process):

init_params = fn shape ->
  key = Nx.Random.key(42)
  {params, _new_key} = Nx.Random.normal(key, shape: shape)
weights = init_params.({28 * 28})

The function weights*pixels won’t be flexible enough—it is always equal to 0 when the pixels are equal to 0 (i.e., its intercept is 0). You might remember from high school math that the formula for a line is y=w*x+b; we still need the b. We’ll initialize it to a random number too:

bias = init_params.({1})

In neural networks, the w in the equation y=w*x+b is called the weights, and the b is called the bias. Together, the weights and bias make up the parameters.

> jargon: Parameters: The weights and biases of a model. The weights are the w in the equation w*x+b, and the biases are the b in that equation.

We can now calculate a prediction for one image:

# (train_x[0]*weights.T).sum() + bias

Nx.multiply(train_x[0], weights)
|> Nx.sum()
|> Nx.add(bias)

While we could use Elixir to loop through and calculate the prediction for each image, that would be very slow. Because Elixir loops don’t run on the GPU we need to represent as much of the computation in a model as possible using higher-level functions.

In this case, there’s an extremely convenient mathematical operation that calculates w*x for every row of a matrix—it’s called matrix multiplication. <> shows what matrix multiplication looks like.

This image shows two matrices, A and B, being multiplied together. Each item of the result, which we’ll call AB, contains each item of its corresponding row of A multiplied by each item of its corresponding column of B, added together. For instance, row 1, column 2 (the yellow dot with a red border) is calculated as $a{1,1} * b{1,2} + a{1,2} * b{2,2}$. If you need a refresher on matrix multiplication, we suggest you take a look at the Intro to Matrix Multiplication on Khan Academy, since this is the most important mathematical operation in deep learning.

Let’s try it:

linear1 = fn xb ->
  Nx.dot(xb, weights)
  |> Nx.add(bias)

preds = linear1.(train_x)

The first element is the same as we calculated before, as we’d expect. This equation which uses matrix multiplication (shown with the @ operator from Python), batch@weights + bias, is one of the two fundamental equations of any neural network (the other one is the activation function, which we’ll see in a moment).

> BS: In Elixir, we’ll use the Nx.dot() for matrix multiplication. See docs

Let’s check our accuracy. To decide if an output represents a 3 or a 7, we can just check whether it’s greater than 0.0, so our accuracy for each item can be calculated (using broadcasting, so no loops!) with:

corrects =
  Nx.greater(preds, 0)
  |> Nx.equal(Nx.squeeze(train_y))


Now let’s see what the change in accuracy is for a small change in one of the weights:

weights = Nx.multiply(weights, 1.0001)

linear1 = fn xb ->
  Nx.dot(xb, weights)
  |> Nx.add(bias)

preds = linear1.(train_x)
corrects = Nx.greater(preds, 0) |> Nx.equal(Nx.squeeze(train_y))


As we’ve seen, we need gradients in order to improve our model using SGD, and in order to calculate gradients we need some loss function that represents how good our model is. That is because the gradients are a measure of how that loss function changes with small tweaks to the weights.

So, we need to choose a loss function. The obvious approach would be to use accuracy, which is our metric, as our loss function as well. In this case, we would calculate our prediction for each image, collect these values to calculate an overall accuracy, and then calculate the gradients of each weight with respect to that overall accuracy.

Unfortunately, we have a significant technical problem here. The gradient of a function is its slope, or its steepness, which can be defined as rise over run—that is, how much the value of the function goes up or down, divided by how much we changed the input. We can write this in mathematically as: (y_new - y_old) / (x_new - x_old). This gives us a good approximation of the gradient when x_new is very similar to x_old, meaning that their difference is very small. But accuracy only changes at all when a prediction changes from a 3 to a 7, or vice versa. The problem is that a small change in weights from x_old to x_new isn’t likely to cause any prediction to change, so (y_new - y_old) will almost always be 0. In other words, the gradient is 0 almost everywhere.

A very small change in the value of a weight will often not actually change the accuracy at all. This means it is not useful to use accuracy as a loss function—if we do, most of the time our gradients will actually be 0, and the model will not be able to learn from that number.

> S: In mathematical terms, accuracy is a function that is constant almost everywhere (except at the threshold, 0.5), so its derivative is nil almost everywhere (and infinity at the threshold). This then gives gradients that are 0 or infinite, which are useless for updating the model.

Instead, we need a loss function which, when our weights result in slightly better predictions, gives us a slightly better loss. So what does a “slightly better prediction” look like, exactly? Well, in this case, it means that if the correct answer is a 3 the score is a little higher, or if the correct answer is a 7 the score is a little lower.

Let’s write such a function now. What form does it take?

The loss function receives not the images themselves, but the predictions from the model. Let’s make one argument, prds, of values between 0 and 1, where each value is the prediction that an image is a 3. It is a vector (i.e., a rank-1 tensor), indexed over the images.

The purpose of the loss function is to measure the difference between predicted values and the true values — that is, the targets (aka labels). Let’s make another argument, trgts, with values of 0 or 1 which tells whether an image actually is a 3 or not. It is also a vector (i.e., another rank-1 tensor), indexed over the images.

So, for instance, suppose we had three images which we knew were a 3, a 7, and a 3. And suppose our model predicted with high confidence (0.9) that the first was a 3, with slight confidence (0.4) that the second was a 7, and with fair confidence (0.2), but incorrectly, that the last was a 7. This would mean our loss function would receive these values as its inputs:

trgts = Nx.tensor([1, 0, 1])
prds = Nx.tensor([0.9, 0.4, 0.2])

Here’s a first try at a loss function that measures the distance between predictions and targets:

# Python
# def mnist_loss(predictions, targets):
#     return torch.where(targets==1, 1-predictions, predictions).mean()

mnist_loss = fn predictions, targets ->
  Nx.select(Nx.equal(targets, 1), Nx.subtract(1, predictions), predictions)
  |> Nx.mean()
Nx.select(Nx.equal(trgts, 1), Nx.subtract(1, prds), prds)

You can see that this function returns a lower number when predictions are more accurate, when accurate predictions are more confident (higher absolute values), and when inaccurate predictions are less confident. In Nx, we always assume that a lower value of a loss function is better. Since we need a scalar for the final loss, mnist_loss takes the mean of the previous tensor:

mnist_loss.(prds, trgts)

For instance, if we change our prediction for the one “false” target from 0.2 to 0.8 the loss will go down, indicating that this is a better prediction:

mnist_loss.(Nx.tensor([0.9, 0.4, 0.8]), trgts)

One problem with mnist_loss as currently defined is that it assumes that predictions are always between 0 and 1. We need to ensure, then, that this is actually the case! As it happens, there is a function that does exactly that—let’s take a look.


The sigmoid function always outputs a number between 0 and 1. It’s defined as follows:

# Python
# def sigmoid(x): return 1/(1+torch.exp(-x))

sigmoid = fn x -> Nx.divide(1, Nx.add(1, Nx.exp(-x))) end

Nx defines an accelerated version for us, so we don’t really need our own. This is an important function in deep learning, since we often want to ensure values are between 0 and 1. This is what it looks like:

As you can see, it takes any input value, positive or negative, and smooshes it onto an output value between 0 and 1. It’s also a smooth curve that only goes up, which makes it easier for SGD to find meaningful gradients.

Let’s update mnist_loss to first apply sigmoid to the inputs:

# def mnist_loss(predictions, targets):
#     predictions = predictions.sigmoid()
#     return torch.where(targets==1, 1-predictions, predictions).mean()

mnist_loss = fn predictions, targets ->
  predictions = Nx.sigmoid(predictions)

  Nx.select(Nx.equal(Nx.squeeze(targets), 1), Nx.subtract(1, predictions), predictions)
  |> Nx.mean()

Now we can be confident our loss function will work, even if the predictions are not between 0 and 1. All that is required is that a higher prediction corresponds to higher confidence an image is a 3.

Having defined a loss function, now is a good moment to recapitulate why we did this. After all, we already had a metric, which was overall accuracy. So why did we define a loss?

The key difference is that the metric is to drive human understanding and the loss is to drive automated learning. To drive automated learning, the loss must be a function that has a meaningful derivative. It can’t have big flat sections and large jumps, but instead must be reasonably smooth. This is why we designed a loss function that would respond to small changes in confidence level. This requirement means that sometimes it does not really reflect exactly what we are trying to achieve, but is rather a compromise between our real goal and a function that can be optimized using its gradient. The loss function is calculated for each item in our dataset, and then at the end of an epoch the loss values are all averaged and the overall mean is reported for the epoch.

Metrics, on the other hand, are the numbers that we really care about. These are the values that are printed at the end of each epoch that tell us how our model is really doing. It is important that we learn to focus on these metrics, rather than the loss, when judging the performance of a model.

SGD and Mini-Batches

Now that we have a loss function that is suitable for driving SGD, we can consider some of the details involved in the next phase of the learning process, which is to change or update the weights based on the gradients. This is called an optimization step.

In order to take an optimization step we need to calculate the loss over one or more data items. How many should we use? We could calculate it for the whole dataset, and take the average, or we could calculate it for a single data item. But neither of these is ideal. Calculating it for the whole dataset would take a very long time. Calculating it for a single item would not use much information, so it would result in a very imprecise and unstable gradient. That is, you’d be going to the trouble of updating the weights, but taking into account only how that would improve the model’s performance on that single item.

So instead we take a compromise between the two: we calculate the average loss for a few data items at a time. This is called a mini-batch. The number of data items in the mini-batch is called the batch size. A larger batch size means that you will get a more accurate and stable estimate of your dataset’s gradients from the loss function, but it will take longer, and you will process fewer mini-batches per epoch. Choosing a good batch size is one of the decisions you need to make as a deep learning practitioner to train your model quickly and accurately. We will talk about how to make this choice throughout this book.

Another good reason for using mini-batches rather than calculating the gradient on individual data items is that, in practice, we nearly always do our training on an accelerator such as a GPU. These accelerators only perform well if they have lots of work to do at a time, so it’s helpful if we can give them lots of data items to work on. Using mini-batches is one of the best ways to do this. However, if you give them too much data to work on at once, they run out of memory—making GPUs happy is also tricky!

As we saw in our discussion of data augmentation in <>, we get better generalization if we can vary things during training. One simple and effective thing we can vary is what data items we put in each mini-batch. Rather than simply enumerating our dataset in order for every epoch, instead what we normally do is randomly shuffle it on every epoch, before we create mini-batches. PyTorch and fastai provide a class that will do the shuffling and mini-batch collation for you, called DataLoader.

A DataLoader can take any Python collection and turn it into an iterator over mini-batches, like so:

# coll = range(15)
# dl = DataLoader(coll, batch_size=5, shuffle=True)
# list(dl)

key = Nx.Random.key(42)

{shuffled, _new_key} = Nx.Random.shuffle(key, Nx.iota({15}))

shuffled |> Nx.to_batched(5) |> Enum.to_list()

For training a model, we don’t just want any collection, but a collection containing independent and dependent variables (that is, the inputs and targets of the model). A collection that contains tuples of independent and dependent variables is known as a Dataset. Here’s an example of an extremely simple Dataset:

TODO: Section with Dataloader and alphabet

Putting It All Together

It’s time to implement the process we saw in <>. In code, our process will be implemented something like this for each epoch:

# Pseudo-code for understandability

for inputs,outputs in steps:
    preds = model(params, inputs)
    loss = loss_func(preds, outputs)
    gradients = grad_func(params)
    params = params - (gradients * learning_rate)

First, let’s re-initialize our parameters:

  weights = init_params.({28 * 28}),
  bias = init_params.({1})

And redefine our linear1 function to use weights and bias as input:

linear1 = fn xb, weights, bias ->
  Nx.dot(xb, weights)
  |> Nx.add(bias)

A DataLoader can be created from a Dataset:

Let’s get a Image data ready for batches

Let’s look at our image set data:

# Python
# dl = DataLoader(dset, batch_size=256)
# xb,yb = first(dl)
# xb.shape,yb.shape

# dset = Explorer.DataFrame.new(x: train_x, y: train_y)

dl =
    train_x |> Nx.to_batched(256) |> Enum.to_list(),
    train_y |> Nx.to_batched(256) |> Enum.to_list()
  |> Enum.zip()

{xb, yb} = hd(dl)
[xb.shape, yb.shape]

We’ll do the same for the validation set:

valid_dl =
    valid_x |> Nx.to_batched(256) |> Enum.to_list(),
    valid_y |> Nx.to_batched(256) |> Enum.to_list()
  |> Enum.zip()
But first let’s test a few things out with a small batch

Let’s create a mini-batch of size 4 for testing:

# Todo - because these aren't shuffled we're getting all of the same type and will be all correct or all incorrect

batch = train_x[6129..6132]
preds = linear1.(batch, weights, bias)
answers = Nx.squeeze(train_y[6129..6132])
loss = mnist_loss.(preds, answers)

Now we can calculate the gradients:

# loss.backward()
# weights.grad.shape,weights.grad.mean(),bias.grad

f = fn images, answers, params ->
  {weights, bias} = params
  preds = linear1.(images, weights, bias)
  mnist_loss.(preds, answers)

params = {weights, bias}

gradients = Nx.Defn.grad(fn p -> f.(batch, answers, p) end).(params)

Let’s put that all in a function:

# def calc_grad(xb, yb, model):
#     preds = model(xb)
#     loss = mnist_loss(preds, yb)
#     loss.backward()

calc_grad = fn xb, yb, params ->
  # {weights, bias} = params
  # preds = linear1.(xb, weights, bias)
  # mnist_loss.(preds, yb)
  Nx.Defn.grad(fn p -> f.(xb, yb, p) end).(params)

and test it:

# calc_grad(batch, train_y[:4], linear1)
# weights.grad.mean(),bias.grad

calc_grad.(batch, answers, {weights, bias})

Our only remaining step is to update the weights and biases based on the gradient and learning rate. Here’s our basic training loop for an epoch:

# def train_epoch(model, lr, params):
#     for xb,yb in dl:
#         calc_grad(xb, yb, model)
#         for p in params:
#             p.data -= p.grad*lr
#             p.grad.zero_()

train_epoch = fn data, lr, params ->
  Enum.reduce(data, params, fn {xb, yb}, params ->
    {weight, bias} = params

    {w_gradient, b_gradient} = Nx.Defn.grad(fn p -> f.(xb, yb, p) end).(params)

      Nx.subtract(weight, Nx.multiply(lr, w_gradient)),
      Nx.subtract(bias, Nx.multiply(lr, b_gradient))

> BS: The Python versions of this code use global variables and changing values outside of the explicit scope of the functions. This is hard to read and thankfully Elixir does not allow this, but it will result in our code being more verbose.

range = 6120..6140

batch_x = train_x[range]
batch_y = train_y[range]
lr = 1.0
params = {weights, bias}
params = train_epoch.(dl, lr, params)
# Python
# def batch_accuracy(xb, yb):
#     preds = xb.sigmoid()
#     correct = (preds>0.5) == yb
#     return correct.float().mean()

batch_accuracy = fn xb, yb ->
  preds = Nx.sigmoid(xb)
  correct = Nx.greater(preds, 0.5) |> Nx.equal(yb)
# def validate_epoch(model):
#     accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
#     return round(torch.stack(accs).mean().item(), 4)
validate_epoch = fn valid_dl, model ->
  accs =
    Enum.map(valid_dl, fn {xb, yb} ->
      batch_accuracy.(model.(xb), yb)

  Nx.stack(accs) |> Nx.mean()

model = fn xb -> linear1.(xb, weights, bias) end
validate_epoch.(valid_dl, model)

That’s our starting point. Let’s train for one epoch, and see if the accuracy improves:

params = {weights, bias}
{weights, bias} = train_epoch.(dl, lr, params)

model = fn xb -> linear1.(xb, weights, bias) end
validate_epoch.(valid_dl, model)


Then do a few more:

# for i in range(20):
#     train_epoch(linear1, lr, params)
#     print(validate_epoch(linear1), end=' ')
params = {weights, bias}

Enum.reduce(0..4, params, fn _, params ->
  {weights, bias} = train_epoch.(dl, lr, params)

  model = fn xb -> linear1.(xb, weights, bias) end
  IO.inspect(validate_epoch.(valid_dl, model), label: "latest epoch")
  {weights, bias}

Looking good! We’re already about at the same accuracy as our “pixel similarity” approach, and we’ve created a general-purpose foundation we can build on. Our next step will be to create an object that will handle the SGD step for us. In PyTorch, it’s called an optimizer.

Creating an Optimizer

We will see that Axon provides several Optimizers for us, but to demonstrate their purpose we’ll build one from scratch.

# linear_model = nn.Linear(28*28,1)
# `linear_model = Nx.tensor()

Now we add parameters to train to the model.

{params, _key} = Nx.Random.normal(key, shape: {784, 1})
# params.shape
# {weights, bias} = params
# [weights.shape, bias.shape]
# def train_epoch(model):
#     for xb,yb in dl:
#         calc_grad(xb, yb, model)
#         opt.step()
#         opt.zero_grad()
train_epoch = fn model, dl ->
  Enum.reduce(dl, model, fn {xb, yb}, model ->
# def train_model(model, epochs):
#     for i in range(epochs):
#         train_epoch(model)
#         print(validate_epoch(model), end=' ')

train_model = fn model, epochs ->
  for _ <- 1..epochs, reduce: model do
    model ->
      newModel = train_epoch.(model)
Use Prebuilt optimizer to do the same thing
# Python
# linear_model = nn.Linear(28*28,1)
# opt = SGD(linear_model.parameters(), lr)
# train_model(linear_model, 20)

# lr = 1.0
# {inputs, targets} = hd(dl)

# f = fn params, images, answers ->
#   {weights, bias} = params
#   preds = linear1.(images, weights, bias)
#   mnist_loss.(preds, answers)
# end

# {params, _} = Nx.Random.normal(key, shape: {784,1})
# # {bias, _} = Nx.Random.uniform(key, shape: {1})
# {init_fn, update_fn} = Axon.Optimizers.sgd(lr)

# # params = {weights, bias}
# optimizer_state = init_fn.(params)
# {loss, gradient} = Nx.Defn.value_and_grad(params, f)
# {scaled_updates, new_optimizer_state} = update_fn.(gradient, optimizer_state, params)
Use Prebuilt learner to do the same thing

Further Elixir Reading


  • Move static diagrams to VegaLite
  • Move http to Req library
  • Create the w/ Python versions and remove Python code here
  • Change the Enum.reduce loops to for-loops for readability