Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Flax to Axon - Interactive Tool

flax_to_axon_interactive_tool.livemd

Flax to Axon - Interactive Tool

Mix.install([
  {:axon, "~> 0.6.1"},
  {:stream_data, "~> 1.1"},
  {:nx, "~> 0.7.2"},
  {:safetensors, "~> 0.1.3"},
  {:kino, "~> 0.12.3"},
  {:langchain, "~> 0.3.0-rc.0"},
  {:exla, "~> 0.7.3"}
])

Nx.global_default_backend(EXLA.Backend)

asdf_dir = "#{__DIR__}/.asdf"

unless File.exists?(asdf_dir) do
  {_, 0} =
    System.cmd("git", [
      "clone",
      "https://github.com/asdf-vm/asdf.git",
      asdf_dir,
      "--branch",
      "v0.14.0"
    ])
end

asdf = "#{asdf_dir}/bin/asdf"
{_, 0} = System.cmd(asdf, ["plugin", "add", "python"], env: [{"ASDF_DATA_DIR", asdf_dir}])

{_, 0} =
  System.cmd(asdf, ["install", "python", "3.11.9"], env: [{"ASDF_DATA_DIR", "#{__DIR__}/.asdf"}])

asdf_python = Path.join([asdf_dir, "installs", "python", "3.11.9", "bin", "python"])

python_packages =
  ~w(
    safetensors
    torch
    transformers
    accelerate
    numpy
    datasets
    pillow
    flax
    jax
    jaxlib
  )

venv_dir = Path.join(__DIR__, "flax_to_axon_env")
{_, 0} = System.cmd(asdf_python, ["-m", "venv", "--copies", venv_dir])

python = Path.join([venv_dir, "bin", "python"])
pip = Path.join([venv_dir, "bin", "pip"])

{_, 0} = System.cmd(pip, ["install" | python_packages])

run_python = fn command, opts ->
  System.cmd(python, ["-c", command], opts)
end

data_dir = Path.join(__DIR__, "data")

unless File.exists?(data_dir), do: :ok = File.mkdir(data_dir)

WARNING

> This Livebook installs asdf, Python and some libraries in the directory of the livebook. > Modify the notebook setup if you don’t want that. > > It also runs LLM generated Elixir code using Code.eval_string.

Introduction

You can configure the LLM provider and model here. For OpenAI, you must set the OPENAI_KEY secret to your api key.

api_key = System.get_env("LB_OPENAI_KEY")

llm_config =
  LangChain.ChatModels.ChatOpenAI.new!(%{
    model: "gpt-4o-mini",
    api_key: api_key,
    seed: 0
  })

Goals

  • (Semi-) Automatically convert Flax models to Axon (from transformers/diffusers)
  • Verify that models compute same results for same inputs

For now, we consider only Flax (linen) because it’s similar to Axon and should be easier than PyTorch or TensorFlow

flowchart TD
    Input[Read Flax File]
    Classes[Split into classes and free functions]

    subgraph each_class[for each class/free function]
      subgraph conversion[Convert into Axon]
        Convert[Convert using LLM instructions]
        Replace[Replace previously \n converted classes \n with Axon function]
      end
      subgraph verify[Verify same results]
        direction TB
        Generate[Generate input and params in Elixir]
        Compute_axon[Compute results in Axon]
        Safetensors[Save everything in safetensors]
        Load[Load safetensors in Python]
        Compute_flax[Compute results in Flax]
        Compare[Compare outputs for inputs]
      end

      Store[Store resulting Axon function]
    end


    Input-->Classes-->each_class

    Convert-->Replace
    Convert-->Convert
    conversion-->verify

    Generate-->Compute_axon-->Safetensors-->Load-->Compute_flax-->Compare

    verify-->Store

Get inputs

We read a Python file which contains the model code. Then we get all top level definitions from the file.

Models in Flax are classes, so they are named in CamelCase. In Axon we represent them as functions, named in snake_case.

defmodule Utils do
  @moduledoc """
  Some utils to extract the information we need from Flax code.
  """

  @doc """
  Splits Flax code into its top level definitions.
  Considers only classes and free functions.
  Ignores everything before the first class or function definition, i.e. imports etc.
  """
  def top_level_defs(flax_code) do
    [_preamble | top_level_defs] =
      Regex.split(~r{\n(class |def )\b}, flax_code, include_captures: true)

    def_type = Enum.take_every(top_level_defs, 2)
    def_content = Enum.drop(top_level_defs, 1) |> Enum.take_every(2)

    Enum.zip_with(def_type, def_content, fn type, content ->
      String.trim_leading(type) <> content
    end)
  end

  @doc """
  Finds the first top level definition in `flax_code`.
  Then returns a tuple with the name in Flax and the corresponding name in Axon.
  """
  def extract_names(flax_code) do
    [flax_name | _] = String.split(flax_code, "(")

    flax_name =
      flax_name
      |> String.trim_leading("class ")
      |> String.trim_leading("def ")

    axon_name = Macro.underscore(flax_name)

    {flax_name, axon_name}
  end
end

Conversion - Prompt Definitions

In this section we define a set of messages we pass to the LLM to make it help us convert the Flax code.

  • System Messages to give the LLM some context
  • Conversion tables for layers, activation functions and initializers (partially generated by ChatGPT)
  • Instructions the LLM will perform one at a time to convert the code
system_message =
  """
  You are an expert in machine learning frameworks.
  You help converting models from Python code in Flax linen framework to Elixir code in the Axon framework.
  You will do this by following instructions step by step.
  For each of the instructions, reply ONLY with the modified code of the model.
  Do NOT include any other content or comments.
  Perform ONLY the action the instruction asks you to do.
  In case you don't change anything return ONLY the unchanged model code.
  I will provide you with the model code and an instruction in the following format:
  INSTRUCTION: here is the instruction you perform
  MODEL: here is the model code you will modify
  """

Kino.nothing()

Conversion Tables

Created using ChatGPT and refining format.

Layers

Prompt:

Fetch the documentation about layers in Flax linen from here: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html

Build a separate markdown table for each of the layers in Flax linen and the corresponding layer in Axon. Each table should have two rows: First row: shows the Flax linen layer in the first column, then a column for each parameter it takes. Prefix the linen layer with nn. Second row: shows the corresponding Axon layer in the first column, then the corresponding Axon parameter to each of the Flax linen parameters. Prefix the Axon layer with Axon.Layer.

If there is no corresponding layer, or no corresponding parameter, add a “-“ instead.

Here an example for the Conv layer of Flax linen:

