Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

GPT2

gpt2.livemd

GPT2

Mix.install(
  [
    {:nx, github: "elixir-nx/nx", sparse: "nx", ref: "cef5a12d", override: true},
    {:exla, github: "elixir-nx/nx", sparse: "exla", ref: "cef5a12d", override: true}
  ],
  system_env: %{
    "XLA_ARCHIVE_URL" =>
      "https://static.jonatanklosko.com/builds/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
  }
)

Nx.global_default_backend(EXLA.Backend)

Nx.iota({3})

Section

params =
  "/home/kuku/Downloads/model.safetensors"
  |> File.read!()
  |> Safetensors.load!()
blocks =
  Enum.reduce(params, %{}, fn {key, value}, acc ->
    case String.split(key, ".") do
      ["h", block_num, inner_block_name, layer_name, param_name] ->
        init = %{inner_block_name => %{layer_name => %{param_name => value}}}

        Map.update(acc, "block_#{block_num}", init, fn block_params ->
          inner_init = %{layer_name => %{param_name => value}}

          Map.update(block_params, inner_block_name, inner_init, fn inner_block_params ->
            layer_init = %{param_name => value}

            Map.update(inner_block_params, layer_name, layer_init, fn layer_params ->
              Map.put(layer_params, param_name, value)
            end)
          end)
        end)

      ["h", block_num, layer_name, param_name] ->
        init = %{layer_name => %{param_name => value}}

        Map.update(acc, "block_#{block_num}", init, fn block_params ->
          layer_init = %{param_name => value}

          Map.update(block_params, layer_name, layer_init, fn layer_params ->
            Map.put(layer_params, param_name, value)
          end)
        end)

      [layer_name, param_name] ->
        Map.update(acc, layer_name, %{param_name => value}, fn layer_params ->
          Map.put(layer_params, param_name, value)
        end)
    end
  end)
