Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Flax to Axon - Verification

flax_to_axon_verification.livemd

Flax to Axon - Verification

Mix.install([
  {:axon, "~> 0.6.1"},
  {:stream_data, "~> 1.1"},
  {:nx, "~> 0.7.2"},
  {:safetensors, "~> 0.1.3"},
  {:kino, "~> 0.12.3"}
])

asdf_dir = "#{__DIR__}/.asdf"

unless File.exists?(asdf_dir) do
  {_,0} = System.cmd("git", [
    "clone",
    "https://github.com/asdf-vm/asdf.git",
    asdf_dir,
    "--branch",
    "v0.14.0"
  ])
end

asdf = "#{asdf_dir}/bin/asdf"
{_, 0} = System.cmd(asdf, ["plugin", "add", "python"], env: [{"ASDF_DATA_DIR", asdf_dir}])

{_, 0} = System.cmd(asdf, ["install", "python", "3.11.9"], env: [{"ASDF_DATA_DIR", "#{__DIR__}/.asdf"}])

asdf_python = Path.join([asdf_dir, "installs", "python", "3.11.9", "bin", "python"])

python_packages =
  ~w(
    safetensors 
    torch
    transformers
    accelerate 
    numpy
    datasets
    pillow
    flax
    jax
    jaxlib
  )

venv_dir = Path.join(__DIR__, "flax2axon_env")
{_, 0} = System.cmd(asdf_python, ["-m", "venv", "--copies", venv_dir])

python = Path.join([venv_dir, "bin", "python"])
pip = Path.join([venv_dir, "bin", "pip"])

{_, 0} = System.cmd(pip, ["install" | python_packages])

run_python = fn command, opts ->
  System.cmd(python, ["-c", command], opts)
end

data_dir = Path.join(__DIR__, "data")

unless File.exists?(data_dir), do: :ok = File.mkdir(data_dir)

Install and use python

As we must run the Flax model using Python we must set up the environment.

The setup block will automatically install python using asdf, create a virtual environment, install all the packages we need and create a helper funtion run_python to run Python code.

IO.puts("Python is here: #{python}")
{_, 0} = run_python.("print('hello from Python')", [])

We define some paths that we will use to store data in safetensors format. That’s just an easy way to work with the same values in both frameworks, Axon and Flax.

params_axon_path = Path.join(data_dir, "params_axon.safetensors")
params_flax_path = Path.join(data_dir, "params_flax.safetensors")
test_data_axon_path = Path.join(data_dir, "test_data_axon.safetensors")
test_data_flax_path = Path.join(data_dir, "test_data_flax.safetensors")

Manual test data

test_data = %{
  "input_0" => Nx.broadcast(Nx.tensor([1, 2, 3], type: :f32), {1, 1, 3}),
  "input_1" => Nx.broadcast(Nx.tensor([1, 2, 3], type: :f32), {1, 1, 3}),
  "output_0" => Nx.broadcast(Nx.tensor([5.999975204467773, 7.6568403244018555, 8.196144104003906], type: :f32), {1, 1, 3})
}

Safetensors.write!(test_data_axon_path, test_data)

Let’s define some test data… we will get this from previous steps.

We must run the models in Axon and Flax using the same parameters, otherwise we won’t get the same results.

params =
  %{
    "batch_norm_0" => %{
      "beta" => Nx.tensor([1, 1, 1], type: :f32),
      "gamma" => Nx.tensor([1, 1, 1], type: :f32),
      "mean" => Nx.tensor([1, 2, 3], type: :f32),
      "var" => Nx.tensor([1, 2, 3], type: :f32)
    },
    "conv_0" => %{
      "kernel" => Nx.broadcast(Nx.tensor(1, type: :f32), {1, 3, 3})
    }
  }

# flatten first

#test_data = Map.put(test_data, "params", params)
#Safetensors.write!(test_data_axon_path, test_data)

Test data from code to code model

This is the model we converted.

flax_res_net_conv_layer = fn x, out_channels, kernel_size, stride, activation ->
  hidden_state =
    Axon.conv(
      x,
      out_channels,
      kernel_size: {kernel_size, kernel_size},
      strides: stride,
      padding: [
        {div(kernel_size, 2), div(kernel_size, 2)},
        {div(kernel_size, 2), div(kernel_size, 2)}
      ],
      use_bias: false,
      kernel_initializer:
        Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal)
    )

  hidden_state =
    Axon.batch_norm(hidden_state,
      momentum: 0.9,
      epsilon: 1.0e-05
    )

  hidden_state = Axon.activation(hidden_state, activation)
  hidden_state
end

This is our plan:

  • TODO: First, we check what kind of inputs it accepts.
  • Then, infer param shapes using init function.
  • TODO: Then generate random data for params and inputs and run the model.
  • Save inputs and outputs in safetensor format to check against Flax model
  • Map params to Flax names, then save them in safetensors format Axon and Flax version
out_channels = 3
kernel_size = 3
stride = 1
activation = :relu

model =
  flax_res_net_conv_layer.(Axon.input("input"), out_channels, kernel_size, stride, activation)

{init_fn, predict_fn} = Axon.build(model)

input_shape = {1, 3, 3, 3}
input_type = :f32


params = init_fn.(Nx.template(input_shape, input_type), %{})

Let’s check which params we need. safetensors requires flattened keys to store the params. So we concatenate the key hierarchy with a . as seperator.