Flax nn.Conv features kernel_size strides padding dtype use_bias kernel_init
Axon Axon.conv units kernel_size strides padding - use_bias kernel_initializer
conversion_layers_table =
  """
  ### Conv Layer

  |          | Layer     | features | kernel_size | strides | padding | dtype | use_bias | kernel_init        |
  | -------- | --------- | -------- | ----------- | ------- | ------- | ----- | -------- | ------------------ |
  | **Flax** | nn.Conv   | features | kernel_size | strides | padding | dtype | use_bias | kernel_init        |
  | **Axon** | Axon.conv | units    | kernel_size | strides | padding | -     | use_bias | kernel_initializer |

  ### Dense Layer

  |          | Layer      | features | use_bias | kernel_init        |
  | -------- | ---------- | -------- | -------- | ------------------ |
  | **Flax** | nn.Dense   | features | use_bias | kernel_init        |
  | **Axon** | Axon.dense | units    | use_bias | kernel_initializer |

  ### Dropout Layer

  |          | Layer        | rate |
  | -------- | ------------ | ---- |
  | **Flax** | nn.Dropout   | rate |
  | **Axon** | Axon.dropout | rate |

  ### BatchNorm Layer

  |          | Layer           | use_running_average | axis | momentum | epsilon | dtype |
  | -------- | --------------- | ------------------- | ---- | -------- | ------- | ----- |
  | **Flax** | nn.BatchNorm    | use_running_average | axis | momentum | epsilon | dtype |
  | **Axon** | Axon.batch_norm | -                   | axis | momentum | epsilon | -     |

  ### LayerNorm Layer

  |          | Layer           | axis | epsilon | dtype |
  | -------- | --------------- | ---- | ------- | ----- |
  | **Flax** | nn.LayerNorm    | axis | epsilon | dtype |
  | **Axon** | Axon.layer_norm | axis | epsilon | -     |

  ### Relu Layer

  |          | Layer     |
  | -------- | --------- |
  | **Flax** | nn.relu   |
  | **Axon** | Axon.relu |

  ### Sigmoid Layer

  |          | Layer        |
  | -------- | ------------ |
  | **Flax** | nn.sigmoid   |
  | **Axon** | Axon.sigmoid |

  ### Tanh Layer

  |          | Layer     |
  | -------- | --------- |
  | **Flax** | nn.tanh   |
  | **Axon** | Axon.tanh |

  ### MaxPool Layer

  |          | Layer         | window_shape | strides | padding |
  | -------- | ------------- | ------------ | ------- | ------- |
  | **Flax** | nn.max_pool   | window_shape | strides | padding |
  | **Axon** | Axon.max_pool | kernel_size  | strides | padding |

  ### AvgPool Layer

  |          | Layer         | window_shape | strides | padding |
  | -------- | ------------- | ------------ | ------- | ------- |
  | **Flax** | nn.avg_pool   | window_shape | strides | padding |
  | **Axon** | Axon.avg_pool | kernel_size  | strides | padding |

  ### GlobalAvgPool Layer

  |          | Layer                |
  | -------- | -------------------- |
  | **Flax** | nn.global_avg_pool   |
  | **Axon** | Axon.global_avg_pool |

  ### Flatten Layer

  |          | Layer        |
  | -------- | ------------ |
  | **Flax** | nn.Flatten   |
  | **Axon** | Axon.flatten |
  """

Kino.Markdown.new(conversion_layers_table)

Initializers

Prompt:

Fetch the documentation about initializers in Axon from here: https://hexdocs.pm/axon/Axon.Initializers.html
Fetch the documentation about initializers in Flax linen from here: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/initializers.html

Build a separate markdown table for each of the initializers in Flax linen and the corresponding initializer in Axon. Each table should have two rows: First row: shows the Flax linen initializer in the first column, then a column for each parameter it takes. Prefix the linen initializer with nn. Second row: shows the corresponding Axon initializer in the first column, then the corresponding Axon parameter to each of the Flax linen parameters. Prefix the Axon initializer with Axon.Layer.

If there is no corresponding initializer, or no corresponding parameter, add a “-“ instead.

Here an example for the initializers.variance_scaling of Flax linen:

Flax nn.initializers.variance_scaling scale mode distribution in_axis out_axis batch_axis dtype
Axon Axon.Initializers.variance_scaling scale mode distribution - - - -
conversion_initializers_table =
  """
  ### Variance Scaling

  |      |                                    |       |      |              |         |          |            |       |
  | ---- | ---------------------------------- | ----- | ---- | ------------ | ------- | -------- | ---------- | ----- |
  | Flax | nn.initializers.variance_scaling   | scale | mode | distribution | in_axis | out_axis | batch_axis | dtype |
  | Axon | Axon.Initializers.variance_scaling | scale | mode | distribution | -       | -        | -          | -     |

  ### Glorot Normal

  |      |                                 |       |
  | ---- | ------------------------------- | ----- |
  | Flax | nn.initializers.glorot_normal   | dtype |
  | Axon | Axon.Initializers.glorot_normal | -     |

  ### Glorot Uniform

  |      |                                  |       |
  | ---- | -------------------------------- | ----- |
  | Flax | nn.initializers.glorot_uniform   | dtype |
  | Axon | Axon.Initializers.glorot_uniform | -     |

  ### Lecun Normal

  |      |                                |       |
  | ---- | ------------------------------ | ----- |
  | Flax | nn.initializers.lecun_normal   | dtype |
  | Axon | Axon.Initializers.lecun_normal | -     |

  ### Lecun Uniform

  |      |                                 |       |
  | ---- | ------------------------------- | ----- |
  | Flax | nn.initializers.lecun_uniform   | dtype |
  | Axon | Axon.Initializers.lecun_uniform | -     |

  ### Orthogonal

  |      |                              |              |
  | ---- | ---------------------------- | ------------ |
  | Flax | nn.initializers.orthogonal   | scale        |
  | Axon | Axon.Initializers.orthogonal | distribution |

  ### Constant

  |      |                          |       |
  | ---- | ------------------------ | ----- |
  | Flax | nn.initializers.constant | value |
  | Axon | Axon.Initializers.full   | value |

  ### Normal

  |      |                          |      |        |
  | ---- | ------------------------ | ---- | ------ |
  | Flax | nn.initializers.normal   | mean | stddev |
  | Axon | Axon.Initializers.normal | mean | scale  |

  ### Uniform

  |      |                           |        |        |
  | ---- | ------------------------- | ------ | ------ |
  | Flax | nn.initializers.uniform   | minval | maxval |
  | Axon | Axon.Initializers.uniform | -      | scale  |

  ### Identity

  |      |                            |
  | ---- | -------------------------- |
  | Flax | nn.initializers.identity   |
  | Axon | Axon.Initializers.identity |

  ### Ones

  |      |                        |
  | ---- | ---------------------- |
  | Flax | nn.initializers.ones   |
  | Axon | Axon.Initializers.ones |

  ### Zeros

  |      |                         |
  | ---- | ----------------------- |
  | Flax | nn.initializers.zeros   |
  | Axon | Axon.Initializers.zeros |
  """

Kino.Markdown.new(conversion_initializers_table)

Activation Functions

Prompt:

Fetch the documentation about activations in Axon from here: https://hexdocs.pm/axon/Axon.Activations.html
Fetch the documentation about activations in Flax linen from here: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/activation_functions.html

Build a separate markdown table for each of the activations in Flax linen and the corresponding activation in Axon. Each table should have two rows: First row: shows the Flax linen activation in the first column, then a column for each parameter it takes. Prefix the linen activation with nn. Second row: shows the corresponding Axon activation in the first column, then the corresponding Axon parameter to each of the Flax linen parameters. Prefix the Axon activation with Axon.Activations.

If there is no corresponding activation, or no corresponding parameter, add a “-“ instead.

Here an example for activation.softmax of Flax linen:

Flax nn.activation.softmax axis where initial
Axon Axon.Activations.softmax axis - -
conversion_activations_table =
  """
  ### Softmax

  |      |                          |      |       |         |
  | ---- | ------------------------ | ---- | ----- | ------- |
  | Flax | nn.softmax               | axis | where | initial |
  | Axon | Axon.Activations.softmax | axis | -     | -       |

  ### ReLU

  |      |                       |
  | ---- | --------------------- |
  | Flax | nn.relu               |
  | Axon | Axon.Activations.relu |

  ### Leaky ReLU

  |      |                             |                |
  | ---- | --------------------------- | -------------- |
  | Flax | nn.leaky_relu               | negative_slope |
  | Axon | Axon.Activations.leaky_relu | alpha          |

  ### Sigmoid

  |      |                          |
  | ---- | ------------------------ |
  | Flax | nn.sigmoid               |
  | Axon | Axon.Activations.sigmoid |

  ### Tanh

  |      |                       |
  | ---- | --------------------- |
  | Flax | nn.tanh               |
  | Axon | Axon.Activations.tanh |

  ### GELU

  |      |                       |             |
  | ---- | --------------------- | ----------- |
  | Flax | nn.gelu               | approximate |
  | Axon | Axon.Activations.gelu | -           |

  ### ELU

  |      |                      |       |
  | ---- | -------------------- | ----- |
  | Flax | nn.elu               | alpha |
  | Axon | Axon.Activations.elu | alpha |

  ### SELU

  |      |                       |
  | ---- | --------------------- |
  | Flax | nn.selu               |
  | Axon | Axon.Activations.selu |

  ### Softplus

  |      |                           |
  | ---- | ------------------------- |
  | Flax | nn.softplus               |
  | Axon | Axon.Activations.softplus |

  ### Swish

  |      |                        |
  | ---- | ---------------------- |
  | Flax | nn.swish               |
  | Axon | Axon.Activations.swish |

  ### Log Softmax

  |      |                              |      |       |         |
  | ---- | ---------------------------- | ---- | ----- | ------- |
  | Flax | nn.log_softmax               | axis | where | initial |
  | Axon | Axon.Activations.log_softmax | axis | -     | -       |
  """

Kino.Markdown.new(conversion_activations_table)

Detailed instructions how to convert the model

defmodule Instruction do
  defstruct [:name, :instruction, :conversion_table, :example_input, :example_output]