defmodule Encoder do
  defstruct [:tokenizer]

  def new(model_id) do
    {:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained(model_id)
    %__MODULE__{tokenizer: tokenizer}
  end

  def encode(%{tokenizer: tokenizer}, text) do
    {:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)

    Nx.tensor(Tokenizers.Encoding.get_ids(encoding))
    |> Nx.new_axis(0)
  end

  def decode(%{tokenizer: tokenizer}, id) do
    {:ok, token} = Tokenizers.Tokenizer.decode(tokenizer, [id])
    token
  end
end
encoder = Encoder.new("gpt2")
Encoder.encode(encoder, "Hello world!")
Encoder.decode(encoder, 995)
defmodule GPT do
  import Nx.Defn

  defn predict(input, wte, wpe, blocks, ln_f, opts \\ []) do
    opts = keyword!(opts, n_head: 12)

    input
    |> embedding(wte, wpe)
    |> transformer(blocks, n_head: opts[:n_head])
    |> layer_norm(ln_f)
    |> Nx.dot([-1], wte["weight"], [-1])
  end

  defn embedding(x, %{"weight" => wte}, %{"weight" => wpe}) do
    position_ids = Nx.iota({Nx.axis_size(x, 0), Nx.axis_size(x, 1)}, axis: -1)
    Nx.take(wte, x) + Nx.take(wpe, position_ids)
  end

  deftransform transformer(x, params, opts \\ []) do
    Enum.reduce(params, x, fn {_block_name, block_params}, x ->
      transformer_block(x, block_params, n_head: opts[:n_head])
    end)
  end

  defn transformer_block(
         x,
         %{"mlp" => mlp, "attn" => attn, "ln_1" => ln_1, "ln_2" => ln_2},
         opts \\ []
       ) do
    opts = keyword!(opts, n_head: 12)

    x
    |> layer_norm(ln_1)
    |> mha(attn, n_head: opts[:n_head])
    |> then(fn x ->
      x
      |> layer_norm(ln_2)
      |> ffn(mlp)
      |> Nx.add(x)
    end)
  end

  defn mha(x, %{"c_attn" => c_attn, "c_proj" => c_proj}, opts \\ []) do
    opts = keyword!(opts, n_head: 12)
    x = linear(x, c_attn)

    {q, k, v} = split_qkv(x)
    q = split_heads(q, opts[:n_head])
    k = split_heads(k, opts[:n_head])
    v = split_heads(v, opts[:n_head])

    causal_mask = (1 - Nx.tri(Nx.axis_size(x, 0), Nx.axis_size(x, 0))) * -1.0e10
    out = attention(q, k, v, causal_mask)

    linear(out, c_proj)
  end

  deftransformp split_qkv(tensor) do
    split_size = div(Nx.axis_size(tensor, -1), 3)
    q = tensor[[0..-1//1, 0..-1//1, 0..(split_size - 1)]]
    k = tensor[[0..-1//1, 0..-1//1, split_size..(2 * split_size - 1)]]
    v = tensor[[0..-1//1, 0..-1//1, (2 * split_size)..-1//1]]
    {q, k, v}
  end

  deftransformp split_heads(tensor, n_head) do
    {batch, seq, _dim} = Nx.shape(tensor)
    Nx.reshape(tensor, {batch, seq, n_head, :auto})
  end

  defn attention(q, k, v, mask) do
    k = Nx.transpose(k, axes: [0, 2, 1, 3])
    q = Nx.transpose(q, axes: [0, 2, 1, 3])
    v = Nx.transpose(v, axes: [0, 2, 1, 3])

    q
    |> Nx.divide(Nx.sqrt(Nx.axis_size(q, -1)))
    |> Nx.dot([3], [0, 1], k, [3], [0, 1])
    |> softmax()
    |> Nx.add(mask)
    |> Nx.dot([3], [0, 1], v, [2], [0, 1])
    |> Nx.transpose(axes: [0, 2, 1, 3])
    |> flatten_heads()
  end

  deftransformp flatten_heads(tensor) do
    shape = Nx.shape(tensor)
    rank = Nx.rank(tensor)

    new_shape =
      shape
      |> Tuple.delete_at(rank - 1)
      |> put_elem(rank - 2, :auto)

    Nx.reshape(tensor, new_shape)
  end

  defn ffn(x, %{"c_fc" => c_fc, "c_proj" => c_proj}) do
    x
    |> linear(c_fc)
    |> gelu()
    |> linear(c_proj)
  end

  @doc """
  Linear layer.

  ## Examples

      iex> {x, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {32, 128})
      iex> {w, key} = Nx.Random.uniform(key, shape: {128, 256})
      iex> {b, _key} = Nx.Random.uniform(key, shape: {256})
      iex> out = GPT.linear(x, %{"weight" => w, "bias" => b})
      iex> Nx.shape(out)
      {32, 256}
      iex> Nx.type(out)
      {:f, 32}
  """
  defn linear(x, %{"weight" => w, "bias" => b}) do
    x |> Nx.dot(w) |> Nx.add(b)
  end

  @doc """
  Applies Layer Normalization.

  ## Examples

      iex> x = Nx.tensor([[2, 2, 3], [-5, 0, 1]])
      iex> actual = GPT.layer_norm(x, %{"weight" => Nx.broadcast(1.0, {2, 1}), "bias" => Nx.broadcast(0.0, {2, 1})})
      iex> expected = Nx.tensor([
      ...>   [-0.70709, -0.70709, 1.41418],
      ...>   [-1.397, 0.508, 0.889]
      ...> ])
      iex> Nx.all_close(actual, expected, atol: 1.0e-3)
      #Nx.Tensor<
        u8
        1
      >
  """
  defn layer_norm(x, %{"weight" => w, "bias" => b}, opts \\ []) do
    opts = keyword!(opts, eps: 1.0e-5)
    mean = Nx.mean(x, axes: [-1], keep_axes: true)
    variance = Nx.variance(x, axes: [-1], keep_axes: true)
    x = (x - mean) / Nx.sqrt(variance + opts[:eps])
    w * x + b
  end

  @doc """
  Applies GeLU Activation.

  ## Examples

      iex> actual = GPT.gelu(Nx.tensor([[1, 2], [-2, 0.5]]))
      iex> expected = Nx.tensor(([[0.84119, 1.9546], [-0.0454, 0.34571]]))
      iex> Nx.all_close(actual, expected, atol: 1.0e-3)
      #Nx.Tensor<
        u8
        1
      >
  """
  defn gelu(x) do
    0.5 * x * (1 + Nx.tanh(Nx.sqrt(2 / Nx.Constants.pi()) * (x + 0.044715 * Nx.pow(x, 3))))
  end

  @doc """
  Applies Softmax Activation.

  ## Examples

      iex> actual = GPT.softmax(Nx.tensor([[2, 100], [-5, 0]]))
      iex> expected = Nx.tensor([[2.74878501e-43, 1.0],[6.69285092e-03, 9.93307149e-01]])
      iex> Nx.all_close(actual, expected, atol: 1.0e-3)
      #Nx.Tensor<
        u8
        1
      >
  """
  defn softmax(x) do
    exp_x = Nx.exp(x - Nx.reduce_max(x, axes: [-1], keep_axes: true))
    exp_x / Nx.sum(exp_x, axes: [-1], keep_axes: true)
  end
end
predict_fun = fn input, params ->
  {wte, params} = Map.pop!(params, "wte")
  {wpe, params} = Map.pop!(params, "wpe")
  {ln_f, params} = Map.pop!(params, "ln_f")
  GPT.predict(input, wte, wpe, params, ln_f)
end

predict_fun = Nx.Defn.jit(predict_fun, compiler: EXLA)
input = Encoder.encode(encoder, "Hello World!")
predict_fun.(input, blocks)
output = predict_fun.(input, blocks)
logits = output[[.., -1]]
new_token = Nx.argmax(logits, axis: -1)
defmodule Generator do
  def generate(predict_fun, encoder, input, params, eos_id, max_seq_len) do
    encoded_input = Encoder.encode(encoder, input)
    seq_len = Nx.axis_size(encoded_input, 1)

    Enum.reduce_while(seq_len..max_seq_len, encoded_input, fn _idx, current_input ->
      output = predict_fun.(current_input, params)
      logits = output[[.., -1]]
      next_token = Nx.argmax(logits, axis: -1, keep_axis: true)

      if eos_id == Nx.to_number(Nx.squeeze(next_token)) do
        {:halt, current_input}
      else
        IO.write("#{Encoder.decode(encoder, Nx.to_number(Nx.squeeze(next_token)))}")
        new_sequence = Nx.concatenate([current_input, next_token], axis: -1)
        {:cont, new_sequence}
      end
    end)
  end
end
Generator.generate(predict_fun, encoder, "Elixir is", blocks, 50256, 256)