GPT2
Mix.install(
[
{:nx, github: "elixir-nx/nx", sparse: "nx", ref: "cef5a12d", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", ref: "cef5a12d", override: true}
],
system_env: %{
"XLA_ARCHIVE_URL" =>
"https://static.jonatanklosko.com/builds/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
}
)
Nx.global_default_backend(EXLA.Backend)
Nx.iota({3})
Section
params =
"/home/kuku/Downloads/model.safetensors"
|> File.read!()
|> Safetensors.load!()
blocks =
Enum.reduce(params, %{}, fn {key, value}, acc ->
case String.split(key, ".") do
["h", block_num, inner_block_name, layer_name, param_name] ->
init = %{inner_block_name => %{layer_name => %{param_name => value}}}
Map.update(acc, "block_#{block_num}", init, fn block_params ->
inner_init = %{layer_name => %{param_name => value}}
Map.update(block_params, inner_block_name, inner_init, fn inner_block_params ->
layer_init = %{param_name => value}
Map.update(inner_block_params, layer_name, layer_init, fn layer_params ->
Map.put(layer_params, param_name, value)
end)
end)
end)
["h", block_num, layer_name, param_name] ->
init = %{layer_name => %{param_name => value}}
Map.update(acc, "block_#{block_num}", init, fn block_params ->
layer_init = %{param_name => value}
Map.update(block_params, layer_name, layer_init, fn layer_params ->
Map.put(layer_params, param_name, value)
end)
end)
[layer_name, param_name] ->
Map.update(acc, layer_name, %{param_name => value}, fn layer_params ->
Map.put(layer_params, param_name, value)
end)
end
end)
defmodule Encoder do
defstruct [:tokenizer]
def new(model_id) do
{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained(model_id)
%__MODULE__{tokenizer: tokenizer}
end
def encode(%{tokenizer: tokenizer}, text) do
{:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)
Nx.tensor(Tokenizers.Encoding.get_ids(encoding))
|> Nx.new_axis(0)
end
def decode(%{tokenizer: tokenizer}, id) do
{:ok, token} = Tokenizers.Tokenizer.decode(tokenizer, [id])
token
end
end
encoder = Encoder.new("gpt2")
Encoder.encode(encoder, "Hello world!")
Encoder.decode(encoder, 995)
defmodule GPT do
import Nx.Defn
defn predict(input, wte, wpe, blocks, ln_f, opts \\ []) do
opts = keyword!(opts, n_head: 12)
input
|> embedding(wte, wpe)
|> transformer(blocks, n_head: opts[:n_head])
|> layer_norm(ln_f)
|> Nx.dot([-1], wte["weight"], [-1])
end
defn embedding(x, %{"weight" => wte}, %{"weight" => wpe}) do
position_ids = Nx.iota({Nx.axis_size(x, 0), Nx.axis_size(x, 1)}, axis: -1)
Nx.take(wte, x) + Nx.take(wpe, position_ids)
end
deftransform transformer(x, params, opts \\ []) do
Enum.reduce(params, x, fn {_block_name, block_params}, x ->
transformer_block(x, block_params, n_head: opts[:n_head])
end)
end
defn transformer_block(
x,
%{"mlp" => mlp, "attn" => attn, "ln_1" => ln_1, "ln_2" => ln_2},
opts \\ []
) do
opts = keyword!(opts, n_head: 12)
x
|> layer_norm(ln_1)
|> mha(attn, n_head: opts[:n_head])
|> then(fn x ->
x
|> layer_norm(ln_2)
|> ffn(mlp)
|> Nx.add(x)
end)
end
defn mha(x, %{"c_attn" => c_attn, "c_proj" => c_proj}, opts \\ []) do
opts = keyword!(opts, n_head: 12)
x = linear(x, c_attn)
{q, k, v} = split_qkv(x)
q = split_heads(q, opts[:n_head])
k = split_heads(k, opts[:n_head])
v = split_heads(v, opts[:n_head])
causal_mask = (1 - Nx.tri(Nx.axis_size(x, 0), Nx.axis_size(x, 0))) * -1.0e10
out = attention(q, k, v, causal_mask)
linear(out, c_proj)
end
deftransformp split_qkv(tensor) do
split_size = div(Nx.axis_size(tensor, -1), 3)
q = tensor[[0..-1//1, 0..-1//1, 0..(split_size - 1)]]
k = tensor[[0..-1//1, 0..-1//1, split_size..(2 * split_size - 1)]]
v = tensor[[0..-1//1, 0..-1//1, (2 * split_size)..-1//1]]
{q, k, v}
end
deftransformp split_heads(tensor, n_head) do
{batch, seq, _dim} = Nx.shape(tensor)
Nx.reshape(tensor, {batch, seq, n_head, :auto})
end
defn attention(q, k, v, mask) do
k = Nx.transpose(k, axes: [0, 2, 1, 3])
q = Nx.transpose(q, axes: [0, 2, 1, 3])
v = Nx.transpose(v, axes: [0, 2, 1, 3])
q
|> Nx.divide(Nx.sqrt(Nx.axis_size(q, -1)))
|> Nx.dot([3], [0, 1], k, [3], [0, 1])
|> softmax()
|> Nx.add(mask)
|> Nx.dot([3], [0, 1], v, [2], [0, 1])
|> Nx.transpose(axes: [0, 2, 1, 3])
|> flatten_heads()
end
deftransformp flatten_heads(tensor) do
shape = Nx.shape(tensor)
rank = Nx.rank(tensor)
new_shape =
shape
|> Tuple.delete_at(rank - 1)
|> put_elem(rank - 2, :auto)
Nx.reshape(tensor, new_shape)
end
defn ffn(x, %{"c_fc" => c_fc, "c_proj" => c_proj}) do
x
|> linear(c_fc)
|> gelu()
|> linear(c_proj)
end
@doc """
Linear layer.
## Examples
iex> {x, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {32, 128})
iex> {w, key} = Nx.Random.uniform(key, shape: {128, 256})
iex> {b, _key} = Nx.Random.uniform(key, shape: {256})
iex> out = GPT.linear(x, %{"weight" => w, "bias" => b})
iex> Nx.shape(out)
{32, 256}
iex> Nx.type(out)
{:f, 32}
"""
defn linear(x, %{"weight" => w, "bias" => b}) do
x |> Nx.dot(w) |> Nx.add(b)
end
@doc """
Applies Layer Normalization.
## Examples
iex> x = Nx.tensor([[2, 2, 3], [-5, 0, 1]])
iex> actual = GPT.layer_norm(x, %{"weight" => Nx.broadcast(1.0, {2, 1}), "bias" => Nx.broadcast(0.0, {2, 1})})
iex> expected = Nx.tensor([
...> [-0.70709, -0.70709, 1.41418],
...> [-1.397, 0.508, 0.889]
...> ])
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn layer_norm(x, %{"weight" => w, "bias" => b}, opts \\ []) do
opts = keyword!(opts, eps: 1.0e-5)
mean = Nx.mean(x, axes: [-1], keep_axes: true)
variance = Nx.variance(x, axes: [-1], keep_axes: true)
x = (x - mean) / Nx.sqrt(variance + opts[:eps])
w * x + b
end
@doc """
Applies GeLU Activation.
## Examples
iex> actual = GPT.gelu(Nx.tensor([[1, 2], [-2, 0.5]]))
iex> expected = Nx.tensor(([[0.84119, 1.9546], [-0.0454, 0.34571]]))
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn gelu(x) do
0.5 * x * (1 + Nx.tanh(Nx.sqrt(2 / Nx.Constants.pi()) * (x + 0.044715 * Nx.pow(x, 3))))
end
@doc """
Applies Softmax Activation.
## Examples
iex> actual = GPT.softmax(Nx.tensor([[2, 100], [-5, 0]]))
iex> expected = Nx.tensor([[2.74878501e-43, 1.0],[6.69285092e-03, 9.93307149e-01]])
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn softmax(x) do
exp_x = Nx.exp(x - Nx.reduce_max(x, axes: [-1], keep_axes: true))
exp_x / Nx.sum(exp_x, axes: [-1], keep_axes: true)
end
end
predict_fun = fn input, params ->
{wte, params} = Map.pop!(params, "wte")
{wpe, params} = Map.pop!(params, "wpe")
{ln_f, params} = Map.pop!(params, "ln_f")
GPT.predict(input, wte, wpe, params, ln_f)
end
predict_fun = Nx.Defn.jit(predict_fun, compiler: EXLA)
input = Encoder.encode(encoder, "Hello World!")
predict_fun.(input, blocks)
output = predict_fun.(input, blocks)
logits = output[[.., -1]]
new_token = Nx.argmax(logits, axis: -1)
defmodule Generator do
def generate(predict_fun, encoder, input, params, eos_id, max_seq_len) do
encoded_input = Encoder.encode(encoder, input)
seq_len = Nx.axis_size(encoded_input, 1)
Enum.reduce_while(seq_len..max_seq_len, encoded_input, fn _idx, current_input ->
output = predict_fun.(current_input, params)
logits = output[[.., -1]]
next_token = Nx.argmax(logits, axis: -1, keep_axis: true)
if eos_id == Nx.to_number(Nx.squeeze(next_token)) do
{:halt, current_input}
else
IO.write("#{Encoder.decode(encoder, Nx.to_number(Nx.squeeze(next_token)))}")
new_sequence = Nx.concatenate([current_input, next_token], axis: -1)
{:cont, new_sequence}
end
end)
end
end
Generator.generate(predict_fun, encoder, "Elixir is", blocks, 50256, 256)