end
template = LangChain.PromptTemplate.from_template!(
  """
  INSTRUCTION:
  <%= @instruction %>
  
  <%= if assigns[:conversion_table] do %>
    <%= @conversion_table %> 
  <% end %>
  
  Here an example.

  INPUT:  
  <%= @example_input %>

  OUTPUT:  
  <%= @example_output %>
  ```
  
  DO NOT PERFORM ANY OTHER ACTION, IF THERE IS NOTHING TO DO RETURN THE UNCHANGED CODE.

  MODEL:
  <%= @model_code %>
  """
)
conversion_instructions = [
  %Instruction{
    name: :setup,
    instruction: """
    If there is a setup function, move all of its content to the `__call__` function.
    Then remove the setup function.
    """,
    example_input: """
    ```python
    class SlightlyLargerSetup(nn.Module):  
      def setup(self):
        self.dense0 = nn.Dense(features=32)    
        self.dropout = nn.Dropout(rate=0.5)
        self.dense1 = nn.Dense(features=1)

      def __call__(self, x, training: bool):
        x = self.dense0(x)
        x = nn.relu(x)
        x = self.dropout(x, deterministic=not training)
        x = self.dense1(x)
        x = nn.softmax(x)
        return x
    ```
    """,
    example_output: """
    ```python
    class SlightlyLargerSetup(nn.Module):  
      def __call__(self, x, training: bool):
        self.dense0 = nn.Dense(features=32)    
        self.dropout = nn.Dropout(rate=0.5)
        self.dense1 = nn.Dense(features=1)
        x = self.dense0(x)
        x = nn.relu(x)
        x = self.dropout(x, deterministic=not training)
        x = self.dense1(x)
        x = nn.softmax(x)
        return x
    ```
    """
  },
  %Instruction{
    name: :attributes,
    instruction: """
    If there are attributes, move them to the `__call__` function as parameters and replace all the references of the attributes with the function parameters.
    """,
    example_input: """
    ```python
    class SlightlyLargerAttributes(nn.Module):
      features_dense_0: int
      dropout_rate: int
      features_dense_1: int

      @nn.compact
      def __call__(self, x, training: bool):
        x = nn.Dense(features=self.features_dense_0)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
        x = nn.Dense(features=self.features_dense_1)(x)
        x = nn.softmax(x)
        return x
    ```
    """,
    example_output: """
     ```python
    class SlightlyLargerAttributes(nn.Module):

      @nn.compact
      def __call__(self, x, training: bool, features_dense_0: int, dropout_rate: int, features_dense_1: int):
        x = nn.Dense(features=features_dense_0)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=dropout_rate, deterministic=not training)(x)
        x = nn.Dense(features=features_dense_1)(x)
        x = nn.softmax(x)
        return x
    ```
    """
  },
  %Instruction{
    name: :move_params_to_declaration,
    instruction: """
    In the `__call__` function, move all additional parameters when calling the layers to the initialization of the layers. Each layer should be called with a single argument.
    """,
    example_input: """
    ```python
    class SlightlyLargerAttributes(nn.Module):

      @nn.compact
      def __call__(self, x, training: bool, features_dense_0: int, dropout_rate: int, features_dense_1: int):
        x = nn.Dense(features=features_dense_0)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=dropout_rate)(x, deterministic=not training)
        x = nn.Dense(features=features_dense_1)(x)
        x = nn.softmax(x)
        return x
    ```
    """,
    example_output: """
    ```python
    class SlightlyLargerAttributes(nn.Module):

      @nn.compact
      def __call__(self, x, training: bool, features_dense_0: int, dropout_rate: int, features_dense_1: int):
        x = nn.Dense(features=features_dense_0)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=dropout_rate, deterministic=not training)(x)
        x = nn.Dense(features=features_dense_1)(x)
        x = nn.softmax(x)
        return x
    ```
    """
  },
  %Instruction{
    name: :remove_self,
    instruction: """
    Remove all occurrences of "self." in the `__call__` function.
    Then drop the "self" parameter of the `__call__` function.
    """,
    example_input: """
    ```python
    class SlightlyLargerSetup(nn.Module):  
      def __call__(self, x, training: bool):
        self.dense0 = nn.Dense(features=32)
        self.dropout = nn.Dropout(rate=0.5)
        self.dense1 = nn.Dense(features=1)
        x = self.dense0(x)
        x = nn.relu(x)
        x = self.dropout(x, deterministic=not training)
        x = self.dense1(x)
        x = nn.softmax(x)
        return x
    ```
    """,
    example_output: """
    ```python
    class SlightlyLargerSetup(nn.Module):  
      def __call__(x, training: bool):
        dense0 = nn.Dense(features=32)    
        dropout = nn.Dropout(rate=0.5)
        dense1 = nn.Dense(features=1)
        x = dense0(x)
        x = nn.relu(x)
        x = dropout(x, deterministic=not training)
        x = dense1(x)
        x = nn.softmax(x)
        return x
    ```
    """
  },
  %Instruction{
    name: :replace_numeric_ops,
    instruction: """
    Replace numeric operations with corresponding Axon functions when the operands are tensors.
    According to the following table:
    """,
    example_input: """
    ```python
    class SomeModule(nn.Module):  
      def __call__(self, hidden_state, residual ):
        hidden_state = hidden_state + residual
        return hidden_state
    ```
    """,
    example_output: """
    ```python
    class SomeModule(nn.Module):  
      def __call__(self, hidden_state, residual ):
        hidden_state = Axon.add(hidden_state, residual)
        return hidden_state
    ```
    """,
    conversion_table: """
    | Flax operation | Axon layer       |
    | ---            | ---              | 
    | +              | Axon.add()       |
    | -              | Axon.substract() |
    | *              | Axon.multiply()  |

    """
  },
  %Instruction{
    name: :replace_partials,
    instruction: """
    Replace `partial` calls with a lambda in Elixir.
    """,
    example_input: """
    ```python
    class SomeModule(nn.Module):  
      def __call__(self, hidden_state ):
        hidden_state = partial(myfunc, 1, 2, 3)
        return hidden_state
    ```
    """,
    example_output: """
    ```python
    class SomeModule(nn.Module):  
      def __call__(self, hidden_state ):
        hidden_state = &myfunc(&1,1, 2, 3)
        return hidden_state
    ```
    """
  },
  %Instruction{
    name: :replace_loops,
    instruction: """
    Replace for loops with an Elixir comprehension with reduce.
    """,
    example_input: """
    ```python
    class SomeModule(nn.Module):  
      def __call__(self, hidden_state):
      layers = [dense0, dense1]
  
      for layer in layers:
        hidden_state = layer(hidden_state)
       
      return hidden_state
    ```
    """,
    example_output: """
    ```python
    class SomeModule(nn.Module):  
      def __call__(self, hidden_state):
        layers = [dense0, dense1]

        hidden_state = for layer <- layers, reduce: hidden_state do
           layer(hidden_state)
        end
         
        return hidden_state
    ```
    """
  },
  %Instruction{
    name: :custom_classes,
    instruction: """
    Rewrite custom class names written in CamelCase inside the `__call__` function in snake_case.
    """,
    example_input: """
    ```python
    class SomeModule(nn.Module):  
      def __call__(self, hidden_state):
        conv_layer = FlaxConv(3)
    
        hidden_state = conv_layer(hidden_state)
         
        return hidden_state
    ```
    """,
    example_output: """
    ```python
    class SomeModule(nn.Module):  
      def __call__(self, hidden_state):
        conv_layer = flax_conv(3)

        hidden_state = conv_layer(hidden_state)
         
        return hidden_state
    ```
    """
  },
  %Instruction{
    name: :convert_layers,
    instruction: """
    Replace the Flax layers with the corresponding Axon layers.

    Return only the Elixir code for the model. 
    Take into account the parameters in the first parenthesis. 
    You must follow these rules to replace layers:
    - When a Flax layer is initialized but not called, replace it with a lambda of the corresponding Axon layer.o
      Example: replace `dense0 = nn.Dense(features=32)` with `dense0 = &Axon.dense(&1, 32)`
    - When a previously initialized Flax layer is called, replace it with a call of the corresponding Axon lambda.
      Example: replace `x = dense0(x)` with `x = dense0.(x)`
    - When the Flax code specifies a name for the layer, add that name as option to the Axon layer:
      Example: replace `dense0 = nn.Dense(features=32, name="dense0")` with `dense0 = &Axon.dense(&1, 32, name: "dense0")`
     
    Use this table to find corresponding layers:
    """,
    conversion_table: conversion_layers_table,
    example_input: """
     ```python
    class SlightlyLarger(nn.Module):

      @nn.compact
      def __call__(self, x, training: bool):
        dense0 = nn.Dense(features=32)    
        dropout = nn.Dropout(rate=0.5)
        dense1 = nn.Dense(features=1)
        x = dense0(x)
        x = nn.relu(x)
        x = dropout(x, deterministic=not training)
        x = dense1(x)
        x = nn.softmax(x)
        return x
    ```
    """,
    example_output: """
      ```python
    class SlightlyLarger(nn.Module):
      def __call__(self, x, training: bool):
        dense0 = &Axon.dense(&1, 32)
        dropout = &Axon.dropout(&1, rate: 0.5)
        dense1 = &Axon.dense(&1, 1)
        x = dense0.(x)
        x = Axon.activation(x, :relu)
        x = dropout.(x)
        x = dense1.(x)
        x = Axon.activation(x, :softmax)
        x
    ```
    """
  },
  %Instruction{
    name: :convert_initializers,
    instruction: """
    Replace initializer functions with the corresponding Axon functions according to this table:
    """,
    conversion_table: conversion_initializers_table,
    example_input: """
    """,
    example_output: """
    """
  },
  %Instruction{
    name: :convert_activations,
    instruction: """
    Replace activation functions.
    You must follow these rules to replace activation functions: 
    - If there is an activation parameter, and the code makes use of ACT2FN[activation], replace that with Axon.activation(activation). 
    - If there is an actual activation function called, replace it with Axon's activation function according to the conversion table. 

    Use this table:
    """,
    conversion_table: conversion_activations_table,
    example_input: """
    """,
    example_output: """
    """
  },
  %Instruction{
    name: :wrap_in_function,
    instruction: """
     Wrap the Axon model in a function that takes all the required parameters from the `call` function
    - name the function corresponding to the class name, but snake case.
    - take the same arguments as the `call` function
    - remove the self argument
    - remove the dtype argument
    - remove the type specs if present
    - wrap the function in do ... end
    """,
    example_input: """
    ```python
    class SlightlyLarger(nn.Module):
      def __call__(self, x, training: bool):
        x = Axon.dense(x, 32)
        x = Axon.activation(x, :relu)
        x = Axon.dropout(x, rate: 0.5)
        x = Axon.dense(x, 1)
        x = Axon.activation(x, :softmax)
        x
    ```
    """,
    example_output: """
    ```python
    def slightly_larger(x, training) do
      x = Axon.dense(x, 32)
      x = Axon.activation(x, :relu)
      x = Axon.dropout(x, rate: 0.5)
      x = Axon.dense(x, 1)
      x = Axon.activation(x, :softmax)
      x
    end  
    ```
    """
  },
  %Instruction{
    name: :config,
    instruction: """
    If there is a config parameter replace it with an Elixir keyword list.
    Retrieve the config values using Keyword.get() at the beginning of the function.
    Provide sensible default values.
    """,
    example_input: """
    ```python
    def slightly_larger(x, config) do
        out_channels = config.out_channels
        x = Axon.dense(x, out_channels)
        x
    ```
    """,
    example_output: """
    ```python
    def slightly_larger(x, config) do
        out_channels = Keyword.get(config, :out_channels, 3)
        x = Axon.dense(x, out_channels)
        x
    end  
    ```
    """
  },
  %Instruction{
    name: :fix_param_values,
    instruction: """
    Correct the values of Axon parameters according to the tables.  
    """
  },
  %Instruction{
    name: :multiple_return_values,
    instruction: """
    If there are multiple return values, wrap them into an Axon.container.
    """,
    example_input: """
    ```python
    def slightly_larger(hidden_state, hidden_states, config) do
      out_channels = config.out_channels
      hidden_state = Axon.dense(hidden_state, out_channels)
  
      hidden_state, hidden_state
    ```
    """,
    example_output: """
    ```python
    def slightly_larger(hidden_state, hidden_states, config) do
      out_channels = config.out_channels
      hidden_state = Axon.dense(hidden_state, out_channels)

      Axon.container(%{hidden_state: hidden_state, hidden_states: hidden_states})
    ```
    """
  },
  %Instruction{
    name: :remove_directives,
    instruction: """
    Remove all `use` and `import` directives from the Elixir code.    
    """
  },
  %Instruction{
    name: :valid_elixir_code,
    instruction: """
    Check if the function is valid Elixir code. 
    Otherwise, fix all issues by converting Python expressions to Elixir expressions. E.g. // corresponds to div, scientific notation like 1e-05 needs a decimal point in Elixir 1.0e-05.
    """
  }
]