defmodule ParamsUtils do
  def flatten_keys(%{} = params) do
    for key <- Map.keys(params) do
      prefixed_keys(params[key], key)
    end
    |> List.flatten()
  end

  defp prefixed_keys(%Nx.Tensor{}, key), do: key

  defp prefixed_keys(%{} = params, prefix) do
    for key <- Map.keys(params) do
      prefixed_keys(params[key], "#{prefix}.#{key}")
    end
  end

  def get_from_flattened_key(params, flattened_key) do
    keys = String.split(flattened_key, ".")

    for key <- keys, reduce: params do
      acc -> acc[key]
    end
  end

  def unflatten_and_put(params, flattened_key, value) do
    single_param_map = flattened_map(flattened_key, value)

    merge_recursive(params, single_param_map)
  end

  def merge_recursive(%{} = map1, %{} = map2) do
    Map.merge(map1, map2, fn _k, m1, m2 -> merge_recursive(m1, m2) end)
  end

  defp flattened_map(flattened_key, value) do
    case String.split(flattened_key, ".", parts: 2) do
      [key] -> %{key => value}
      [key, other_keys] -> %{key => flattened_map(other_keys, value)}
    end
  end
end

ParamsUtils.flatten_keys(params)

In Flax these are the keys of our params:

['batch_stats.normalization.mean', 'batch_stats.normalization.var', 'params.convolution.kernel', 'params.normalization.bias', 'params.normalization.scale']

So, we create a mapping from the Axon world to the Flax world.

param_mapping = %{
  "batch_norm_0.beta" => "params.normalization.bias",
  "batch_norm_0.gamma" => "params.normalization.scale",
  "batch_norm_0.mean" => "batch_stats.normalization.mean",
  "batch_norm_0.var" => "batch_stats.normalization.var",
  "conv_0.kernel" => "params.convolution.kernel"
}

Then, we create another set of params with flattened keys according to Flax and store all params as safetensors.

axon_params = for {axon_key, _} <- param_mapping, into: %{} do
  {axon_key, ParamsUtils.get_from_flattened_key(params, axon_key)}
end 

flax_params = for {axon_key, flax_key} <- param_mapping, into: %{} do
  {flax_key, ParamsUtils.get_from_flattened_key(params, axon_key)}
end

Safetensors.write!(params_axon_path, axon_params)
Safetensors.write!(params_flax_path, flax_params)
input_data =
  for dim <- Enum.reverse(Tuple.to_list(input_shape)), reduce: StreamData.float() do
    acc -> StreamData.list_of(acc, length: dim)
  end

test_data =
  for i <- 0..100 do   
    input =
      input_data
      |> Enum.take(1)
      |> hd
      |> Nx.tensor()

    input_name = "input_#{i}"
    output_name = "output_#{i}"

    output = predict_fn.(params, input)

    [{input_name, input}, {output_name, output}]
  end
  |> List.flatten()
  |> Map.new()

Safetensors.write!(test_data_axon_path, test_data)

Calculate results in Flax

Now, we use a script to run the Flax model with the same inputs as our generated Axon model.

module = "FlaxResNetConvLayer"
test_flax =
  """
  import jax
  from typing import Any, Callable, Sequence
  from jax import random, numpy as jnp
  import flax
  from flax import linen as nn
  
  from functools import partial
  from typing import Optional, Tuple

  from safetensors import safe_open
  from safetensors.flax import save_file
  
  from transformers.models.resnet.modeling_flax_resnet import #{module}

  def unflatten_dict(d, sep='.'):
    result = {}
    for key, value in d.items():
        parts = key.split(sep)
        node = result
        for part in parts[:-1]:
            node = node.setdefault(part, {})
        node[parts[-1]] = value
    return result

  tensors = {}
  with safe_open("#{test_data_axon_path}", framework="flax") as f:
      for k in f.keys():
          tensors[k] = f.get_tensor(k)
  
  model = #{module}(3)
  key = random.key(0)

  
  params = {}
  with safe_open("#{params_flax_path}", framework="flax") as f:
      for k in f.keys():
          params[k] = f.get_tensor(k)
  
  params = unflatten_dict(params)

  out_tensors = tensors.copy()
  input_keys = [key for key in tensors.keys() if key.startswith("input")]
  for input_key in input_keys:  
    input = tensors[input_key]
  
    output = model.apply(params, input)
    output_key = input_key.replace("input", "output")
  
    out_tensors[output_key] = output

  save_file(out_tensors, "#{test_data_flax_path}")
  """

run_python.(test_flax, []) 

Compare with results from Axon

Then, we compare the results from Flax with those from Axon.

axon_data = Safetensors.read!(test_data_axon_path)
flax_data = Safetensors.read!(test_data_flax_path)
assert_all_close = fn left, right ->
    atol = 1.0e-4
    rtol = 1.0e-4

    equals =
      left
      |> Nx.all_close(right, atol: atol, rtol: rtol)
      |> Nx.backend_transfer(Nx.BinaryBackend)
  

    equals == Nx.tensor(1, type: :u8, backend: Nx.BinaryBackend) 
    end

same_result? = fn axon_result, flax_result ->
  assert_all_close.(axon_result, flax_result)
end

verification_results =
  for output_key <- Map.keys(axon_data), String.starts_with?(output_key, "output"), into: %{} do
    input_key = String.replace(output_key, "output", "input")
    
    got_same? = same_result?.(axon_data[output_key], flax_data[output_key]) 
    
    {output_key,
     %{same_result?: got_same?, input: axon_data[input_key], axon_output: axon_data[output_key], flax_output: flax_data[output_key]}}
  end
wrong_results = verification_results
|> Map.to_list()
|> Enum.filter(fn {_, output} -> output.same_result? == false end)

Enum.count(wrong_results)
{_, first_wrong} = hd(wrong_results)

first_wrong.input
first_wrong.axon_output
first_wrong.flax_output