Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

RWKV in 150 lines in Livebook

rwkv.livemd

RWKV in 150 lines in Livebook

Mix.install(
  [
    {:axon, "~> 0.5.0"},
    {:bumblebee, "~> 0.2.0"},
    {:exla, "~> 0.5.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Fetch data

File.cd!(__DIR__, fn ->
  System.shell("make")
end)

Define Model and Tokenizer

defmodule Model do
  @model_file "data/RWKV-4-Pile-430M-20220808-8066.pth"
  @n_layer 24
  @n_embed 1024

  def load!() do
    @model_file
    |> Path.absname(__DIR__)
    |> Bumblebee.Conversion.PyTorch.Loader.load!()
    |> reshape_and_cast()
    |> rename_keys()
  end

  def initial_state() do
    Nx.broadcast(0, {@n_layer, 4, @n_embed})
  end

  defp reshape_and_cast(model) do
    for {key, tensor} <- model, into: %{} do
      tensor = if String.contains?(key, ".time_"), do: Nx.squeeze(tensor), else: tensor
      tensor = Nx.as_type(tensor, :f32)
      {key, tensor}
    end
  end

  defp rename_keys(model) do
    model
    |> Map.take(["emb.weight", "head.weight", "ln_out.bias", "ln_out.weight"])
    |> Map.put("blocks", extract_block_params(model))
  end

  defp extract_block_params(model) do
    for i <- 0..(@n_layer - 1) do
      prefix = "blocks.#{i}."

      for {key, params} <- model, String.starts_with?(key, prefix), into: %{} do
        {String.trim_leading(key, prefix), params}
      end
    end
    |> List.to_tuple()
  end
end

defmodule Tokenizer do
  alias Bumblebee.Utils.Tokenizers

  @tokenizer_file "data/20B_tokenizer.json"
  @pad_token "<|padding|>"

  def load!() do
    @tokenizer_file
    |> Path.absname(__DIR__)
    |> Tokenizers.load!()
  end

  def apply(tokenizer, input) do
    Tokenizers.apply(tokenizer, input, @pad_token)
  end

  def decode(tokenizer, token_ids) do
    Tokenizers.decode(tokenizer, token_ids)
  end
end

:ok

Define RWKV and Probs

defmodule RWKV do
  import Nx.Defn

  def rwkv(model, token_id, state) do
    x = model["emb.weight"][token_id]
    x = layer_norm(x, {model, 0, 0})

    {x, state} =
      Enum.reduce(0..(tuple_size(model["blocks"]) - 1), {x, state}, fn i, {x, state} ->
        {dx, state} =
          x
          |> layer_norm({model, i, 1})
          |> time_mixing({model, i, state})

        x = Nx.add(x, dx)

        {dx, state} =
          x
          |> layer_norm({model, i, 2})
          |> channel_mixing({model, i, state})

        x = Nx.add(x, dx)

        {x, state}
      end)

    x = layer_norm(x, {model, "ln_out"})

    probs = softmax(Nx.dot(model["head.weight"], x))

    {probs, state}
  end

  defnp softmax(x) do
    e_x = Nx.exp(x - Nx.reduce_max(x))
    e_x / Nx.sum(e_x)
  end

  defp layer_norm(x, {model, "ln_out"}) do
    w = model["ln_out.weight"]
    b = model["ln_out.bias"]
    layer_norm_(x, w, b)
  end

  defp layer_norm(x, {model, layer_idx, ln_idx}) do
    block = elem(model["blocks"], layer_idx)
    w = block["ln#{ln_idx}.weight"]
    b = block["ln#{ln_idx}.bias"]
    layer_norm_(x, w, b)
  end

  defnp layer_norm_(x, w, b) do
    (x - Nx.mean(x)) / Nx.standard_deviation(x) * w + b
  end

  defp time_mixing(x, {model, layer_idx, state}) do
    %{
      "att.time_decay" => decay,
      "att.time_first" => bonus,
      "att.time_mix_k" => mix_k,
      "att.time_mix_v" => mix_v,
      "att.time_mix_r" => mix_r,
      "att.key.weight" => wk,
      "att.value.weight" => wv,
      "att.receptance.weight" => wr,
      "att.output.weight" => wout
    } = elem(model["blocks"], layer_idx)

    last_x = state[layer_idx][0]
    last_num = state[layer_idx][1]
    last_den = state[layer_idx][2]

    {dx, x, num, den} =
      time_mixing_(
        x,
        last_x,
        last_num,
        last_den,
        decay,
        bonus,
        mix_k,
        mix_v,
        mix_r,
        wk,
        wv,
        wr,
        wout
      )

    new_state =
      state
      |> Nx.put_slice([layer_idx, 0, 0], Nx.reshape(x, {1, 1, :auto}))
      |> Nx.put_slice([layer_idx, 1, 0], Nx.reshape(num, {1, 1, :auto}))
      |> Nx.put_slice([layer_idx, 2, 0], Nx.reshape(den, {1, 1, :auto}))

    {dx, new_state}
  end

  defnp time_mixing_(
          x,
          last_x,
          last_num,
          last_den,
          decay,
          bonus,
          mix_k,
          mix_v,
          mix_r,
          wk,
          wv,
          wr,
          wout
        ) do
    k = Nx.dot(wk, x * mix_k + last_x * (1 - mix_k))
    v = Nx.dot(wv, x * mix_v + last_x * (1 - mix_v))
    r = Nx.dot(wr, x * mix_r + last_x * (1 - mix_r))

    wkv =
      (last_num + Nx.exp(bonus + k) * v) /
        (last_den + Nx.exp(bonus + k))

    rwkv = Nx.sigmoid(r) * wkv

    num = Nx.exp(-Nx.exp(decay)) * last_num + Nx.exp(k) * v
    den = Nx.exp(-Nx.exp(decay)) * last_den + Nx.exp(k)

    {Nx.dot(wout, rwkv), x, num, den}
  end

  defp channel_mixing(x, {model, layer_idx, state}) do
    %{
      "ffn.time_mix_k" => mix_k,
      "ffn.time_mix_r" => mix_r,
      "ffn.key.weight" => wk,
      "ffn.value.weight" => wv,
      "ffn.receptance.weight" => wr
    } = elem(model["blocks"], layer_idx)

    last_x = state[layer_idx][3]

    {dx, x} = channel_mixing_(x, last_x, mix_k, mix_r, wk, wv, wr)
    new_state = Nx.put_slice(state, [layer_idx, 3, 0], Nx.reshape(x, {1, 1, :auto}))

    {dx, new_state}
  end

  defnp channel_mixing_(x, last_x, mix_k, mix_r, wk, wv, wr) do
    k = Nx.dot(wk, x * mix_k + last_x * (1 - mix_k))
    r = Nx.dot(wr, x * mix_r + last_x * (1 - mix_r))
    vk = Nx.dot(wv, Nx.max(k, 0) ** 2)
    {Nx.sigmoid(r) * vk, x}
  end
end

defmodule Probs do
  import Nx.Defn

  @temperature 1.0
  @top_p 0.85

  defn sample_probs(rand_key, probs, temperature \\ @temperature, top_p \\ @top_p) do
    sorted_probs = Nx.sort(probs, direction: :desc)
    cumulative_probs = Nx.cumulative_sum(sorted_probs)
    cutoff = sorted_probs[Nx.argmax(cumulative_probs > top_p)]

    probs = Nx.select(probs < cutoff, 0, probs)
    probs = probs ** (1 / temperature)

    Nx.Random.choice(rand_key, Nx.iota({Nx.size(probs)}), probs / Nx.sum(probs))
  end
end

:ok

Run

context =
  "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

IO.puts(context)

model = Model.load!()
state = Model.initial_state()
tokenizer = Tokenizer.load!()

token_ids =
  Tokenizer.apply(tokenizer, context)
  |> Map.get("input_ids")
  |> Nx.to_flat_list()

# Feed context
{probs, state} =
  Enum.reduce(token_ids, {nil, state}, fn token_id, {_probs, state} ->
    RWKV.rwkv(model, token_id, state)
  end)

# Print next tokens
rand_key = Nx.Random.key(:rand.uniform(2 ** 32))

Enum.reduce(0..100, {probs, state, rand_key}, fn _i, {probs, state, rand_key} ->
  {token_id_tensor, rand_key} = Probs.sample_probs(rand_key, probs)
  token_id = token_id_tensor |> Nx.squeeze() |> Nx.to_number()

  IO.write(Tokenizer.decode(tokenizer, [token_id]))

  {probs, state} = RWKV.rwkv(model, token_id, state)

  {probs, state, rand_key}
end)

:ok