Conversion - Code

defmodule Conversion do
  use Agent
  alias LangChain.Message
  alias LangChain.Chains.LLMChain

  @system_message system_message
  @instructions conversion_instructions
  @no_op_instruction "return unchanged code"
  @template template

  def num_instructions, do: length(@instructions)

  def conversion_instruction(index) do
    if index < num_instructions() do
      Enum.at(@instructions, index)
    else
      @no_op_instruction
    end
  end

  def start_link() do
    Agent.start_link(fn -> %{} end)
  end

  def get(converted, key) do
    Agent.get(converted, &amp;Map.get(&amp;1, key))
  end

  def put(converted, key, code, params_mapping) do
    Agent.update(converted, &amp;Map.put(&amp;1, key, %{code: code, params_mapping: params_mapping}))
  end

  def code_to_string(converted) do
    converted_funcs = Agent.get(converted, &amp;Map.values(&amp;1))

    converted_funcs
    |> Enum.map(&amp; &amp;1.code)
    |> Enum.join("\n\n")
  end

  def params_mappings_to_string(converted) do
    converted_funcs = Agent.get(converted, &amp;Map.values(&amp;1))

    converted_funcs
    |> Enum.map(&amp; &amp;1.params_mapping)
    |> Enum.join("\n")
  end

  def new_chain!(llm_config) do
    {:ok, chain, _response} =
      %{llm: llm_config}
      |> LLMChain.new!()
      |> LLMChain.add_message(Message.new_system!(@system_message))
      |> LLMChain.run()

    chain
  end

  defp send_user_message(chain, message) do
    chain
    |> LLMChain.add_message(Message.new_user!(message))
    |> LLMChain.run()
  end

  defp instruction_message(model_code, instruction) do
    prompt_inputs = Map.from_struct(instruction) |> Map.put(:model_code, model_code)

    LangChain.PromptTemplate.format(@template, prompt_inputs)
  end

  def update_model(model_code, chain, instruction) do
    message = instruction_message(model_code, instruction)

    case send_user_message(chain, message) do
      {:ok, updated_chain, response} -> {:ok, updated_chain, response.content}
      {:error, _chain, _response} -> {:error, chain, model_code}
    end
  end
end
defmodule Model do
  require Logger

  def apply_model(module_name, model_fn_name, args) do
    try do
      {:ok, apply(module_name, String.to_atom(model_fn_name), args)}
    rescue
      e ->
        Logger.error(e)
        {:error, e}
    end
  end

  def model_string(model_code, module_name, additional_code) do
    """
    defmodule #{module_name} do
    #{model_code}

    #{additional_code}
    end
    """
  end

  def build_model(model_code, axon_function_name, build_args, additional_code) do
    model = model_string(model_code, "ModelTest", additional_code)

    try do
      Code.eval_string(model, [], __ENV__)

      {:ok, model} = apply_model(Model.ModelTest, axon_function_name, build_args)

      model = Axon.build(model)

      {:ok, model}
    rescue
      error -> 
        Logger.error(error)
        {:error, error}
    end
  end

  def get_params(init_fn, input_shape, input_type) do
    try do
      params = init_fn.(Nx.template(input_shape, input_type), %{})
      {:ok, params}
    rescue
      error -> 
        Logger.error(error)
        {:error, error}
    end
  end
end

Verification

Check that Python installation works

IO.puts("Python is here: #{python}")
{_, 0} = run_python.("print('hello from Python')", [])
defmodule Verification do
  @python python

  defp run_python(code, opts) do
    System.cmd(@python, ["-c", code], opts)
  end

  @doc """
  Define paths for `safetensors` files, we will use those to work on the same numbers in Python and Elixir.
  """
  def safetensor_files(dir, name) do
    {Path.join(dir, "#{name}_params_axon.safetensors"),
     Path.join(dir, "#{name}_params_flax.safetensors"),
     Path.join(dir, "#{name}_test_data_axon.safetensors"),
     Path.join(dir, "#{name}_test_data_flax.safetensors")}
  end

  @doc """
  Write `safetensors` files for Axon and Flax params.
  """
  def save_params(params, param_mapping, axon_params_path, flax_params_path) do
    axon_params =
      for {axon_key, _} <- param_mapping, into: %{} do
        {axon_key, ParamsUtils.get_from_flattened_key(params, axon_key)}
      end

    flax_params =
      for {axon_key, flax_key} <- param_mapping, into: %{} do
        {flax_key, ParamsUtils.get_from_flattened_key(params, axon_key)}
      end

    Safetensors.write!(axon_params_path, axon_params)
    Safetensors.write!(flax_params_path, flax_params)
  end

  def run_axon_and_save(predict_fn, params, input_shape, path) do
    input_data =
      for dim <- Enum.reverse(Tuple.to_list(input_shape)), reduce: StreamData.float() do
        acc -> StreamData.list_of(acc, length: dim)
      end

    test_data =
      for i <- 0..100 do
        input =
          input_data
          |> Enum.take(1)
          |> hd
          |> Nx.tensor()

        input_name = "input_#{i}"
        output_name = "output_#{i}"

        output = predict_fn.(params, input)

        [{input_name, input}, {output_name, output}]
      end
      |> List.flatten()
      |> Map.new()

    Safetensors.write!(path, test_data)
  end

  # need path for import, e.g. transformers.models.resnet.modeling_flax_resnet
  def run_flax_and_save(
        module,
        module_args,
        input_path,
        output_path,
        params_path,
        import_path,
        additional_imports
      ) do
    flax_script =
      """
      import jax
      from typing import Any, Callable, Sequence
      from jax import random, numpy as jnp
      import flax
      from flax import linen as nn

      from functools import partial
      from typing import Optional, Tuple

      from safetensors import safe_open
      from safetensors.flax import save_file

      from #{import_path} import #{module}

      #{additional_imports}

      def unflatten_dict(d, sep='.'):
        result = {}
        for key, value in d.items():
            parts = key.split(sep)
            node = result
            for part in parts[:-1]:
                node = node.setdefault(part, {})
            node[parts[-1]] = value
        return result

      tensors = {}
      with safe_open("#{input_path}", framework="flax") as f:
          for k in f.keys():
              tensors[k] = f.get_tensor(k)

      print("initializing: #{module}(#{module_args})")
      model = #{module}(#{module_args})


      params = {}
      with safe_open("#{params_path}", framework="flax") as f:
          for k in f.keys():
              params[k] = f.get_tensor(k)

      params = unflatten_dict(params)

      out_tensors = tensors.copy()
      input_keys = [key for key in tensors.keys() if key.startswith("input")]
      for input_key in input_keys:  
        input = tensors[input_key]

        output = model.apply(params, input)
        output_key = input_key.replace("input", "output")

        out_tensors[output_key] = output

      save_file(out_tensors, "#{output_path}")
      """

    run_python(flax_script, [])
  end

  defp assert_all_close(left, right) do
    atol = 1.0e-4
    rtol = 1.0e-4

    equals =
      left
      |> Nx.all_close(right, atol: atol, rtol: rtol)
      |> Nx.backend_transfer(Nx.BinaryBackend)

    equals == Nx.tensor(1, type: :u8, backend: Nx.BinaryBackend)
  end

  defp same_result?(axon_result, flax_result) do
    assert_all_close(axon_result, flax_result)
  end

  def verification_results(axon_path, flax_path) do
    axon_data = Safetensors.read!(axon_path)
    flax_data = Safetensors.read!(flax_path)

    for output_key <- Map.keys(axon_data), String.starts_with?(output_key, "output"), into: %{} do
      input_key = String.replace(output_key, "output", "input")

      got_same? = same_result?(axon_data[output_key], flax_data[output_key])

      {output_key,
       %{
         same_result?: got_same?,
         input: axon_data[input_key],
         axon_output: axon_data[output_key],
         flax_output: flax_data[output_key]
       }}
    end
  end

  def verify_model(axon_model, flax_model, params_mapping) do
    with {:ok, {init_fn, predict_fn}} <-
           Model.build_model(
             axon_model.code,
             axon_model.name,
             axon_model.build_args,
             axon_model.additional_code
           ),
         {:ok, params} <- Model.get_params(init_fn, axon_model.input_shape, axon_model.input_type) do
      save_params(params, params_mapping, axon_model.params_path, flax_model.params_path)

      run_axon_and_save(predict_fn, params, axon_model.input_shape, axon_model.results_path)

      run_flax_and_save(
        flax_model.name,
        flax_model.model_args,
        axon_model.results_path,
        flax_model.results_path,
        flax_model.params_path,
        flax_model.import_path,
        flax_model.additional_imports
      )

      # check if results are the same 
      results = verification_results(axon_model.results_path, flax_model.results_path)

      if Enum.all?(results, fn {_k, res} -> res.same_result? end) do
        {:ok, results}
      else
        {:verification_failed, results}
      end
    end
  end
