GPT
Mix.install([
{:axon, ">= 0.5.0"},
{:nx, "~> 0.7.0"},
{:scholar, "~> 0.3.0"},
{:exla, "~> 0.2"},
{:kino, ">= 0.9.0"},
{:vega_lite, "~> 0.1.6"},
{:kino_vega_lite, "~> 0.1.11"}
])
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)
Tokenizer
text = File.read!("input.txt")
chars =
text
|> to_charlist()
|> MapSet.new()
|> MapSet.to_list()
|> Enum.sort()
vocab_size = chars |> length()
stoi =
chars
|> Enum.with_index()
|> Map.new()
itos =
chars
|> Enum.with_index()
|> Map.new(fn {c, i} -> {i, c} end)
encode = fn string ->
string
|> to_charlist()
|> Enum.map(fn c -> stoi[c] end)
end
decode = fn list ->
list
|> Enum.map(fn l -> itos[l] end)
|> to_string()
end
encode.("hii there") |> IO.inspect()
encode.("hii there") |> decode.()
data =
text
|> encode.()
|> Nx.tensor()
Nx.shape(data) |> IO.inspect()
Nx.type(data) |> IO.inspect()
data[0..1000] |> IO.inspect()
size = Nx.size(data)
n = 0.9 * size |> round()
train_data = data[0..n]
val_data = data[n..size - 1]
Data loader
block_size = 8
train_data[0..block_size] # Elixir range is inclusive
x = train_data[0..(block_size - 1)]
y = train_data[1..block_size]
0..(block_size - 1)
|> Enum.each(fn t ->
context = x[0..t] |> Nx.to_flat_list()
target = y[t] |> Nx.to_number()
IO.puts("when input is #{inspect(context)} -- target: #{target}")
end)
key = Nx.Random.key(1337)
get_batch = fn split, batch_size, block_size ->
data = if split == "train", do: train_data, else: val_data
{ix, _new_key} = Nx.Random.randint(key, 0, Nx.size(data) - block_size, shape: {batch_size})
x =
Nx.to_list(ix)
|> Enum.map(fn i ->
data[i..(i + block_size - 1)]
end)
|> Nx.stack()
y =
Nx.to_list(ix)
|> Enum.map(fn i ->
data[(i + 1)..(i + block_size)]
end)
|> Nx.stack()
{x, y}
end
{xb, yb} = get_batch.("train", 4, 8)
IO.puts("inputs")
Nx.shape(xb) |> IO.inspect(label: "xb shape")
xb |> IO.inspect()
IO.puts("targets")
Nx.shape(yb) |> IO.inspect(label: "yb shape")
yb |> IO.inspect()
batch_size = 4
block_size = 8
for b <- 0..(batch_size - 1),
t <- 0..(block_size - 1) do
context = xb[b][0..t] |> Nx.to_flat_list()
target = yb[b][t] |> Nx.to_number()
IO.puts("when input is #{inspect(context, charlists: :as_lists)} -- target: #{target}")
end
xb
random_seed = 1337
# Since Axon Loop expects a stream or Enum, let's reimplement
# get_batch to be streamed
get_batch_stream = fn split, batch_size, block_size ->
Stream.resource(
fn ->
Nx.Random.key(random_seed)
end,
fn key ->
data = if split == "train", do: train_data, else: val_data
{ix, new_key} = Nx.Random.randint(key, 0, Nx.size(data) - block_size, shape: {batch_size})
ix = Nx.to_list(ix)
x =
ix
|> Enum.map(fn i ->
data[i..(i + block_size - 1)]
end)
|> Nx.stack()
y =
ix
|> Enum.map(fn i ->
data[(i + 1)..(i + block_size)]
end)
|> Nx.stack()
{b, t} = Nx.shape(y)
flattened_y = Nx.reshape(y, {b * t})
out_data = {x, flattened_y}
{[out_data], new_key}
end,
fn _ -> :ok end
)
end
Bigram Language Model
key = Nx.Random.key(1337)
defmodule Train do
import Nx.Defn
defn predict_fn(model_predict_fn, params, input) do
%{prediction: preds} = out = model_predict_fn.(params, input)
{b, t, c} = Nx.shape(preds)
logits = Nx.reshape(preds, {b * t, c})
%{ out | prediction: logits }
end
defn loss_fn(targets, logits) do
Axon.Losses.categorical_cross_entropy(targets, logits,
# PyTorch default
reduction: :mean,
# The labels are a sparse tensor with integer values
sparse: true,
# unnormalized logits
from_logits: true
)
end
end
bigram_model =
Axon.input("input")
|> Axon.embedding(65, 65)
{init_fn, predict_fn} = Axon.build(bigram_model, mode: :train)
custom_predict_fn = &Train.predict_fn(predict_fn, &1, &2)
custom_loss_fn = &Train.loss_fn(&1, &2)
train_batch = get_batch_stream.("train", 4, 8)
params =
{init_fn, custom_predict_fn}
|> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw())
|> Axon.Loop.run(train_batch, %{}, epochs: 1, iterations: 100, compiler: EXLA)
# Generate from the model
generate_fn = fn model, params, init_seq, max_new_tokens ->
Enum.reduce(1..max_new_tokens, init_seq, fn _i, acc ->
{_b, t} = Nx.shape(acc)
context_length = min(t, block_size)
context_range = -context_length..-1
context_slice = acc[[.., context_range]]
preds = Axon.predict(model, params, context_slice)
logits = preds[[.., -1, ..]]
probs = Axon.Activations.softmax(logits)
# {b, 1}
batch_car = Nx.argmax(probs, axis: 1, keep_axis: true)
Nx.concatenate([acc, batch_car], axis: -1)
end)
end
init_seq = Nx.iota({1, 5})
max_new_tokens = 500
generate_fn.(bigram_model, params, init_seq, max_new_tokens)
|> Nx.to_list()
|> Enum.map(fn encoded -> decode.(encoded) end)
|> List.first()
|> IO.puts()
Hmm… we get a lot of repitition. I suspect the issue is with Nx.argmax
. Need to investigate further!
t = Nx.iota({3, 3}) |> Nx.tril()