Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

RWKV-Axon in 150 lines in Livebook

rwkv_axon.livemd

RWKV-Axon in 150 lines in Livebook

Mix.install(
  [
    {:axon, "~> 0.5.0"},
    {:bumblebee, "~> 0.2.0"},
    {:exla, "~> 0.5.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Fetch data

File.cd!(__DIR__, fn ->
  System.shell("make")
end)

Define Params and Tokenizer

defmodule Params do
  @model_file "data/RWKV-4-Pile-430M-20220808-8066.pth"
  @n_layer 24
  @n_embed 1024

  def load!() do
    @model_file
    |> Path.absname(__DIR__)
    |> Bumblebee.Conversion.PyTorch.Loader.load!()
    |> reshape_and_cast()
    |> to_params()
  end

  def initial_state() do
    Nx.broadcast(0, {@n_layer, 4, @n_embed})
  end

  defp reshape_and_cast(model) do
    for {key, tensor} <- model, into: %{} do
      tensor = if String.contains?(key, ".time_"), do: Nx.squeeze(tensor), else: tensor
      tensor = Nx.as_type(tensor, :f32)
      {key, tensor}
    end
  end

  defp to_params(model) do
    params = %{
      "emb.weight" => model["emb.weight"],
      "head.weight" => model["head.weight"],
      "input_norm" => extract_params(model, "blocks.0.ln0."),
      "output_norm" => extract_params(model, "ln_out.")
    }

    for i <- 0..(@n_layer - 1) do
      [
        {"block_#{i}.layer_norm_0", extract_params(model, "blocks.#{i}.ln1.")},
        {"block_#{i}.layer_norm_1", extract_params(model, "blocks.#{i}.ln2.")},
        {"block_#{i}.time_mixing", extract_params(model, "blocks.#{i}.att.")},
        {"block_#{i}.channel_mixing", extract_params(model, "blocks.#{i}.ffn.")}
      ]
    end
    |> List.flatten()
    |> Enum.into(params)
  end

  defp extract_params(map, prefix) do
    for {key, value} <- map, String.starts_with?(key, prefix), into: %{} do
      {String.trim_leading(key, prefix), value}
    end
  end
end

defmodule Tokenizer do
  alias Bumblebee.Utils.Tokenizers

  @tokenizer_file "data/20B_tokenizer.json"
  @pad_token "<|padding|>"

  def load!() do
    @tokenizer_file
    |> Path.absname(__DIR__)
    |> Tokenizers.load!()
  end

  def apply(tokenizer, input) do
    Tokenizers.apply(tokenizer, input, @pad_token)
  end

  def decode(tokenizer, token_ids) do
    Tokenizers.decode(tokenizer, token_ids)
  end
end

:ok

Define RWKV and Probs

defmodule RWKV do
  import Nx.Defn

  @n_layer 24
  @n_embed 1024

  def template() do
    token = Nx.template({@n_embed}, :f32)
    state = Nx.template({@n_layer, 4, @n_embed}, :f32)

    %{"token" => token, "state" => state}
  end

  def initial_state() do
    Nx.broadcast(0, {@n_layer, 4, @n_embed})
  end

  def predict(model, params, {token_id, state}, opts \\ []) do
    token = params["emb.weight"][token_id]

    {x, state} = Axon.predict(model, params, %{"token" => token, "state" => state}, opts)

    probs = softmax(Nx.dot(params["head.weight"], x))

    {probs, state}
  end

  deftransform model() do
    token = Axon.input("token", shape: {@n_embed})
    state = Axon.input("state", shape: {@n_layer, 4, @n_embed})

    Axon.container({token, state})
    |> layer_norm(name: "input_norm")
    |> blocks(@n_layer)
    |> layer_norm(name: "output_norm")
  end

  deftransformp layer_norm(%Axon{} = input, opts \\ []) do
    weight = Axon.param("weight", fn _ -> {@n_embed} end)
    bias = Axon.param("bias", fn _ -> {@n_embed} end)

    Axon.layer(&amp;layer_norm_impl/4, [input, weight, bias], opts)
  end

  defnp layer_norm_impl({x, state}, w, b, _opts) do
    x = (x - Nx.mean(x)) / Nx.standard_deviation(x) * w + b

    {x, state}
  end

  deftransformp blocks(input, n) do
    Enum.reduce(0..(n - 1), input, fn i, input -> block(input, i) end)
  end

  deftransformp block(%Axon{} = input, idx) do
    input =
      input
      |> layer_norm(name: "block_#{idx}.layer_norm_0")
      |> time_mixing(idx, name: "block_#{idx}.time_mixing")
      |> add_delta_to(input, name: "block_#{idx}.add_delta_to_0")

    input
    |> layer_norm(name: "block_#{idx}.layer_norm_1")
    |> channel_mixing(idx, name: "block_#{idx}.channel_mixing")
    |> add_delta_to(input, name: "block_#{idx}.add_delta_to_1")
  end

  deftransformp time_mixing(%Axon{} = input, idx, opts \\ []) do
    decay = Axon.param("time_decay", fn _ -> {@n_embed} end)
    bonus = Axon.param("time_first", fn _ -> {@n_embed} end)

    mixes = [
      Axon.param("time_mix_k", fn _ -> {@n_embed} end),
      Axon.param("time_mix_v", fn _ -> {@n_embed} end),
      Axon.param("time_mix_r", fn _ -> {@n_embed} end)
    ]

    weights = [
      Axon.param("key.weight", fn _ -> {@n_embed, @n_embed} end),
      Axon.param("value.weight", fn _ -> {@n_embed, @n_embed} end),
      Axon.param("receptance.weight", fn _ -> {@n_embed, @n_embed} end),
      Axon.param("output.weight", fn _ -> {@n_embed, @n_embed} end)
    ]

    opts = Keyword.put(opts, :block_index, idx)

    Axon.layer(&amp;time_mixing_impl/11, [input, decay, bonus] ++ mixes ++ weights, opts)
  end

  defnp time_mixing_impl({x, state}, decay, bonus, mix_k, mix_v, mix_r, wk, wv, wr, wout, opts) do
    idx = opts[:block_index]

    last_x = state[idx][0]
    last_num = state[idx][1]
    last_den = state[idx][2]

    k = Nx.dot(wk, x * mix_k + last_x * (1 - mix_k))
    v = Nx.dot(wv, x * mix_v + last_x * (1 - mix_v))
    r = Nx.dot(wr, x * mix_r + last_x * (1 - mix_r))

    wkv =
      (last_num + Nx.exp(bonus + k) * v) /
        (last_den + Nx.exp(bonus + k))

    rwkv = Nx.sigmoid(r) * wkv
    dx = Nx.dot(wout, rwkv)

    num = Nx.exp(-Nx.exp(decay)) * last_num + Nx.exp(k) * v
    den = Nx.exp(-Nx.exp(decay)) * last_den + Nx.exp(k)

    state =
      state
      |> put_state(idx, 0, x)
      |> put_state(idx, 1, num)
      |> put_state(idx, 2, den)

    {dx, state}
  end

  deftransformp channel_mixing(%Axon{} = input, idx, opts \\ []) do
    mixes = [
      Axon.param("time_mix_k", fn _ -> {@n_embed} end),
      Axon.param("time_mix_r", fn _ -> {@n_embed} end)
    ]

    weights = [
      Axon.param("key.weight", fn _ -> {@n_embed, @n_embed} end),
      Axon.param("value.weight", fn _ -> {@n_embed, @n_embed} end),
      Axon.param("receptance.weight", fn _ -> {@n_embed, @n_embed} end)
    ]

    opts = Keyword.put(opts, :block_index, idx)

    Axon.layer(&amp;channel_mixing_impl/7, [input] ++ mixes ++ weights, opts)
  end

  defnp channel_mixing_impl({x, state}, mix_k, mix_r, wk, wv, wr, opts) do
    idx = opts[:block_index]

    last_x = state[idx][3]

    k = Nx.dot(wk, x * mix_k + last_x * (1 - mix_k))
    r = Nx.dot(wr, x * mix_r + last_x * (1 - mix_r))
    vk = Nx.dot(wv, Nx.max(k, 0) ** 2)
    dx = Nx.sigmoid(r) * vk

    state = put_state(state, idx, 3, x)

    {dx, state}
  end

  deftransformp add_delta_to(%Axon{} = delta, %Axon{} = input, _opts) do
    Axon.container({delta, input})
    |> Axon.nx(fn {{dx, new_state}, {x, _old_state}} ->
      {Nx.add(x, dx), new_state}
    end)
  end

  defnp softmax(x) do
    e_x = Nx.exp(x - Nx.reduce_max(x))
    e_x / Nx.sum(e_x)
  end

  defnp put_state(state, layer_idx, state_idx, value) do
    Nx.put_slice(state, [layer_idx, state_idx, 0], Nx.reshape(value, {1, 1, :auto}))
  end
end

defmodule Probs do
  import Nx.Defn

  @temperature 1.0
  @top_p 0.85

  defn sample_probs(rand_key, probs, temperature \\ @temperature, top_p \\ @top_p) do
    sorted_probs = Nx.sort(probs, direction: :desc)
    cumulative_probs = Nx.cumulative_sum(sorted_probs)
    cutoff = sorted_probs[Nx.argmax(cumulative_probs > top_p)]

    probs = Nx.select(probs < cutoff, 0, probs)
    probs = probs ** (1 / temperature)

    Nx.Random.choice(rand_key, Nx.iota({Nx.size(probs)}), probs / Nx.sum(probs))
  end
end

:ok

Run

defmodule Script do
  def main() do
    model = RWKV.model()
    state = RWKV.initial_state()
    params = Params.load!()
    tokenizer = Tokenizer.load!()

    context =
      "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

    token_ids =
      Tokenizer.apply(tokenizer, context)
      |> Map.get("input_ids")
      |> Nx.to_flat_list()

    # Feed context
    {probs, state} =
      Enum.reduce(token_ids, {nil, state}, fn token_id, {_probs, state} ->
        RWKV.predict(model, params, {token_id, state})
      end)

    IO.puts(context)

    rand_key = Nx.Random.key(:rand.uniform(2 ** 32))

    # Print next tokens
    Enum.reduce(0..100, {probs, state, rand_key}, fn _i, {probs, state, rand_key} ->
      {token_id_tensor, rand_key} = Probs.sample_probs(rand_key, probs)
      token_id = token_id_tensor |> Nx.squeeze() |> Nx.to_number()

      IO.write(Tokenizer.decode(tokenizer, [token_id]))

      {probs, state} = RWKV.predict(model, params, {token_id, state})

      {probs, state, rand_key}
    end)

    :ok
  end
end

Script.main()