Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

GPT

gpt.livemd

GPT

Mix.install([
  {:axon, ">= 0.5.0"},
  {:nx, "~> 0.7.0"},
  {:scholar, "~> 0.3.0"},
  {:exla, "~> 0.2"},
  {:kino, ">= 0.9.0"},
  {:vega_lite, "~> 0.1.6"},
  {:kino_vega_lite, "~> 0.1.11"}
])

Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)

Tokenizer

text = File.read!("input.txt")
chars =
  text
  |> to_charlist()
  |> MapSet.new()
  |> MapSet.to_list()
  |> Enum.sort()

vocab_size = chars |> length()
stoi =
  chars
  |> Enum.with_index()
  |> Map.new()

itos =
  chars
  |> Enum.with_index()
  |> Map.new(fn {c, i} -> {i, c} end)

encode = fn string ->
  string
  |> to_charlist()
  |> Enum.map(fn c -> stoi[c] end)
end

decode = fn list ->
  list
  |> Enum.map(fn l -> itos[l] end)
  |> to_string()
end

encode.("hii there") |> IO.inspect()
encode.("hii there") |> decode.()
data =
  text
  |> encode.()
  |> Nx.tensor()

Nx.shape(data) |> IO.inspect()
Nx.type(data) |> IO.inspect()

data[0..1000] |> IO.inspect()
size = Nx.size(data)
n = 0.9 * size |> round()
train_data = data[0..n]
val_data = data[n..size - 1]

Data loader

block_size = 8
train_data[0..block_size] # Elixir range is inclusive
x = train_data[0..(block_size - 1)]
y = train_data[1..block_size]

0..(block_size - 1)
|> Enum.each(fn t ->
  context = x[0..t] |> Nx.to_flat_list()
  target = y[t] |> Nx.to_number()
  IO.puts("when input is #{inspect(context)} -- target: #{target}")
end)
key = Nx.Random.key(1337)

get_batch = fn split, batch_size, block_size ->
  data = if split == "train", do: train_data, else: val_data
  {ix, _new_key} = Nx.Random.randint(key, 0, Nx.size(data) - block_size, shape: {batch_size})
  x =
    Nx.to_list(ix)
    |> Enum.map(fn i ->
      data[i..(i + block_size - 1)]
    end)
    |> Nx.stack()
  y =
    Nx.to_list(ix)
    |> Enum.map(fn i ->
      data[(i + 1)..(i + block_size)]
    end)
    |> Nx.stack()
  
  {x, y}
end

{xb, yb} = get_batch.("train", 4, 8)
IO.puts("inputs")
Nx.shape(xb) |> IO.inspect(label: "xb shape")
xb |> IO.inspect()
IO.puts("targets")
Nx.shape(yb) |> IO.inspect(label: "yb shape")
yb |> IO.inspect()

batch_size = 4
block_size = 8

for b <- 0..(batch_size - 1),
    t <- 0..(block_size - 1) do
  context = xb[b][0..t] |> Nx.to_flat_list()
  target = yb[b][t] |> Nx.to_number()
  IO.puts("when input is #{inspect(context, charlists: :as_lists)} -- target: #{target}")
end
xb
random_seed = 1337

# Since Axon Loop expects a stream or Enum, let's reimplement
# get_batch to be streamed
get_batch_stream = fn split, batch_size, block_size ->
  Stream.resource(
    fn ->
      Nx.Random.key(random_seed)
    end,
    fn key ->
      data = if split == "train", do: train_data, else: val_data
      {ix, new_key} = Nx.Random.randint(key, 0, Nx.size(data) - block_size, shape: {batch_size})
      ix = Nx.to_list(ix)

      x =
        ix
        |> Enum.map(fn i ->
          data[i..(i + block_size - 1)]
        end)
        |> Nx.stack()

      y =
        ix
        |> Enum.map(fn i ->
          data[(i + 1)..(i + block_size)]
        end)
        |> Nx.stack()

      {b, t} = Nx.shape(y)
      flattened_y = Nx.reshape(y, {b * t})
      out_data = {x, flattened_y}
      {[out_data], new_key}
    end,
    fn _ -> :ok end
  )
end

Bigram Language Model

key = Nx.Random.key(1337)

defmodule Train do
  import Nx.Defn

  defn predict_fn(model_predict_fn, params, input) do
    %{prediction: preds} = out = model_predict_fn.(params, input)
    {b, t, c} = Nx.shape(preds)
    logits = Nx.reshape(preds, {b * t, c})
    %{ out | prediction: logits }
  end

  defn loss_fn(targets, logits) do
    Axon.Losses.categorical_cross_entropy(targets, logits,
      # PyTorch default
      reduction: :mean,
      # The labels are a sparse tensor with integer values
      sparse: true,
      # unnormalized logits
      from_logits: true
    )
  end
end

bigram_model =
  Axon.input("input")
  |> Axon.embedding(65, 65)

{init_fn, predict_fn} = Axon.build(bigram_model, mode: :train)

custom_predict_fn = &amp;Train.predict_fn(predict_fn, &amp;1, &amp;2)
custom_loss_fn = &amp;Train.loss_fn(&amp;1, &amp;2)
train_batch = get_batch_stream.("train", 4, 8)

params =
  {init_fn, custom_predict_fn}
  |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw())
  |> Axon.Loop.run(train_batch, %{}, epochs: 1, iterations: 100, compiler: EXLA)
# Generate from the model

generate_fn = fn model, params, init_seq, max_new_tokens ->
  Enum.reduce(1..max_new_tokens, init_seq, fn _i, acc ->
    {_b, t} = Nx.shape(acc)

    context_length = min(t, block_size)
    context_range = -context_length..-1
    context_slice = acc[[.., context_range]]

    preds = Axon.predict(model, params, context_slice)
    logits = preds[[.., -1, ..]]
    probs = Axon.Activations.softmax(logits)
    # {b, 1}
    batch_car = Nx.argmax(probs, axis: 1, keep_axis: true)

    Nx.concatenate([acc, batch_car], axis: -1)
  end)
end

init_seq = Nx.iota({1, 5})
max_new_tokens = 500

generate_fn.(bigram_model, params, init_seq, max_new_tokens)
|> Nx.to_list()
|> Enum.map(fn encoded -> decode.(encoded) end)
|> List.first()
|> IO.puts()

Hmm… we get a lot of repitition. I suspect the issue is with Nx.argmax. Need to investigate further!

t = Nx.iota({3, 3}) |> Nx.tril()