RWKV-Axon in 150 lines in Livebook
Mix.install(
[
{:axon, "~> 0.5.0"},
{:bumblebee, "~> 0.2.0"},
{:exla, "~> 0.5.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Fetch data
File.cd!(__DIR__, fn ->
System.shell("make")
end)
Define Params and Tokenizer
defmodule Params do
@model_file "data/RWKV-4-Pile-430M-20220808-8066.pth"
@n_layer 24
@n_embed 1024
def load!() do
@model_file
|> Path.absname(__DIR__)
|> Bumblebee.Conversion.PyTorch.Loader.load!()
|> reshape_and_cast()
|> to_params()
end
def initial_state() do
Nx.broadcast(0, {@n_layer, 4, @n_embed})
end
defp reshape_and_cast(model) do
for {key, tensor} <- model, into: %{} do
tensor = if String.contains?(key, ".time_"), do: Nx.squeeze(tensor), else: tensor
tensor = Nx.as_type(tensor, :f32)
{key, tensor}
end
end
defp to_params(model) do
params = %{
"emb.weight" => model["emb.weight"],
"head.weight" => model["head.weight"],
"input_norm" => extract_params(model, "blocks.0.ln0."),
"output_norm" => extract_params(model, "ln_out.")
}
for i <- 0..(@n_layer - 1) do
[
{"block_#{i}.layer_norm_0", extract_params(model, "blocks.#{i}.ln1.")},
{"block_#{i}.layer_norm_1", extract_params(model, "blocks.#{i}.ln2.")},
{"block_#{i}.time_mixing", extract_params(model, "blocks.#{i}.att.")},
{"block_#{i}.channel_mixing", extract_params(model, "blocks.#{i}.ffn.")}
]
end
|> List.flatten()
|> Enum.into(params)
end
defp extract_params(map, prefix) do
for {key, value} <- map, String.starts_with?(key, prefix), into: %{} do
{String.trim_leading(key, prefix), value}
end
end
end
defmodule Tokenizer do
alias Bumblebee.Utils.Tokenizers
@tokenizer_file "data/20B_tokenizer.json"
@pad_token "<|padding|>"
def load!() do
@tokenizer_file
|> Path.absname(__DIR__)
|> Tokenizers.load!()
end
def apply(tokenizer, input) do
Tokenizers.apply(tokenizer, input, @pad_token)
end
def decode(tokenizer, token_ids) do
Tokenizers.decode(tokenizer, token_ids)
end
end
:ok
Define RWKV and Probs
defmodule RWKV do
import Nx.Defn
@n_layer 24
@n_embed 1024
def template() do
token = Nx.template({@n_embed}, :f32)
state = Nx.template({@n_layer, 4, @n_embed}, :f32)
%{"token" => token, "state" => state}
end
def initial_state() do
Nx.broadcast(0, {@n_layer, 4, @n_embed})
end
def predict(model, params, {token_id, state}, opts \\ []) do
token = params["emb.weight"][token_id]
{x, state} = Axon.predict(model, params, %{"token" => token, "state" => state}, opts)
probs = softmax(Nx.dot(params["head.weight"], x))
{probs, state}
end
deftransform model() do
token = Axon.input("token", shape: {@n_embed})
state = Axon.input("state", shape: {@n_layer, 4, @n_embed})
Axon.container({token, state})
|> layer_norm(name: "input_norm")
|> blocks(@n_layer)
|> layer_norm(name: "output_norm")
end
deftransformp layer_norm(%Axon{} = input, opts \\ []) do
weight = Axon.param("weight", fn _ -> {@n_embed} end)
bias = Axon.param("bias", fn _ -> {@n_embed} end)
Axon.layer(&layer_norm_impl/4, [input, weight, bias], opts)
end
defnp layer_norm_impl({x, state}, w, b, _opts) do
x = (x - Nx.mean(x)) / Nx.standard_deviation(x) * w + b
{x, state}
end
deftransformp blocks(input, n) do
Enum.reduce(0..(n - 1), input, fn i, input -> block(input, i) end)
end
deftransformp block(%Axon{} = input, idx) do
input =
input
|> layer_norm(name: "block_#{idx}.layer_norm_0")
|> time_mixing(idx, name: "block_#{idx}.time_mixing")
|> add_delta_to(input, name: "block_#{idx}.add_delta_to_0")
input
|> layer_norm(name: "block_#{idx}.layer_norm_1")
|> channel_mixing(idx, name: "block_#{idx}.channel_mixing")
|> add_delta_to(input, name: "block_#{idx}.add_delta_to_1")
end
deftransformp time_mixing(%Axon{} = input, idx, opts \\ []) do
decay = Axon.param("time_decay", fn _ -> {@n_embed} end)
bonus = Axon.param("time_first", fn _ -> {@n_embed} end)
mixes = [
Axon.param("time_mix_k", fn _ -> {@n_embed} end),
Axon.param("time_mix_v", fn _ -> {@n_embed} end),
Axon.param("time_mix_r", fn _ -> {@n_embed} end)
]
weights = [
Axon.param("key.weight", fn _ -> {@n_embed, @n_embed} end),
Axon.param("value.weight", fn _ -> {@n_embed, @n_embed} end),
Axon.param("receptance.weight", fn _ -> {@n_embed, @n_embed} end),
Axon.param("output.weight", fn _ -> {@n_embed, @n_embed} end)
]
opts = Keyword.put(opts, :block_index, idx)
Axon.layer(&time_mixing_impl/11, [input, decay, bonus] ++ mixes ++ weights, opts)
end
defnp time_mixing_impl({x, state}, decay, bonus, mix_k, mix_v, mix_r, wk, wv, wr, wout, opts) do
idx = opts[:block_index]
last_x = state[idx][0]
last_num = state[idx][1]
last_den = state[idx][2]
k = Nx.dot(wk, x * mix_k + last_x * (1 - mix_k))
v = Nx.dot(wv, x * mix_v + last_x * (1 - mix_v))
r = Nx.dot(wr, x * mix_r + last_x * (1 - mix_r))
wkv =
(last_num + Nx.exp(bonus + k) * v) /
(last_den + Nx.exp(bonus + k))
rwkv = Nx.sigmoid(r) * wkv
dx = Nx.dot(wout, rwkv)
num = Nx.exp(-Nx.exp(decay)) * last_num + Nx.exp(k) * v
den = Nx.exp(-Nx.exp(decay)) * last_den + Nx.exp(k)
state =
state
|> put_state(idx, 0, x)
|> put_state(idx, 1, num)
|> put_state(idx, 2, den)
{dx, state}
end
deftransformp channel_mixing(%Axon{} = input, idx, opts \\ []) do
mixes = [
Axon.param("time_mix_k", fn _ -> {@n_embed} end),
Axon.param("time_mix_r", fn _ -> {@n_embed} end)
]
weights = [
Axon.param("key.weight", fn _ -> {@n_embed, @n_embed} end),
Axon.param("value.weight", fn _ -> {@n_embed, @n_embed} end),
Axon.param("receptance.weight", fn _ -> {@n_embed, @n_embed} end)
]
opts = Keyword.put(opts, :block_index, idx)
Axon.layer(&channel_mixing_impl/7, [input] ++ mixes ++ weights, opts)
end
defnp channel_mixing_impl({x, state}, mix_k, mix_r, wk, wv, wr, opts) do
idx = opts[:block_index]
last_x = state[idx][3]
k = Nx.dot(wk, x * mix_k + last_x * (1 - mix_k))
r = Nx.dot(wr, x * mix_r + last_x * (1 - mix_r))
vk = Nx.dot(wv, Nx.max(k, 0) ** 2)
dx = Nx.sigmoid(r) * vk
state = put_state(state, idx, 3, x)
{dx, state}
end
deftransformp add_delta_to(%Axon{} = delta, %Axon{} = input, _opts) do
Axon.container({delta, input})
|> Axon.nx(fn {{dx, new_state}, {x, _old_state}} ->
{Nx.add(x, dx), new_state}
end)
end
defnp softmax(x) do
e_x = Nx.exp(x - Nx.reduce_max(x))
e_x / Nx.sum(e_x)
end
defnp put_state(state, layer_idx, state_idx, value) do
Nx.put_slice(state, [layer_idx, state_idx, 0], Nx.reshape(value, {1, 1, :auto}))
end
end
defmodule Probs do
import Nx.Defn
@temperature 1.0
@top_p 0.85
defn sample_probs(rand_key, probs, temperature \\ @temperature, top_p \\ @top_p) do
sorted_probs = Nx.sort(probs, direction: :desc)
cumulative_probs = Nx.cumulative_sum(sorted_probs)
cutoff = sorted_probs[Nx.argmax(cumulative_probs > top_p)]
probs = Nx.select(probs < cutoff, 0, probs)
probs = probs ** (1 / temperature)
Nx.Random.choice(rand_key, Nx.iota({Nx.size(probs)}), probs / Nx.sum(probs))
end
end
:ok
Run
defmodule Script do
def main() do
model = RWKV.model()
state = RWKV.initial_state()
params = Params.load!()
tokenizer = Tokenizer.load!()
context =
"\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
token_ids =
Tokenizer.apply(tokenizer, context)
|> Map.get("input_ids")
|> Nx.to_flat_list()
# Feed context
{probs, state} =
Enum.reduce(token_ids, {nil, state}, fn token_id, {_probs, state} ->
RWKV.predict(model, params, {token_id, state})
end)
IO.puts(context)
rand_key = Nx.Random.key(:rand.uniform(2 ** 32))
# Print next tokens
Enum.reduce(0..100, {probs, state, rand_key}, fn _i, {probs, state, rand_key} ->
{token_id_tensor, rand_key} = Probs.sample_probs(rand_key, probs)
token_id = token_id_tensor |> Nx.squeeze() |> Nx.to_number()
IO.write(Tokenizer.decode(tokenizer, [token_id]))
{probs, state} = RWKV.predict(model, params, {token_id, state})
{probs, state, rand_key}
end)
:ok
end
end
Script.main()