RWKV in 150 lines in Livebook
Mix.install(
[
{:axon, "~> 0.5.0"},
{:bumblebee, "~> 0.2.0"},
{:exla, "~> 0.5.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Fetch data
File.cd!(__DIR__, fn ->
System.shell("make")
end)
Define Model and Tokenizer
defmodule Model do
@model_file "data/RWKV-4-Pile-430M-20220808-8066.pth"
@n_layer 24
@n_embed 1024
def load!() do
@model_file
|> Path.absname(__DIR__)
|> Bumblebee.Conversion.PyTorch.Loader.load!()
|> reshape_and_cast()
|> rename_keys()
end
def initial_state() do
Nx.broadcast(0, {@n_layer, 4, @n_embed})
end
defp reshape_and_cast(model) do
for {key, tensor} <- model, into: %{} do
tensor = if String.contains?(key, ".time_"), do: Nx.squeeze(tensor), else: tensor
tensor = Nx.as_type(tensor, :f32)
{key, tensor}
end
end
defp rename_keys(model) do
model
|> Map.take(["emb.weight", "head.weight", "ln_out.bias", "ln_out.weight"])
|> Map.put("blocks", extract_block_params(model))
end
defp extract_block_params(model) do
for i <- 0..(@n_layer - 1) do
prefix = "blocks.#{i}."
for {key, params} <- model, String.starts_with?(key, prefix), into: %{} do
{String.trim_leading(key, prefix), params}
end
end
|> List.to_tuple()
end
end
defmodule Tokenizer do
alias Bumblebee.Utils.Tokenizers
@tokenizer_file "data/20B_tokenizer.json"
@pad_token "<|padding|>"
def load!() do
@tokenizer_file
|> Path.absname(__DIR__)
|> Tokenizers.load!()
end
def apply(tokenizer, input) do
Tokenizers.apply(tokenizer, input, @pad_token)
end
def decode(tokenizer, token_ids) do
Tokenizers.decode(tokenizer, token_ids)
end
end
:ok
Define RWKV and Probs
defmodule RWKV do
import Nx.Defn
def rwkv(model, token_id, state) do
x = model["emb.weight"][token_id]
x = layer_norm(x, {model, 0, 0})
{x, state} =
Enum.reduce(0..(tuple_size(model["blocks"]) - 1), {x, state}, fn i, {x, state} ->
{dx, state} =
x
|> layer_norm({model, i, 1})
|> time_mixing({model, i, state})
x = Nx.add(x, dx)
{dx, state} =
x
|> layer_norm({model, i, 2})
|> channel_mixing({model, i, state})
x = Nx.add(x, dx)
{x, state}
end)
x = layer_norm(x, {model, "ln_out"})
probs = softmax(Nx.dot(model["head.weight"], x))
{probs, state}
end
defnp softmax(x) do
e_x = Nx.exp(x - Nx.reduce_max(x))
e_x / Nx.sum(e_x)
end
defp layer_norm(x, {model, "ln_out"}) do
w = model["ln_out.weight"]
b = model["ln_out.bias"]
layer_norm_(x, w, b)
end
defp layer_norm(x, {model, layer_idx, ln_idx}) do
block = elem(model["blocks"], layer_idx)
w = block["ln#{ln_idx}.weight"]
b = block["ln#{ln_idx}.bias"]
layer_norm_(x, w, b)
end
defnp layer_norm_(x, w, b) do
(x - Nx.mean(x)) / Nx.standard_deviation(x) * w + b
end
defp time_mixing(x, {model, layer_idx, state}) do
%{
"att.time_decay" => decay,
"att.time_first" => bonus,
"att.time_mix_k" => mix_k,
"att.time_mix_v" => mix_v,
"att.time_mix_r" => mix_r,
"att.key.weight" => wk,
"att.value.weight" => wv,
"att.receptance.weight" => wr,
"att.output.weight" => wout
} = elem(model["blocks"], layer_idx)
last_x = state[layer_idx][0]
last_num = state[layer_idx][1]
last_den = state[layer_idx][2]
{dx, x, num, den} =
time_mixing_(
x,
last_x,
last_num,
last_den,
decay,
bonus,
mix_k,
mix_v,
mix_r,
wk,
wv,
wr,
wout
)
new_state =
state
|> Nx.put_slice([layer_idx, 0, 0], Nx.reshape(x, {1, 1, :auto}))
|> Nx.put_slice([layer_idx, 1, 0], Nx.reshape(num, {1, 1, :auto}))
|> Nx.put_slice([layer_idx, 2, 0], Nx.reshape(den, {1, 1, :auto}))
{dx, new_state}
end
defnp time_mixing_(
x,
last_x,
last_num,
last_den,
decay,
bonus,
mix_k,
mix_v,
mix_r,
wk,
wv,
wr,
wout
) do
k = Nx.dot(wk, x * mix_k + last_x * (1 - mix_k))
v = Nx.dot(wv, x * mix_v + last_x * (1 - mix_v))
r = Nx.dot(wr, x * mix_r + last_x * (1 - mix_r))
wkv =
(last_num + Nx.exp(bonus + k) * v) /
(last_den + Nx.exp(bonus + k))
rwkv = Nx.sigmoid(r) * wkv
num = Nx.exp(-Nx.exp(decay)) * last_num + Nx.exp(k) * v
den = Nx.exp(-Nx.exp(decay)) * last_den + Nx.exp(k)
{Nx.dot(wout, rwkv), x, num, den}
end
defp channel_mixing(x, {model, layer_idx, state}) do
%{
"ffn.time_mix_k" => mix_k,
"ffn.time_mix_r" => mix_r,
"ffn.key.weight" => wk,
"ffn.value.weight" => wv,
"ffn.receptance.weight" => wr
} = elem(model["blocks"], layer_idx)
last_x = state[layer_idx][3]
{dx, x} = channel_mixing_(x, last_x, mix_k, mix_r, wk, wv, wr)
new_state = Nx.put_slice(state, [layer_idx, 3, 0], Nx.reshape(x, {1, 1, :auto}))
{dx, new_state}
end
defnp channel_mixing_(x, last_x, mix_k, mix_r, wk, wv, wr) do
k = Nx.dot(wk, x * mix_k + last_x * (1 - mix_k))
r = Nx.dot(wr, x * mix_r + last_x * (1 - mix_r))
vk = Nx.dot(wv, Nx.max(k, 0) ** 2)
{Nx.sigmoid(r) * vk, x}
end
end
defmodule Probs do
import Nx.Defn
@temperature 1.0
@top_p 0.85
defn sample_probs(rand_key, probs, temperature \\ @temperature, top_p \\ @top_p) do
sorted_probs = Nx.sort(probs, direction: :desc)
cumulative_probs = Nx.cumulative_sum(sorted_probs)
cutoff = sorted_probs[Nx.argmax(cumulative_probs > top_p)]
probs = Nx.select(probs < cutoff, 0, probs)
probs = probs ** (1 / temperature)
Nx.Random.choice(rand_key, Nx.iota({Nx.size(probs)}), probs / Nx.sum(probs))
end
end
:ok
Run
context =
"\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
IO.puts(context)
model = Model.load!()
state = Model.initial_state()
tokenizer = Tokenizer.load!()
token_ids =
Tokenizer.apply(tokenizer, context)
|> Map.get("input_ids")
|> Nx.to_flat_list()
# Feed context
{probs, state} =
Enum.reduce(token_ids, {nil, state}, fn token_id, {_probs, state} ->
RWKV.rwkv(model, token_id, state)
end)
# Print next tokens
rand_key = Nx.Random.key(:rand.uniform(2 ** 32))
Enum.reduce(0..100, {probs, state, rand_key}, fn _i, {probs, state, rand_key} ->
{token_id_tensor, rand_key} = Probs.sample_probs(rand_key, probs)
token_id = token_id_tensor |> Nx.squeeze() |> Nx.to_number()
IO.write(Tokenizer.decode(tokenizer, [token_id]))
{probs, state} = RWKV.rwkv(model, token_id, state)
{probs, state, rand_key}
end)
:ok