end
defmodule ParamsUtils do
  def flatten_keys(%{} = params) do
    for key <- Map.keys(params) do
      prefixed_keys(params[key], key)
    end
    |> List.flatten()
  end

  defp prefixed_keys(%Nx.Tensor{}, key), do: key

  defp prefixed_keys(%{} = params, prefix) do
    for key <- Map.keys(params) do
      prefixed_keys(params[key], "#{prefix}.#{key}")
    end
  end

  def get_from_flattened_key(params, flattened_key) do
    keys = String.split(flattened_key, ".")

    for key <- keys, reduce: params do
      acc -> acc[key]
    end
  end

  def unflatten_and_put(params, flattened_key, value) do
    single_param_map = flattened_map(flattened_key, value)

    merge_recursive(params, single_param_map)
  end

  def merge_recursive(%{} = map1, %{} = map2) do
    Map.merge(map1, map2, fn _k, m1, m2 -> merge_recursive(m1, m2) end)
  end

  defp flattened_map(flattened_key, value) do
    case String.split(flattened_key, ".", parts: 2) do
      [key] -> %{key => value}
      [key, other_keys] -> %{key => flattened_map(other_keys, value)}
    end
  end
end

Params Mapping

defmodule ParamsMapping do
  require Logger
  
  alias LangChain.Message
  alias LangChain.Chains.LLMChain

  @python python

  def new_chain!(llm_config) do
    {:ok, chain, _response} =
      %{llm: llm_config}
      |> LLMChain.new!()
      |> LLMChain.add_message(
        Message.new_system!("You are an expert in Python, Flax, Elixir and Axon.")
      )
      |> LLMChain.run()

    chain
  end

  defp params_mapping_prompt(axon_params, flax_params) do
    """
    this is a first set of ids of params of a neural network:
    #{axon_params}

    this is a second set of ids of params of a neural network:
    #{flax_params}

    Both sets of ids refer to the same params. Create an Elixir map for the corresponding ids.
    The keys of the map should be the parameter ids from the first set, the corresponding value should be the id from the second set that belongs to the key 
    end.
    Reply ONLY with the Elixir map.
    """
  end

  def axon_params(model, input_shape, input_type) do
    try do
      {init_fn, _predict_fn} = Axon.build(model)

      params =
        init_fn.(Nx.template(input_shape, input_type), %{})
        |> ParamsUtils.flatten_keys()

      {:ok, params}
    rescue
      e -> 
        Logger.error(e)
        {:error, e}
    end
  end

  def flax_params(module, import_path, additional_imports, model_args, input_shape) do
    flax_shape =
      for num <- Tuple.to_list(input_shape), into: "(" do
        "#{num}, "
      end
      |> String.replace_trailing(", ", ")")

    input = "jnp.broadcast_to(1, #{flax_shape})"

    python_script =
      """
      import jax
      from typing import Any, Callable, Sequence
      from jax import random, numpy as jnp
      import flax
      from flax import linen as nn

      from functools import partial
      from typing import Optional, Tuple

      from safetensors import safe_open
      from safetensors.flax import save_file

      from #{import_path} import #{module}

      #{additional_imports}
      model = #{module}(#{model_args})
      key = random.key(0)
      input = #{input}

      params = model.init(key, input)


      def flatten_dict(d, parent_key='', sep='.'):
          items = []
          for k, v in d.items():
              new_key = parent_key + sep + k if parent_key else k
              if isinstance(v, dict):
                  items.extend(flatten_dict(v, new_key, sep=sep).items())
              else:
                  items.append((new_key, v))
          return dict(items)

      l = flatten_dict(params)
      print(list(l.keys()))
      """

    case System.cmd(@python, ["-c", python_script], []) do
      {params, 0} -> {:ok, params}
      {output, _code} -> {:error, output}
    end
  end

  def params_mapping(chain, axon_params, flax_params) do
    prompt = params_mapping_prompt(axon_params, flax_params)

    {:ok, _chain, response} =
      chain
      |> LLMChain.add_message(Message.new_user!(prompt))
      |> LLMChain.run()

    response.content
  end
end

Interactive Tool - Global Definitions

{:ok, converted} = Conversion.start_link()

conversion_chain = Conversion.new_chain!(llm_config)
params_mapping_chain = ParamsMapping.new_chain!(llm_config)

Kino.nothing()

Interactive Tool - Setup

input_file = Kino.Input.file("Flax File")
i = Kino.Input.read(input_file)

input =
  case i do
    %{file_ref: file_ref} -> Kino.Input.file_path(file_ref) |> File.read!()
    nil -> ""
  end

inputs = Utils.top_level_defs(input)

select_list =
  for {input, index} <- Enum.with_index(inputs) do
    [label | _] = String.split(input, "(")
    {index, label}
  end

select_list =
  if Enum.empty?(select_list) do
    [{0, "No Flax Input"}]
  else
    select_list
  end

target_select = Kino.Input.select("Conversion Target", select_list)
import Kino.Shorts

model_index = Kino.Input.read(target_select)
model = Enum.at(inputs, model_index)

run_instruction = Kino.Control.button("Run Instruction")
run_all_remaining_instructions = Kino.Control.button("Run All Remaining")
instruction_status = frame()
previous_instruction = Kino.Control.button("Previous")
next_instruction = Kino.Control.button("Next")

