Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Generating text with LSTM


Generating text with LSTM

  # {:axon, github: "elixir-nx/axon"},
  {:axon, "~> 0.2.0"},
  {:nx, "~> 0.3.0", override: true},
  {:exla, "~> 0.3.0"},
  {:req, "~> 0.3.0"}

Nx.Defn.default_options(compiler: EXLA)


Recurrent Neural Networks (RNNs) can be used as generative models. This means that in addition to being used for predictive models (making predictions) they can learn the sequences of a problem and then generate entirely new plausible sequences for the problem domain.

Generative models like this are useful not only to study how well a model has learned a problem, but to learn more about the problem domain itself.

In this example, we will discover how to create a generative model for text, character-by-character using Long Short-Term Memory (LSTM) recurrent neural networks in Elixir with Axon.


Using Project Gutenburg we can download a text books that are no longer protected under copywrite, so we can experiment with them.

The one that we will use for this experiment is Alice’s Adventures in Wonderland by Lewis Carroll. You can choose any other text or book that you like for this experiment.

# Change the URL if you'd like to experiment with other books
download_url = "https://www.gutenberg.org/files/11/11-0.txt"

book_text = Req.get!(download_url).body

First of all, we need to normalize the content of the book. We are only interested in the sequence of English characters, periods and new lines. Also currently we don’t care about the capitalization and things like apostrophe so we can remove all other unknown characters and downcase everything. We can use a regular expression for that.

We can also convert the string into a list of characters so we can handle them easier. You will understand exactly why a bit further.

normalized_book_text =
  |> String.downcase()
  |> String.replace(~r/[^a-z \.\n]/, "")
  |> String.to_charlist()

We converted the text to a list of characters, where each character is a number (specifically, a Unicode code point). Lowercase English characters are represented with numbers between 97 = a and 122 = z, a space is 32 = [ ], a new line is 10 = \n and the period is 46 = ..

So we should have 26 + 3 (= 29) characters in total. Let’s see if that’s true.

normalized_book_text |> Enum.uniq() |> Enum.count()

Since we want to use this 29 characters as possible values for each input in our neural network, we can re-map them to values between 0 and 28. So each specific neuron will indicate a specific character.

# Extract all then unique characters we have and sort them for clarity
characters = normalized_book_text |> Enum.uniq() |> Enum.sort()
characters_count = Enum.count(characters)

# Create a mapping for every character
char_to_idx = characters |> Enum.with_index() |> Map.new()
# And a reverse mapping to convert back to characters
idx_to_char = characters |> Enum.with_index(&{&2, &1}) |> Map.new()

IO.puts("Total book characters: #{Enum.count(normalized_book_text)}")
IO.puts("Total unique characters: #{characters_count}")

Now we need to create our training and testing data sets. But how?

Our goal is to teach the machine what comes after a sequence of characters (usually). For example given the following sequence “Hello, My name i” the computer should be able to guess that the next character is probably “s”.

graph LR;
  A[Input: Hello my name i]-->NN[Neural Network]-->B[Output: s];

Let’s choose an arbitrary sequence length and create a data set from the book text. All we need to do is read X amount of characters from the book as the input and then read 1 more as the designated output.

After doing all that, we also want to convert every character to it’s index using the char_to_idx mapping that we have created before.

Neural networks work best if you scale your inputs and outputs. In this case we are going to scale everything between 0 and 1 by dividing them by the number of unique characters that we have.

And for the final step we will reshape it so we can use the data in our LSTM model.

sequence_length = 100

train_data =
  |> Enum.map(&Map.fetch!(char_to_idx, &1))
  |> Enum.chunk_every(sequence_length, 1, :discard)
  # We don't want the last chunk since we don't have a prediction for it.
  |> Enum.drop(-1)
  |> Nx.tensor()
  |> Nx.divide(characters_count)
  |> Nx.reshape({:auto, sequence_length, 1})

For our train results, We will do the same. Drop the first sequence_length characters and then convert them to the mapping. Additionally, we will do one-hot encoding.

The reason we want to use one-hot encoding is that in our model we don’t want to only return a character as the output. We want it to return the probability of each character for the output. This way we can decide if certain probability is good or not or even we can decide between multiple possible outputs or even discard everything if the network is not confident enough.

In Nx, you can achieve this encoding by using this snippet

|> Nx.equal(Nx.iota({1, 3}))

