Flax to Axon - Verification
Mix.install([
{:axon, "~> 0.6.1"},
{:stream_data, "~> 1.1"},
{:nx, "~> 0.7.2"},
{:safetensors, "~> 0.1.3"},
{:kino, "~> 0.12.3"}
])
asdf_dir = "#{__DIR__}/.asdf"
unless File.exists?(asdf_dir) do
{_,0} = System.cmd("git", [
"clone",
"https://github.com/asdf-vm/asdf.git",
asdf_dir,
"--branch",
"v0.14.0"
])
end
asdf = "#{asdf_dir}/bin/asdf"
{_, 0} = System.cmd(asdf, ["plugin", "add", "python"], env: [{"ASDF_DATA_DIR", asdf_dir}])
{_, 0} = System.cmd(asdf, ["install", "python", "3.11.9"], env: [{"ASDF_DATA_DIR", "#{__DIR__}/.asdf"}])
asdf_python = Path.join([asdf_dir, "installs", "python", "3.11.9", "bin", "python"])
python_packages =
~w(
safetensors
torch
transformers
accelerate
numpy
datasets
pillow
flax
jax
jaxlib
)
venv_dir = Path.join(__DIR__, "flax2axon_env")
{_, 0} = System.cmd(asdf_python, ["-m", "venv", "--copies", venv_dir])
python = Path.join([venv_dir, "bin", "python"])
pip = Path.join([venv_dir, "bin", "pip"])
{_, 0} = System.cmd(pip, ["install" | python_packages])
run_python = fn command, opts ->
System.cmd(python, ["-c", command], opts)
end
data_dir = Path.join(__DIR__, "data")
unless File.exists?(data_dir), do: :ok = File.mkdir(data_dir)
Install and use python
As we must run the Flax model using Python we must set up the environment.
The setup block will automatically install python using asdf, create a virtual environment, install all the packages we need and create a helper funtion run_python
to run Python code.
IO.puts("Python is here: #{python}")
{_, 0} = run_python.("print('hello from Python')", [])
We define some paths that we will use to store data in safetensors format. That’s just an easy way to work with the same values in both frameworks, Axon and Flax.
params_axon_path = Path.join(data_dir, "params_axon.safetensors")
params_flax_path = Path.join(data_dir, "params_flax.safetensors")
test_data_axon_path = Path.join(data_dir, "test_data_axon.safetensors")
test_data_flax_path = Path.join(data_dir, "test_data_flax.safetensors")
Manual test data
test_data = %{
"input_0" => Nx.broadcast(Nx.tensor([1, 2, 3], type: :f32), {1, 1, 3}),
"input_1" => Nx.broadcast(Nx.tensor([1, 2, 3], type: :f32), {1, 1, 3}),
"output_0" => Nx.broadcast(Nx.tensor([5.999975204467773, 7.6568403244018555, 8.196144104003906], type: :f32), {1, 1, 3})
}
Safetensors.write!(test_data_axon_path, test_data)
Let’s define some test data… we will get this from previous steps.
We must run the models in Axon and Flax using the same parameters, otherwise we won’t get the same results.
params =
%{
"batch_norm_0" => %{
"beta" => Nx.tensor([1, 1, 1], type: :f32),
"gamma" => Nx.tensor([1, 1, 1], type: :f32),
"mean" => Nx.tensor([1, 2, 3], type: :f32),
"var" => Nx.tensor([1, 2, 3], type: :f32)
},
"conv_0" => %{
"kernel" => Nx.broadcast(Nx.tensor(1, type: :f32), {1, 3, 3})
}
}
# flatten first
#test_data = Map.put(test_data, "params", params)
#Safetensors.write!(test_data_axon_path, test_data)
Test data from code to code model
This is the model we converted.
flax_res_net_conv_layer = fn x, out_channels, kernel_size, stride, activation ->
hidden_state =
Axon.conv(
x,
out_channels,
kernel_size: {kernel_size, kernel_size},
strides: stride,
padding: [
{div(kernel_size, 2), div(kernel_size, 2)},
{div(kernel_size, 2), div(kernel_size, 2)}
],
use_bias: false,
kernel_initializer:
Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal)
)
hidden_state =
Axon.batch_norm(hidden_state,
momentum: 0.9,
epsilon: 1.0e-05
)
hidden_state = Axon.activation(hidden_state, activation)
hidden_state
end
This is our plan:
- TODO: First, we check what kind of inputs it accepts.
- Then, infer param shapes using init function.
- TODO: Then generate random data for params and inputs and run the model.
- Save inputs and outputs in safetensor format to check against Flax model
- Map params to Flax names, then save them in safetensors format Axon and Flax version
out_channels = 3
kernel_size = 3
stride = 1
activation = :relu
model =
flax_res_net_conv_layer.(Axon.input("input"), out_channels, kernel_size, stride, activation)
{init_fn, predict_fn} = Axon.build(model)
input_shape = {1, 3, 3, 3}
input_type = :f32
params = init_fn.(Nx.template(input_shape, input_type), %{})
Let’s check which params we need. safetensors
requires flattened keys to store the params. So we concatenate the key hierarchy with a .
as seperator.
defmodule ParamsUtils do
def flatten_keys(%{} = params) do
for key <- Map.keys(params) do
prefixed_keys(params[key], key)
end
|> List.flatten()
end
defp prefixed_keys(%Nx.Tensor{}, key), do: key
defp prefixed_keys(%{} = params, prefix) do
for key <- Map.keys(params) do
prefixed_keys(params[key], "#{prefix}.#{key}")
end
end
def get_from_flattened_key(params, flattened_key) do
keys = String.split(flattened_key, ".")
for key <- keys, reduce: params do
acc -> acc[key]
end
end
def unflatten_and_put(params, flattened_key, value) do
single_param_map = flattened_map(flattened_key, value)
merge_recursive(params, single_param_map)
end
def merge_recursive(%{} = map1, %{} = map2) do
Map.merge(map1, map2, fn _k, m1, m2 -> merge_recursive(m1, m2) end)
end
defp flattened_map(flattened_key, value) do
case String.split(flattened_key, ".", parts: 2) do
[key] -> %{key => value}
[key, other_keys] -> %{key => flattened_map(other_keys, value)}
end
end
end
ParamsUtils.flatten_keys(params)
In Flax these are the keys of our params:
['batch_stats.normalization.mean', 'batch_stats.normalization.var', 'params.convolution.kernel', 'params.normalization.bias', 'params.normalization.scale']
So, we create a mapping from the Axon world to the Flax world.
param_mapping = %{
"batch_norm_0.beta" => "params.normalization.bias",
"batch_norm_0.gamma" => "params.normalization.scale",
"batch_norm_0.mean" => "batch_stats.normalization.mean",
"batch_norm_0.var" => "batch_stats.normalization.var",
"conv_0.kernel" => "params.convolution.kernel"
}
Then, we create another set of params with flattened keys according to Flax and store all params as safetensors
.
axon_params = for {axon_key, _} <- param_mapping, into: %{} do
{axon_key, ParamsUtils.get_from_flattened_key(params, axon_key)}
end
flax_params = for {axon_key, flax_key} <- param_mapping, into: %{} do
{flax_key, ParamsUtils.get_from_flattened_key(params, axon_key)}
end
Safetensors.write!(params_axon_path, axon_params)
Safetensors.write!(params_flax_path, flax_params)
input_data =
for dim <- Enum.reverse(Tuple.to_list(input_shape)), reduce: StreamData.float() do
acc -> StreamData.list_of(acc, length: dim)
end
test_data =
for i <- 0..100 do
input =
input_data
|> Enum.take(1)
|> hd
|> Nx.tensor()
input_name = "input_#{i}"
output_name = "output_#{i}"
output = predict_fn.(params, input)
[{input_name, input}, {output_name, output}]
end
|> List.flatten()
|> Map.new()
Safetensors.write!(test_data_axon_path, test_data)
Calculate results in Flax
Now, we use a script to run the Flax model with the same inputs as our generated Axon model.
module = "FlaxResNetConvLayer"
test_flax =
"""
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from functools import partial
from typing import Optional, Tuple
from safetensors import safe_open
from safetensors.flax import save_file
from transformers.models.resnet.modeling_flax_resnet import #{module}
def unflatten_dict(d, sep='.'):
result = {}
for key, value in d.items():
parts = key.split(sep)
node = result
for part in parts[:-1]:
node = node.setdefault(part, {})
node[parts[-1]] = value
return result
tensors = {}
with safe_open("#{test_data_axon_path}", framework="flax") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
model = #{module}(3)
key = random.key(0)
params = {}
with safe_open("#{params_flax_path}", framework="flax") as f:
for k in f.keys():
params[k] = f.get_tensor(k)
params = unflatten_dict(params)
out_tensors = tensors.copy()
input_keys = [key for key in tensors.keys() if key.startswith("input")]
for input_key in input_keys:
input = tensors[input_key]
output = model.apply(params, input)
output_key = input_key.replace("input", "output")
out_tensors[output_key] = output
save_file(out_tensors, "#{test_data_flax_path}")
"""
run_python.(test_flax, [])
Compare with results from Axon
Then, we compare the results from Flax with those from Axon.
axon_data = Safetensors.read!(test_data_axon_path)
flax_data = Safetensors.read!(test_data_flax_path)
assert_all_close = fn left, right ->
atol = 1.0e-4
rtol = 1.0e-4
equals =
left
|> Nx.all_close(right, atol: atol, rtol: rtol)
|> Nx.backend_transfer(Nx.BinaryBackend)
equals == Nx.tensor(1, type: :u8, backend: Nx.BinaryBackend)
end
same_result? = fn axon_result, flax_result ->
assert_all_close.(axon_result, flax_result)
end
verification_results =
for output_key <- Map.keys(axon_data), String.starts_with?(output_key, "output"), into: %{} do
input_key = String.replace(output_key, "output", "input")
got_same? = same_result?.(axon_data[output_key], flax_data[output_key])
{output_key,
%{same_result?: got_same?, input: axon_data[input_key], axon_output: axon_data[output_key], flax_output: flax_data[output_key]}}
end
wrong_results = verification_results
|> Map.to_list()
|> Enum.filter(fn {_, output} -> output.same_result? == false end)
Enum.count(wrong_results)
{_, first_wrong} = hd(wrong_results)
first_wrong.input
first_wrong.axon_output
first_wrong.flax_output