save = Kino.Control.button("Save Converted Model")
reload = Kino.Control.button("Reload")

previous_model_frame = frame()
current_model_frame = frame()

instruction_frame = frame()

first_instruction = Conversion.conversion_instruction(0)
num_instructions = Conversion.num_instructions()

Kino.Frame.render(
  instruction_frame,
  markdown("[0/#{num_instructions}] #{first_instruction.instruction}")
)

original_model_text = Kino.Input.textarea("Model Code", default: model, monospace: true)
Kino.Frame.render(previous_model_frame, text("No previous model yet"))
Kino.Frame.render(current_model_frame, original_model_text)

conversion_layout =
  grid(
    [
      markdown("### Next LLM instruction"),
      text(""),
      instruction_frame,
      grid(
        [
          previous_instruction,
          run_instruction,
          next_instruction,
          run_all_remaining_instructions,
          instruction_status
        ],
        boxed: true,
        columns: 2
      ),
      markdown("### Previous version"),
      markdown("### Next LLM input"),
      previous_model_frame,
      current_model_frame
    ],
    boxed: true,
    columns: 2
  )

params_mapping = Kino.Input.textarea("Params Mapping")
generate_params_mapping = Kino.Control.button("Generate Params Mapping")
generate_params_status = frame()
Kino.Frame.render(generate_params_status, text("Waiting"))
params_mapping_frame = frame()
Kino.Frame.render(params_mapping_frame, params_mapping)

import_path = Kino.Input.text("Import Path", default: "transformers.models.this")
additional_imports = Kino.Input.text("Additional Imports", default: "#from path import Class")
input_shape = Kino.Input.text("Input Shape", default: "{3, 3}")
input_type = Kino.Input.text("Input Type", default: "f32")

axon_build_args =
  Kino.Input.text("Axon Build Args (function inputs)", default: "[Axon.input(\"in\"), 1, 2]")

flax_model_args = Kino.Input.text("Flax Model Args (model attributes)", default: "3, 5")
verification_params_frame = frame()
verification_status = frame()
Kino.Frame.render(verification_status, text("Waiting"))

Kino.Frame.render(
  verification_params_frame,
  grid(
    [
      import_path,
      additional_imports,
      input_shape,
      input_type,
      axon_build_args,
      flax_model_args
    ],
    boxed: true,
    columns: 2
  )
)

verify = Kino.Control.button("Run Verification")
result_status_frame = frame()
results_frame = frame()

verification_layout =
  grid(
    [
      current_model_frame,
      reload,
      params_mapping_frame,
      grid([generate_params_mapping, generate_params_status]),
      verification_params_frame,
      grid([verify, verification_status]),
      markdown("### Verification Status"),
      markdown("### Failed Verification Results"),
      result_status_frame,
      results_frame
    ],
    boxed: true,
    columns: 2,
    gap: 16
  )

converted_frame = frame()

converted_code_md =
  """
  ```elixir
  defmodule YourAxonModel do
    #{Conversion.code_to_string(converted)}
  end
  ```
  """

Kino.Frame.render(converted_frame, markdown(converted_code_md))

converted_layout =
  grid(
    [
      markdown("### Converted"),
      text(""),
      converted_frame,
      save
    ],
    boxed: true,
    columns: 2
  )

ui =
  grid([
    grid([markdown("### Select Input File"), input_file]),
    target_select,
    text(""),
    markdown("### Original Model"),
    markdown("```python\n#{model}\n```"),
    tabs(
      Conversion: conversion_layout,
      Verification: verification_layout
    ),
    converted_layout
  ])

inputs = %{
  target_select: target_select,
  model_input: original_model_text,
  import_path: import_path,
  additional_imports: additional_imports,
  input_shape: input_shape,
  input_type: input_type,
  params_mapping: params_mapping,
  axon_build_args: axon_build_args,
  flax_model_args: flax_model_args
}

stream =
  Kino.Control.tagged_stream(
    run_instruction: run_instruction,
    run_all_remaining_instructions: run_all_remaining_instructions,
    previous_instruction: previous_instruction,
    next_instruction: next_instruction,
    save: save,
    verify: verify,
    reload: reload,
    gen_params_mapping: generate_params_mapping
  )