To sum it up, Here is how we generate the train results.

train_results =
  |> Enum.drop(sequence_length)
  |> Enum.map(&Map.fetch!(char_to_idx, &1))
  |> Nx.tensor()
  |> Nx.reshape({:auto, 1})
  |> Nx.equal(Nx.iota({1, characters_count}))

Defining the Model

# As the input, we expect the sequence_length characters

model =
  Axon.input("input_chars", shape: {nil, sequence_length, 1})
  # The LSTM layer of our network
  |> Axon.lstm(256)
  # Selecting only the output from the LSTM Layer
  |> then(fn {out, _} -> out end)
  # Since we only want the last sequence in LSTM we will slice it and
  # select the last one
  |> Axon.nx(fn t -> t[[0..-1//1, -1]] end)
  # 20% dropout so we will not become too dependent on specific neurons
  |> Axon.dropout(rate: 0.2)
  # The output layer. One neuron for each character and using softmax,
  # as activation so every node represents a probability
  |> Axon.dense(characters_count, activation: :softmax)

Training the network

To train the network, we will use Axon’s Loop API. It is pretty straightforward.

For the loss function we can use categorical cross-entropy since we are dealing with categories (each character) in our output. For the optimizer we can use Adam.

We will train our network for 20 epochs. Note that we are working with a fair amount data, so it may take a long time unless you run it on a GPU.

batch_size = 128
train_batches = Nx.to_batched(train_data, batch_size)
result_batches = Nx.to_batched(train_results, batch_size)

IO.puts("Total batches: #{Enum.count(train_batches)}")

params =
  |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.001))
  |> Axon.Loop.run(Stream.zip(train_batches, result_batches), %{}, epochs: 20, compiler: EXLA)


Generating text

Now we have a trained neural network, so we can start generating text with it! We just need to pass the initial sequence as the input to the network and select the most probable output. Axon.predict/3 will give us the output layer and then using Nx.argmax/1 we get the most confident neuron index, then simply convert that index back to its Unicode representation.

generate_fn = fn model, params, init_seq ->
  # The initial sequence that we want the network to complete for us.
  init_seq =
    |> String.trim()
    |> String.downcase()
    |> String.to_charlist()
    |> Enum.map(&Map.fetch!(char_to_idx, &1))

  Enum.reduce(1..100, init_seq, fn _, seq ->
    init_seq =
      |> Enum.take(-sequence_length)
      |> Nx.tensor()
      |> Nx.divide(characters_count)
      |> Nx.reshape({1, sequence_length, 1})

    char =
      Axon.predict(model, params, init_seq)
      |> Nx.argmax()
      |> Nx.to_number()

    seq ++ [char]
  |> Enum.map(&Map.fetch!(idx_to_char, &1))

# The initial sequence that we want the network to complete for us.
init_seq = """
not like to drop the jar for fear
of killing somebody underneath so managed to put it into one of the
cupboards as she fell past it.

generate_fn.(model, params, init_seq) |> IO.puts()

Multi LSTM layers

We can improve our network by stacking multiple LSTM layers together. We just need to change our model and re-train our network.

new_model =
  Axon.input("input_chars", shape: {nil, sequence_length, 1})
  |> Axon.lstm(256)
  |> then(fn {out, _} -> out end)
  |> Axon.dropout(rate: 0.2)
  # This time we will pass all of the `out` to the next lstm layer.
  # We just need to slice the last one.
  |> Axon.lstm(256)
  |> then(fn {out, _} -> out end)
  |> Axon.nx(fn x -> x[[0..-1//1, -1]] end)
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(characters_count, activation: :softmax)

Then we can train the network using the exact same code as before

# Using a smaller batch size in this case will give the network more opportunity to learn
batch_size = 64
train_batches = Nx.to_batched(train_data, batch_size)
result_batches = Nx.to_batched(train_results, batch_size)

IO.puts("Total batches: #{Enum.count(train_batches)}")

new_params =
  |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.001))
  |> Axon.Loop.run(Stream.zip(train_batches, result_batches), %{}, epochs: 50, compiler: EXLA)


Generate text with the new network

generate_fn.(new_model, new_params, init_seq) |> IO.puts()

As you may see, it improved a lot with this new model and the extensive training. This time it knows about rules like adding a space after period.


The above example was written heavily inspired by this article by Jason Brownlee.