Ch 9: RNNs
Mix.install([
{:scidata, "~> 0.1"},
{:axon, "~> 0.7"},
{:exla, "~> 0.6"},
{:nx, "~> 0.6"},
{:table_rex, "~> 3.1.1"},
{:kino, "~> 0.7"}
])
Nx.default_backend(EXLA.Backend)
Get data
data = Scidata.IMDBReviews.download()
{train_data, test_data} =
data.review
|> Enum.zip(data.sentiment)
|> Enum.shuffle()
|> Enum.split(23_000)
Tokenising
Let’s have a peek at our data and a potential tokenising strategy: normalise by downcasing and stripping punctuation. Splitting on whitespace.
Not suitable for everything but fits acceptably to problem scenario.
{review, _sentiment} = train_data |> hd()
normalise = fn (review) ->
review
|> String.downcase()
# remove punctuation and symbols
|> String.replace(~r/[\p{P}\p{S}]/, "")
|> String.split()
end
normalise.(review)
We’ll use a sparse representation by mapping words to an index. To do this we’ll map the most frequent words to avoid vocabulary bloat, i.e. a ginormous index.
frequencies =
Enum.reduce(train_data, %{}, fn {review, _}, tokens ->
review
|> normalise.()
|> Enum.reduce(tokens, &Map.update(&2, &1, 1, fn x -> x + 1 end))
end)
# num_tokens is arbitrary limit
num_tokens = 1024
tokens =
frequencies
|> Enum.sort_by(&elem(&1, 1), :desc)
|> Enum.take(num_tokens)
tokens =
tokens
|> Enum.with_index(fn {token, _}, i -> {token, i + 2} end)
|> Map.new()
Note we’ve started indexing from 2. The 0 and 1 indexes are unassigned. These are for
-
a padding token
Nx requires static shapes and doesn’t support ragged tensors* (tensors of non-uniform dimensions), so you need a strategy to convert all of your input reviews to a uniform shape. The most common way to do this is by padding or truncating each sequence to a fixed length.
-
the OOV tokens (as a category)
We need to account for the words that’ll fall outside this vocabulary, out-of-vocab(OOV) tokens.
*Some frameworks support working with ragged tensors.
Enum.find(tokens, "No zero or one index found", fn {_t,i} -> i == 0 or i == 1 end )
Now we have an indexed vocab.
Next we’ll write a tokeniser function. Note the replacement of OOV tokens with the 0 index.
pad_token = 0
unknown_token = 1
max_seq_len = 64
tokenize = fn review ->
review
|> normalise.()
|> Enum.map(&Map.get(tokens, &1, unknown_token))
|> Nx.tensor()
|> then(&Nx.pad(&1, pad_token, [{0, max_seq_len - Nx.size(&1), 0}]))
end
tokenize.(review)
Input pipeline
batch_size = 64
train_pipeline =
train_data
|> Stream.map(fn {review, label} ->
{tokenize.(review), Nx.tensor(label)}
end)
|> Stream.chunk_every(batch_size, batch_size, :discard)
|> Stream.map(fn reviews_and_labels ->
{review, label} = Enum.unzip(reviews_and_labels)
{Nx.stack(review), Nx.stack(label) |> Nx.new_axis(-1)}
end)
test_pipeline =
test_data
|> Stream.map(fn {review, label} ->
{tokenize.(review), Nx.tensor(label)}
end)
|> Stream.chunk_every(batch_size, batch_size, :discard)
|> Stream.map(fn reviews_and_labels ->
{review, label} = Enum.unzip(reviews_and_labels)
{Nx.stack(review), Nx.stack(label) |> Nx.new_axis(-1)}
end)
Enum.take(train_pipeline, 1)
train = fn (model, epochs) ->
loss =
&Axon.Losses.binary_cross_entropy(&1, &2,
from_logits: true,
reduction: :mean
)
optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-4)
model
|> Axon.Loop.trainer(loss, optimizer)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_pipeline, %{}, epochs: epochs, compiler: EXLA)
end
evaluate = fn (model, trained_state) ->
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_pipeline, trained_state, compiler: EXLA)
end
MLP as benchmark
mlp_model =
Axon.input("review")
|> Axon.embedding(num_tokens + 2, 64)
|> Axon.flatten()
|> Axon.dense(64, activation: :relu)
|> Axon.dense(1)
Axon.embedding/3 takes a sparse collection of tokens like the ones you have in each of your sequences and maps them to a dense vector representation.
input_template = Nx.template({64, 64}, :s64)
Axon.Display.as_graph(mlp_model, input_template)
mlp_trained_model_state =
train.(mlp_model, 10)
evaluate.(mlp_model, mlp_trained_model_state)
RNN
sequence = Axon.input("review")
embedded = sequence |> Axon.embedding(num_tokens + 2, 64)
# ignore padding tokens
mask = Axon.mask(sequence, 0)
{rnn_sequence, _state} =
Axon.lstm(
embedded,
64,
mask: mask,
unroll: :static
)
# extract final token from rnn sequence
final_token = Axon.nx(rnn_sequence, fn seq ->
Nx.squeeze(seq[[0..-1//1, -1, 0..-1//1]])
end)
Axon.nx/3 allows you to utilize generic Nx functions as Axon layers. In this code, you use Nx.slice_along_axis/4 to grab the final token—at index max_seq_len - 1—from axis 1, the temporal axis.
rnn_model =
final_token
|> Axon.dense(64, activation: :relu)
|> Axon.dense(1)
input_template = Nx.template({64, 64}, :s64)
Axon.Display.as_graph(rnn_model, input_template)
rnn_trained_model_state = train.(rnn_model, 10)
evaluate.(rnn_model, rnn_trained_model_state)
Bidirectional RNN
Note that this is fairly similar to the RNN above but makes use of the convenience function Axon.bidirectional/3.
sequence = Axon.input("review")
mask = Axon.mask(sequence, 0)
embedded = Axon.embedding(sequence, num_tokens + 2, 64)
# {rnn_sequence, _state} =
rnn_sequence =
Axon.bidirectional(
embedded,
&Axon.lstm(
&1,
64,
mask: mask,
unroll: :static
),
# how to join the result from each direction together
&Axon.concatenate/2,
axis: 1
)
IO.inspect(match?({_forward_out, _backward_out}, rnn_sequence))
IO.inspect(match?(%Axon{}, rnn_sequence))
# extract final token from rnn sequence
final_token =
Axon.nx(rnn_sequence, fn seq ->
Nx.squeeze(seq[[0..-1//1, -1, 0..-1//1]])
end)
bidir_rnn_model =
final_token
|> Axon.dense(64, activation: :relu)
|> Axon.dense(1)
input_template = Nx.template({64, 64}, :s64)
Axon.Display.as_graph(bidir_rnn_model, input_template)
bidir_rnn_trained_model_state =
train.(bidir_rnn_model, 10)
evaluate.(bidir_rnn_model, bidir_rnn_trained_model_state)
Bidirectional handrolled
I can’t get the section above to compile. So trying a handrolled bidirectional implementation.
input_l = Axon.input("review")
mask = Axon.mask(input_l, 0)
forward_sequence = Axon.embedding(input_l, num_tokens + 2, 64)
backward_sequence =
Axon.nx(forward_sequence, &Nx.reverse(&1, axes: [1]))
{forward_state, {_f_cell, _f_hidden}} = Axon.lstm(forward_sequence, 64)
{backward_state, {_b_cell, _b_hidden}} = Axon.lstm(backward_sequence, 64)
out_state =
Axon.add(
forward_state,
Axon.nx(backward_state, &Nx.reverse(&1, axes: [1]))
)
# Minor optimisation experiments commented out below
# out_state =
# Axon.concatenate(
# forward_state,
# Axon.nx(backward_state, &Nx.reverse(&1, axes: [1]))
# )
# out_state =
# Axon.concatenate(
# backward_state,
# Axon.nx(forward_state, &Nx.reverse(&1, axes: [1]))
# )
# extract final token from rnn sequence
final_token =
Axon.nx(out_state, fn seq ->
Nx.squeeze(seq[[0..-1//1, -1, 0..-1//1]])
end)
hrolled_bidir_model =
final_token
|> Axon.dense(64, activation: :relu)
|> Axon.dense(1)
input_template = Nx.template({64, 64}, :s64)
Axon.Display.as_graph(hrolled_bidir_model, input_template)
hrolled_bidir_trained_model_state =
train.(hrolled_bidir_model, 10)
evaluate.(hrolled_bidir_model, hrolled_bidir_trained_model_state)