Kino.nothing()
Kino.listen(
  stream,
  {conversion_chain, inputs, 0},
  fn event, {chain, inputs, instruction_index} ->
    case event do
      {:run_all_remaining_instructions, _e} ->
        model_code = Kino.Input.read(inputs.model_input)
        Kino.Frame.render(instruction_status, text("Running remaining instructions"))

        {chain, model_code} =
          for index <- instruction_index..num_instructions, reduce: {chain, model_code} do
            {chain, model_code} ->
              instruction =
                Conversion.conversion_instruction(index)

              {:ok, updated_chain, model_code} =
                Conversion.update_model(model_code, chain, instruction)

              {updated_chain, model_code}
          end

        Kino.Frame.render(instruction_status, text("Waiting for instruction"))

        model_code =
          model_code
          |> String.trim_leading("```elixir")
          |> String.trim_leading("```python")
          |> String.trim_trailing("```")

        model_input =
          Kino.Input.textarea("Model Code", default: model_code, monospace: true)

        Kino.Frame.render(current_model_frame, model_input)

        last_instruction = Conversion.conversion_instruction(num_instructions)

        Kino.Frame.render(
          instruction_frame,
          markdown("[#{num_instructions}/#{num_instructions}] #{last_instruction.instruction}")
        )

        {:cont, {chain, %{inputs | model_input: model_input}, num_instructions}}

      {:run_instruction, _e} ->
        instruction =
          Conversion.conversion_instruction(instruction_index)

        model_code =
          Kino.Input.read(inputs.model_input)
          |> Kernel.then(fn code ->
            """
            ```
            #{code}
            ```
            """
          end)

        Kino.Frame.render(previous_model_frame, markdown(model_code))

        Kino.Frame.render(instruction_status, text("Running instruction"))

        {:ok, updated_chain, model_code} =
          Conversion.update_model(model_code, chain, instruction)

        model_code =
          model_code
          |> String.trim_leading("```elixir")
          |> String.trim_leading("```python")
          |> String.trim_trailing("```")

        model_input =
          Kino.Input.textarea("Model Code", default: model_code, monospace: true)

        Kino.Frame.render(current_model_frame, model_input)

        next_index = instruction_index + 1
        next_instruction = Conversion.conversion_instruction(next_index)

        Kino.Frame.render(
          instruction_frame,
          markdown("[#{next_index}/#{num_instructions}] #{next_instruction.instruction}")
        )

        Kino.Frame.render(instruction_status, text("Waiting for instruction"))

        {:cont, {updated_chain, %{inputs | model_input: model_input}, next_index}}

      {:previous_instruction, _e} ->
        previous_index =
          if instruction_index > 0 do
            instruction_index - 1
          else
            0
          end

        previous_instruction =
          Conversion.conversion_instruction(previous_index)

        Kino.Frame.render(
          instruction_frame,
          markdown("[#{previous_index}/#{num_instructions}] #{previous_instruction.instruction}")
        )

        {:cont, {chain, inputs, previous_index}}

      {:next_instruction, _e} ->
        next_index =
          if instruction_index < num_instructions do
            instruction_index + 1
          else
            instruction_index
          end

        next_instruction =
          Conversion.conversion_instruction(next_index)

        Kino.Frame.render(
          instruction_frame,
          markdown("[#{next_index}/#{num_instructions}] #{next_instruction.instruction}")
        )

        {:cont, {chain, inputs, next_index}}

      {:save, _e} ->
        conversion_target = Kino.Input.read(inputs.target_select)
        model_code = Kino.Input.read(inputs.model_input)

        params_mapping = Kino.Input.read(inputs.params_mapping)

        Conversion.put(converted, conversion_target, model_code, params_mapping)

        converted_code_md = """
        ```elixir
        defmodule YourAxonModel do
        #{Conversion.code_to_string(converted)}
        end
        ```
        """

        Kino.Frame.render(converted_frame, markdown(converted_code_md))

        {:cont, {chain, inputs, instruction_index}}

      {:verify, _e} ->
        converted_code = Kino.Input.read(inputs.model_input)
        original_code = Kino.Input.read(original_model_text)

        {flax_name, axon_name} = Utils.extract_names(original_code)

        input_type = Kino.Input.read(inputs.input_type) |> String.to_atom()
        {input_shape, _} = Kino.Input.read(inputs.input_shape) |> Code.eval_string()

        {params_mapping, _} = Kino.Input.read(inputs.params_mapping) |> Code.eval_string()

        import_path = Kino.Input.read(inputs.import_path)
        additional_imports = Kino.Input.read(inputs.additional_imports)

        {axon_build_args, _} = Kino.Input.read(inputs.axon_build_args) |> Code.eval_string()
        flax_model_args = Kino.Input.read(inputs.flax_model_args)

        previously_converted_code = Conversion.code_to_string(converted)

        {axon_params_path, flax_params_path, axon_path, flax_path} =
          Verification.safetensor_files(data_dir, axon_name)

        flax_model_data = %{
          name: flax_name,
          code: original_code,
          model_args: flax_model_args,
          import_path: import_path,
          additional_imports: additional_imports,
          params_path: flax_params_path,
          results_path: flax_path
        }

        axon_model_data = %{
          name: axon_name,
          code: converted_code,
          additional_code: previously_converted_code,
          input_shape: input_shape,
          input_type: input_type,
          build_args: axon_build_args,
          params_path: axon_params_path,
          results_path: axon_path
        }

        Kino.Frame.render(verification_status, text("Running Verification"))

        case Verification.verify_model(
               axon_model_data,
               flax_model_data,
               params_mapping
             ) do
          {:ok, verification_details} ->
            Kino.Frame.render(verification_status, text("Waiting"))

            Kino.Frame.render(result_status_frame, :ok)

            failed =
              Enum.filter(verification_details, fn {_key, value} -> not value.same_result? end)

            Kino.Frame.render(results_frame, failed)

          {:error, error} ->
            Kino.Frame.render(verification_status, error)
            Kino.Frame.render(result_status_frame, :error)
        end

        {:cont, {chain, inputs, instruction_index}}

      {:reload, _e} ->
        converted_code =
          Kino.Input.read(inputs.model_input)
          |> String.trim()

        model_input =
          Kino.Input.textarea("Model Input", default: converted_code, monospace: true)

        Kino.Frame.render(current_model_frame, model_input)

        params_mapping = Kino.Input.read(inputs.params_mapping)
        params_mapping = Kino.Input.textarea("Params Mapping", default: params_mapping)
        Kino.Frame.render(params_mapping_frame, params_mapping)
        Kino.Frame.render(generate_params_status, text("Waiting"))

        import_path = Kino.Input.read(inputs.import_path)
        import_path = Kino.Input.text("Import Path", default: import_path)

        additional_imports = Kino.Input.read(inputs.additional_imports)
        additional_imports = Kino.Input.text("Additional Imports", default: additional_imports)
        input_shape = Kino.Input.read(inputs.input_shape)
        input_shape = Kino.Input.text("Input Shape", default: input_shape)
        input_type = Kino.Input.read(inputs.input_type)
        input_type = Kino.Input.text("Input Type", default: input_type)

        axon_build_args = Kino.Input.read(inputs.axon_build_args)

        axon_build_args =
          Kino.Input.text("Axon Build Args (function inputs)", default: axon_build_args)

        flax_model_args = Kino.Input.read(inputs.flax_model_args)

        flax_model_args =
          Kino.Input.text("Flax Model Args (model attributes)", default: flax_model_args)

        Kino.Frame.render(
          verification_params_frame,
          grid(
            [
              import_path,
              additional_imports,
              input_shape,
              input_type,
              axon_build_args,
              flax_model_args
            ],
            boxed: true,
            columns: 2
          )
        )

        Kino.Frame.render(verification_status, text("Waiting"))

        updated_inputs = %{
          inputs
          | model_input: model_input,
            params_mapping: params_mapping,
            import_path: import_path,
            additional_imports: additional_imports,
            input_shape: input_shape,
            input_type: input_type,
            axon_build_args: axon_build_args,
            flax_model_args: flax_model_args
        }

        {:cont, {chain, updated_inputs, instruction_index}}

      {:gen_params_mapping, _e} ->
        converted_code = Kino.Input.read(inputs.model_input)
        original_code = Kino.Input.read(original_model_text)

        {flax_name, axon_name} = Utils.extract_names(original_code)

        input_type = Kino.Input.read(inputs.input_type) |> String.to_atom()
        {input_shape, _} = Kino.Input.read(inputs.input_shape) |> Code.eval_string()

        import_path = Kino.Input.read(inputs.import_path)
        additional_imports = Kino.Input.read(inputs.additional_imports)

        {axon_build_args, _} = Kino.Input.read(inputs.axon_build_args) |> Code.eval_string()
        flax_model_args = Kino.Input.read(inputs.flax_model_args)

        previously_converted = Conversion.code_to_string(converted)

        with {:ok, _model} <-
               Model.build_model(converted_code, axon_name, axon_build_args, previously_converted),
             {:ok, model} <-
               Model.apply_model(Model.ModelTest, axon_name, axon_build_args),
             {:ok, axon_params} <-
               ParamsMapping.axon_params(model, input_shape, input_type),
             {:ok, flax_params} <-
               ParamsMapping.flax_params(
                 flax_name,
                 import_path,
                 additional_imports,
                 flax_model_args,
                 input_shape
               ) do
          Kino.Frame.render(generate_params_status, text("Generating Params Mapping"))

          params_mapping =
            ParamsMapping.params_mapping(params_mapping_chain, axon_params, flax_params)
            |> String.trim_leading("```elixir")
            |> String.trim_trailing("```")

          Kino.Frame.render(generate_params_status, text("Waiting"))

          params_mapping = Kino.Input.textarea("Params Mapping", default: params_mapping)

          Kino.Frame.render(params_mapping_frame, params_mapping)

          {:cont, {chain, %{inputs | params_mapping: params_mapping}, instruction_index}}
        else
          {:error, error} ->
            Kino.Frame.render(generate_params_status, error)

            old_params_mapping = Kino.Input.read(inputs.params_mapping)
            params_mapping = Kino.Input.textarea("Params Mapping", default: old_params_mapping)
            Kino.Frame.render(params_mapping_frame, params_mapping)

            {:cont, {chain, %{inputs | params_mapping: params_mapping}, instruction_index}}
        end
    end
  end
)

Interactive Tool - UI

Amplify the output of the next cell

ui