Flax to Axon - Interactive Tool
Mix.install([
{:axon, "~> 0.6.1"},
{:stream_data, "~> 1.1"},
{:nx, "~> 0.7.2"},
{:safetensors, "~> 0.1.3"},
{:kino, "~> 0.12.3"},
{:langchain, "~> 0.3.0-rc.0"},
{:exla, "~> 0.7.3"}
])
Nx.global_default_backend(EXLA.Backend)
asdf_dir = "#{__DIR__}/.asdf"
unless File.exists?(asdf_dir) do
{_, 0} =
System.cmd("git", [
"clone",
"https://github.com/asdf-vm/asdf.git",
asdf_dir,
"--branch",
"v0.14.0"
])
end
asdf = "#{asdf_dir}/bin/asdf"
{_, 0} = System.cmd(asdf, ["plugin", "add", "python"], env: [{"ASDF_DATA_DIR", asdf_dir}])
{_, 0} =
System.cmd(asdf, ["install", "python", "3.11.9"], env: [{"ASDF_DATA_DIR", "#{__DIR__}/.asdf"}])
asdf_python = Path.join([asdf_dir, "installs", "python", "3.11.9", "bin", "python"])
python_packages =
~w(
safetensors
torch
transformers
accelerate
numpy
datasets
pillow
flax
jax
jaxlib
)
venv_dir = Path.join(__DIR__, "flax_to_axon_env")
{_, 0} = System.cmd(asdf_python, ["-m", "venv", "--copies", venv_dir])
python = Path.join([venv_dir, "bin", "python"])
pip = Path.join([venv_dir, "bin", "pip"])
{_, 0} = System.cmd(pip, ["install" | python_packages])
run_python = fn command, opts ->
System.cmd(python, ["-c", command], opts)
end
data_dir = Path.join(__DIR__, "data")
unless File.exists?(data_dir), do: :ok = File.mkdir(data_dir)
WARNING
> This Livebook installs asdf, Python and some libraries in the directory of the livebook.
> Modify the notebook setup if you don’t want that.
>
> It also runs LLM generated Elixir code using Code.eval_string
.
Introduction
You can configure the LLM provider and model here. For OpenAI, you must set the OPENAI_KEY
secret to your api key.
api_key = System.get_env("LB_OPENAI_KEY")
llm_config =
LangChain.ChatModels.ChatOpenAI.new!(%{
model: "gpt-4o-mini",
api_key: api_key,
seed: 0
})
Goals
-
(Semi-) Automatically convert Flax models to Axon (from
transformers
/diffusers
) - Verify that models compute same results for same inputs
For now, we consider only Flax (linen) because it’s similar to Axon and should be easier than PyTorch or TensorFlow
flowchart TD
Input[Read Flax File]
Classes[Split into classes and free functions]
subgraph each_class[for each class/free function]
subgraph conversion[Convert into Axon]
Convert[Convert using LLM instructions]
Replace[Replace previously \n converted classes \n with Axon function]
end
subgraph verify[Verify same results]
direction TB
Generate[Generate input and params in Elixir]
Compute_axon[Compute results in Axon]
Safetensors[Save everything in safetensors]
Load[Load safetensors in Python]
Compute_flax[Compute results in Flax]
Compare[Compare outputs for inputs]
end
Store[Store resulting Axon function]
end
Input-->Classes-->each_class
Convert-->Replace
Convert-->Convert
conversion-->verify
Generate-->Compute_axon-->Safetensors-->Load-->Compute_flax-->Compare
verify-->Store
Get inputs
We read a Python file which contains the model code. Then we get all top level definitions from the file.
Models in Flax are classes, so they are named in CamelCase. In Axon we represent them as functions, named in snake_case.
defmodule Utils do
@moduledoc """
Some utils to extract the information we need from Flax code.
"""
@doc """
Splits Flax code into its top level definitions.
Considers only classes and free functions.
Ignores everything before the first class or function definition, i.e. imports etc.
"""
def top_level_defs(flax_code) do
[_preamble | top_level_defs] =
Regex.split(~r{\n(class |def )\b}, flax_code, include_captures: true)
def_type = Enum.take_every(top_level_defs, 2)
def_content = Enum.drop(top_level_defs, 1) |> Enum.take_every(2)
Enum.zip_with(def_type, def_content, fn type, content ->
String.trim_leading(type) <> content
end)
end
@doc """
Finds the first top level definition in `flax_code`.
Then returns a tuple with the name in Flax and the corresponding name in Axon.
"""
def extract_names(flax_code) do
[flax_name | _] = String.split(flax_code, "(")
flax_name =
flax_name
|> String.trim_leading("class ")
|> String.trim_leading("def ")
axon_name = Macro.underscore(flax_name)
{flax_name, axon_name}
end
end
Conversion - Prompt Definitions
In this section we define a set of messages we pass to the LLM to make it help us convert the Flax code.
- System Messages to give the LLM some context
- Conversion tables for layers, activation functions and initializers (partially generated by ChatGPT)
- Instructions the LLM will perform one at a time to convert the code
system_message =
"""
You are an expert in machine learning frameworks.
You help converting models from Python code in Flax linen framework to Elixir code in the Axon framework.
You will do this by following instructions step by step.
For each of the instructions, reply ONLY with the modified code of the model.
Do NOT include any other content or comments.
Perform ONLY the action the instruction asks you to do.
In case you don't change anything return ONLY the unchanged model code.
I will provide you with the model code and an instruction in the following format:
INSTRUCTION: here is the instruction you perform
MODEL: here is the model code you will modify
"""
Kino.nothing()
Conversion Tables
Created using ChatGPT and refining format.
Layers
Prompt:
Fetch the documentation about layers in Flax linen from here: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html
Build a separate markdown table for each of the layers in Flax linen and the corresponding layer in Axon.
Each table should have two rows:
First row: shows the Flax linen layer in the first column, then a column for each parameter it takes. Prefix the linen layer with nn.
Second row: shows the corresponding Axon layer in the first column, then the corresponding Axon parameter to each of the Flax linen parameters. Prefix the Axon layer with Axon.Layer.
If there is no corresponding layer, or no corresponding parameter, add a “-“ instead.
Here an example for the Conv layer of Flax linen:
Flax | nn.Conv | features | kernel_size | strides | padding | dtype | use_bias | kernel_init |
Axon | Axon.conv | units | kernel_size | strides | padding | - | use_bias | kernel_initializer |
conversion_layers_table =
"""
### Conv Layer
| | Layer | features | kernel_size | strides | padding | dtype | use_bias | kernel_init |
| -------- | --------- | -------- | ----------- | ------- | ------- | ----- | -------- | ------------------ |
| **Flax** | nn.Conv | features | kernel_size | strides | padding | dtype | use_bias | kernel_init |
| **Axon** | Axon.conv | units | kernel_size | strides | padding | - | use_bias | kernel_initializer |
### Dense Layer
| | Layer | features | use_bias | kernel_init |
| -------- | ---------- | -------- | -------- | ------------------ |
| **Flax** | nn.Dense | features | use_bias | kernel_init |
| **Axon** | Axon.dense | units | use_bias | kernel_initializer |
### Dropout Layer
| | Layer | rate |
| -------- | ------------ | ---- |
| **Flax** | nn.Dropout | rate |
| **Axon** | Axon.dropout | rate |
### BatchNorm Layer
| | Layer | use_running_average | axis | momentum | epsilon | dtype |
| -------- | --------------- | ------------------- | ---- | -------- | ------- | ----- |
| **Flax** | nn.BatchNorm | use_running_average | axis | momentum | epsilon | dtype |
| **Axon** | Axon.batch_norm | - | axis | momentum | epsilon | - |
### LayerNorm Layer
| | Layer | axis | epsilon | dtype |
| -------- | --------------- | ---- | ------- | ----- |
| **Flax** | nn.LayerNorm | axis | epsilon | dtype |
| **Axon** | Axon.layer_norm | axis | epsilon | - |
### Relu Layer
| | Layer |
| -------- | --------- |
| **Flax** | nn.relu |
| **Axon** | Axon.relu |
### Sigmoid Layer
| | Layer |
| -------- | ------------ |
| **Flax** | nn.sigmoid |
| **Axon** | Axon.sigmoid |
### Tanh Layer
| | Layer |
| -------- | --------- |
| **Flax** | nn.tanh |
| **Axon** | Axon.tanh |
### MaxPool Layer
| | Layer | window_shape | strides | padding |
| -------- | ------------- | ------------ | ------- | ------- |
| **Flax** | nn.max_pool | window_shape | strides | padding |
| **Axon** | Axon.max_pool | kernel_size | strides | padding |
### AvgPool Layer
| | Layer | window_shape | strides | padding |
| -------- | ------------- | ------------ | ------- | ------- |
| **Flax** | nn.avg_pool | window_shape | strides | padding |
| **Axon** | Axon.avg_pool | kernel_size | strides | padding |
### GlobalAvgPool Layer
| | Layer |
| -------- | -------------------- |
| **Flax** | nn.global_avg_pool |
| **Axon** | Axon.global_avg_pool |
### Flatten Layer
| | Layer |
| -------- | ------------ |
| **Flax** | nn.Flatten |
| **Axon** | Axon.flatten |
"""
Kino.Markdown.new(conversion_layers_table)
Initializers
Prompt:
Fetch the documentation about initializers in Axon from here: https://hexdocs.pm/axon/Axon.Initializers.html
Fetch the documentation about initializers in Flax linen from here: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/initializers.html
Build a separate markdown table for each of the initializers in Flax linen and the corresponding initializer in Axon.
Each table should have two rows:
First row: shows the Flax linen initializer in the first column, then a column for each parameter it takes. Prefix the linen initializer with nn.
Second row: shows the corresponding Axon initializer in the first column, then the corresponding Axon parameter to each of the Flax linen parameters. Prefix the Axon initializer with Axon.Layer.
If there is no corresponding initializer, or no corresponding parameter, add a “-“ instead.
Here an example for the initializers.variance_scaling
of Flax linen:
Flax | nn.initializers.variance_scaling | scale | mode | distribution | in_axis | out_axis | batch_axis | dtype |
Axon | Axon.Initializers.variance_scaling | scale | mode | distribution | - | - | - | - |
conversion_initializers_table =
"""
### Variance Scaling
| | | | | | | | | |
| ---- | ---------------------------------- | ----- | ---- | ------------ | ------- | -------- | ---------- | ----- |
| Flax | nn.initializers.variance_scaling | scale | mode | distribution | in_axis | out_axis | batch_axis | dtype |
| Axon | Axon.Initializers.variance_scaling | scale | mode | distribution | - | - | - | - |
### Glorot Normal
| | | |
| ---- | ------------------------------- | ----- |
| Flax | nn.initializers.glorot_normal | dtype |
| Axon | Axon.Initializers.glorot_normal | - |
### Glorot Uniform
| | | |
| ---- | -------------------------------- | ----- |
| Flax | nn.initializers.glorot_uniform | dtype |
| Axon | Axon.Initializers.glorot_uniform | - |
### Lecun Normal
| | | |
| ---- | ------------------------------ | ----- |
| Flax | nn.initializers.lecun_normal | dtype |
| Axon | Axon.Initializers.lecun_normal | - |
### Lecun Uniform
| | | |
| ---- | ------------------------------- | ----- |
| Flax | nn.initializers.lecun_uniform | dtype |
| Axon | Axon.Initializers.lecun_uniform | - |
### Orthogonal
| | | |
| ---- | ---------------------------- | ------------ |
| Flax | nn.initializers.orthogonal | scale |
| Axon | Axon.Initializers.orthogonal | distribution |
### Constant
| | | |
| ---- | ------------------------ | ----- |
| Flax | nn.initializers.constant | value |
| Axon | Axon.Initializers.full | value |
### Normal
| | | | |
| ---- | ------------------------ | ---- | ------ |
| Flax | nn.initializers.normal | mean | stddev |
| Axon | Axon.Initializers.normal | mean | scale |
### Uniform
| | | | |
| ---- | ------------------------- | ------ | ------ |
| Flax | nn.initializers.uniform | minval | maxval |
| Axon | Axon.Initializers.uniform | - | scale |
### Identity
| | |
| ---- | -------------------------- |
| Flax | nn.initializers.identity |
| Axon | Axon.Initializers.identity |
### Ones
| | |
| ---- | ---------------------- |
| Flax | nn.initializers.ones |
| Axon | Axon.Initializers.ones |
### Zeros
| | |
| ---- | ----------------------- |
| Flax | nn.initializers.zeros |
| Axon | Axon.Initializers.zeros |
"""
Kino.Markdown.new(conversion_initializers_table)
Activation Functions
Prompt:
Fetch the documentation about activations in Axon from here: https://hexdocs.pm/axon/Axon.Activations.html
Fetch the documentation about activations in Flax linen from here: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/activation_functions.html
Build a separate markdown table for each of the activations in Flax linen and the corresponding activation in Axon.
Each table should have two rows:
First row: shows the Flax linen activation in the first column, then a column for each parameter it takes. Prefix the linen activation with nn.
Second row: shows the corresponding Axon activation in the first column, then the corresponding Axon parameter to each of the Flax linen parameters. Prefix the Axon activation with Axon.Activations.
If there is no corresponding activation, or no corresponding parameter, add a “-“ instead.
Here an example for activation.softmax
of Flax linen:
Flax | nn.activation.softmax | axis | where | initial |
Axon | Axon.Activations.softmax | axis | - | - |
conversion_activations_table =
"""
### Softmax
| | | | | |
| ---- | ------------------------ | ---- | ----- | ------- |
| Flax | nn.softmax | axis | where | initial |
| Axon | Axon.Activations.softmax | axis | - | - |
### ReLU
| | |
| ---- | --------------------- |
| Flax | nn.relu |
| Axon | Axon.Activations.relu |
### Leaky ReLU
| | | |
| ---- | --------------------------- | -------------- |
| Flax | nn.leaky_relu | negative_slope |
| Axon | Axon.Activations.leaky_relu | alpha |
### Sigmoid
| | |
| ---- | ------------------------ |
| Flax | nn.sigmoid |
| Axon | Axon.Activations.sigmoid |
### Tanh
| | |
| ---- | --------------------- |
| Flax | nn.tanh |
| Axon | Axon.Activations.tanh |
### GELU
| | | |
| ---- | --------------------- | ----------- |
| Flax | nn.gelu | approximate |
| Axon | Axon.Activations.gelu | - |
### ELU
| | | |
| ---- | -------------------- | ----- |
| Flax | nn.elu | alpha |
| Axon | Axon.Activations.elu | alpha |
### SELU
| | |
| ---- | --------------------- |
| Flax | nn.selu |
| Axon | Axon.Activations.selu |
### Softplus
| | |
| ---- | ------------------------- |
| Flax | nn.softplus |
| Axon | Axon.Activations.softplus |
### Swish
| | |
| ---- | ---------------------- |
| Flax | nn.swish |
| Axon | Axon.Activations.swish |
### Log Softmax
| | | | | |
| ---- | ---------------------------- | ---- | ----- | ------- |
| Flax | nn.log_softmax | axis | where | initial |
| Axon | Axon.Activations.log_softmax | axis | - | - |
"""
Kino.Markdown.new(conversion_activations_table)
Detailed instructions how to convert the model
defmodule Instruction do
defstruct [:name, :instruction, :conversion_table, :example_input, :example_output]
end
template = LangChain.PromptTemplate.from_template!(
"""
INSTRUCTION:
<%= @instruction %>
<%= if assigns[:conversion_table] do %>
<%= @conversion_table %>
<% end %>
Here an example.
INPUT:
<%= @example_input %>
OUTPUT:
<%= @example_output %>
```
DO NOT PERFORM ANY OTHER ACTION, IF THERE IS NOTHING TO DO RETURN THE UNCHANGED CODE.
MODEL:
<%= @model_code %>
"""
)
conversion_instructions = [
%Instruction{
name: :setup,
instruction: """
If there is a setup function, move all of its content to the `__call__` function.
Then remove the setup function.
""",
example_input: """
```python
class SlightlyLargerSetup(nn.Module):
def setup(self):
self.dense0 = nn.Dense(features=32)
self.dropout = nn.Dropout(rate=0.5)
self.dense1 = nn.Dense(features=1)
def __call__(self, x, training: bool):
x = self.dense0(x)
x = nn.relu(x)
x = self.dropout(x, deterministic=not training)
x = self.dense1(x)
x = nn.softmax(x)
return x
```
""",
example_output: """
```python
class SlightlyLargerSetup(nn.Module):
def __call__(self, x, training: bool):
self.dense0 = nn.Dense(features=32)
self.dropout = nn.Dropout(rate=0.5)
self.dense1 = nn.Dense(features=1)
x = self.dense0(x)
x = nn.relu(x)
x = self.dropout(x, deterministic=not training)
x = self.dense1(x)
x = nn.softmax(x)
return x
```
"""
},
%Instruction{
name: :attributes,
instruction: """
If there are attributes, move them to the `__call__` function as parameters and replace all the references of the attributes with the function parameters.
""",
example_input: """
```python
class SlightlyLargerAttributes(nn.Module):
features_dense_0: int
dropout_rate: int
features_dense_1: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(features=self.features_dense_0)(x)
x = nn.relu(x)
x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
x = nn.Dense(features=self.features_dense_1)(x)
x = nn.softmax(x)
return x
```
""",
example_output: """
```python
class SlightlyLargerAttributes(nn.Module):
@nn.compact
def __call__(self, x, training: bool, features_dense_0: int, dropout_rate: int, features_dense_1: int):
x = nn.Dense(features=features_dense_0)(x)
x = nn.relu(x)
x = nn.Dropout(rate=dropout_rate, deterministic=not training)(x)
x = nn.Dense(features=features_dense_1)(x)
x = nn.softmax(x)
return x
```
"""
},
%Instruction{
name: :move_params_to_declaration,
instruction: """
In the `__call__` function, move all additional parameters when calling the layers to the initialization of the layers. Each layer should be called with a single argument.
""",
example_input: """
```python
class SlightlyLargerAttributes(nn.Module):
@nn.compact
def __call__(self, x, training: bool, features_dense_0: int, dropout_rate: int, features_dense_1: int):
x = nn.Dense(features=features_dense_0)(x)
x = nn.relu(x)
x = nn.Dropout(rate=dropout_rate)(x, deterministic=not training)
x = nn.Dense(features=features_dense_1)(x)
x = nn.softmax(x)
return x
```
""",
example_output: """
```python
class SlightlyLargerAttributes(nn.Module):
@nn.compact
def __call__(self, x, training: bool, features_dense_0: int, dropout_rate: int, features_dense_1: int):
x = nn.Dense(features=features_dense_0)(x)
x = nn.relu(x)
x = nn.Dropout(rate=dropout_rate, deterministic=not training)(x)
x = nn.Dense(features=features_dense_1)(x)
x = nn.softmax(x)
return x
```
"""
},
%Instruction{
name: :remove_self,
instruction: """
Remove all occurrences of "self." in the `__call__` function.
Then drop the "self" parameter of the `__call__` function.
""",
example_input: """
```python
class SlightlyLargerSetup(nn.Module):
def __call__(self, x, training: bool):
self.dense0 = nn.Dense(features=32)
self.dropout = nn.Dropout(rate=0.5)
self.dense1 = nn.Dense(features=1)
x = self.dense0(x)
x = nn.relu(x)
x = self.dropout(x, deterministic=not training)
x = self.dense1(x)
x = nn.softmax(x)
return x
```
""",
example_output: """
```python
class SlightlyLargerSetup(nn.Module):
def __call__(x, training: bool):
dense0 = nn.Dense(features=32)
dropout = nn.Dropout(rate=0.5)
dense1 = nn.Dense(features=1)
x = dense0(x)
x = nn.relu(x)
x = dropout(x, deterministic=not training)
x = dense1(x)
x = nn.softmax(x)
return x
```
"""
},
%Instruction{
name: :replace_numeric_ops,
instruction: """
Replace numeric operations with corresponding Axon functions when the operands are tensors.
According to the following table:
""",
example_input: """
```python
class SomeModule(nn.Module):
def __call__(self, hidden_state, residual ):
hidden_state = hidden_state + residual
return hidden_state
```
""",
example_output: """
```python
class SomeModule(nn.Module):
def __call__(self, hidden_state, residual ):
hidden_state = Axon.add(hidden_state, residual)
return hidden_state
```
""",
conversion_table: """
| Flax operation | Axon layer |
| --- | --- |
| + | Axon.add() |
| - | Axon.substract() |
| * | Axon.multiply() |
"""
},
%Instruction{
name: :replace_partials,
instruction: """
Replace `partial` calls with a lambda in Elixir.
""",
example_input: """
```python
class SomeModule(nn.Module):
def __call__(self, hidden_state ):
hidden_state = partial(myfunc, 1, 2, 3)
return hidden_state
```
""",
example_output: """
```python
class SomeModule(nn.Module):
def __call__(self, hidden_state ):
hidden_state = &myfunc(&1,1, 2, 3)
return hidden_state
```
"""
},
%Instruction{
name: :replace_loops,
instruction: """
Replace for loops with an Elixir comprehension with reduce.
""",
example_input: """
```python
class SomeModule(nn.Module):
def __call__(self, hidden_state):
layers = [dense0, dense1]
for layer in layers:
hidden_state = layer(hidden_state)
return hidden_state
```
""",
example_output: """
```python
class SomeModule(nn.Module):
def __call__(self, hidden_state):
layers = [dense0, dense1]
hidden_state = for layer <- layers, reduce: hidden_state do
layer(hidden_state)
end
return hidden_state
```
"""
},
%Instruction{
name: :custom_classes,
instruction: """
Rewrite custom class names written in CamelCase inside the `__call__` function in snake_case.
""",
example_input: """
```python
class SomeModule(nn.Module):
def __call__(self, hidden_state):
conv_layer = FlaxConv(3)
hidden_state = conv_layer(hidden_state)
return hidden_state
```
""",
example_output: """
```python
class SomeModule(nn.Module):
def __call__(self, hidden_state):
conv_layer = flax_conv(3)
hidden_state = conv_layer(hidden_state)
return hidden_state
```
"""
},
%Instruction{
name: :convert_layers,
instruction: """
Replace the Flax layers with the corresponding Axon layers.
Return only the Elixir code for the model.
Take into account the parameters in the first parenthesis.
You must follow these rules to replace layers:
- When a Flax layer is initialized but not called, replace it with a lambda of the corresponding Axon layer.o
Example: replace `dense0 = nn.Dense(features=32)` with `dense0 = &Axon.dense(&1, 32)`
- When a previously initialized Flax layer is called, replace it with a call of the corresponding Axon lambda.
Example: replace `x = dense0(x)` with `x = dense0.(x)`
- When the Flax code specifies a name for the layer, add that name as option to the Axon layer:
Example: replace `dense0 = nn.Dense(features=32, name="dense0")` with `dense0 = &Axon.dense(&1, 32, name: "dense0")`
Use this table to find corresponding layers:
""",
conversion_table: conversion_layers_table,
example_input: """
```python
class SlightlyLarger(nn.Module):
@nn.compact
def __call__(self, x, training: bool):
dense0 = nn.Dense(features=32)
dropout = nn.Dropout(rate=0.5)
dense1 = nn.Dense(features=1)
x = dense0(x)
x = nn.relu(x)
x = dropout(x, deterministic=not training)
x = dense1(x)
x = nn.softmax(x)
return x
```
""",
example_output: """
```python
class SlightlyLarger(nn.Module):
def __call__(self, x, training: bool):
dense0 = &Axon.dense(&1, 32)
dropout = &Axon.dropout(&1, rate: 0.5)
dense1 = &Axon.dense(&1, 1)
x = dense0.(x)
x = Axon.activation(x, :relu)
x = dropout.(x)
x = dense1.(x)
x = Axon.activation(x, :softmax)
x
```
"""
},
%Instruction{
name: :convert_initializers,
instruction: """
Replace initializer functions with the corresponding Axon functions according to this table:
""",
conversion_table: conversion_initializers_table,
example_input: """
""",
example_output: """
"""
},
%Instruction{
name: :convert_activations,
instruction: """
Replace activation functions.
You must follow these rules to replace activation functions:
- If there is an activation parameter, and the code makes use of ACT2FN[activation], replace that with Axon.activation(activation).
- If there is an actual activation function called, replace it with Axon's activation function according to the conversion table.
Use this table:
""",
conversion_table: conversion_activations_table,
example_input: """
""",
example_output: """
"""
},
%Instruction{
name: :wrap_in_function,
instruction: """
Wrap the Axon model in a function that takes all the required parameters from the `call` function
- name the function corresponding to the class name, but snake case.
- take the same arguments as the `call` function
- remove the self argument
- remove the dtype argument
- remove the type specs if present
- wrap the function in do ... end
""",
example_input: """
```python
class SlightlyLarger(nn.Module):
def __call__(self, x, training: bool):
x = Axon.dense(x, 32)
x = Axon.activation(x, :relu)
x = Axon.dropout(x, rate: 0.5)
x = Axon.dense(x, 1)
x = Axon.activation(x, :softmax)
x
```
""",
example_output: """
```python
def slightly_larger(x, training) do
x = Axon.dense(x, 32)
x = Axon.activation(x, :relu)
x = Axon.dropout(x, rate: 0.5)
x = Axon.dense(x, 1)
x = Axon.activation(x, :softmax)
x
end
```
"""
},
%Instruction{
name: :config,
instruction: """
If there is a config parameter replace it with an Elixir keyword list.
Retrieve the config values using Keyword.get() at the beginning of the function.
Provide sensible default values.
""",
example_input: """
```python
def slightly_larger(x, config) do
out_channels = config.out_channels
x = Axon.dense(x, out_channels)
x
```
""",
example_output: """
```python
def slightly_larger(x, config) do
out_channels = Keyword.get(config, :out_channels, 3)
x = Axon.dense(x, out_channels)
x
end
```
"""
},
%Instruction{
name: :fix_param_values,
instruction: """
Correct the values of Axon parameters according to the tables.
"""
},
%Instruction{
name: :multiple_return_values,
instruction: """
If there are multiple return values, wrap them into an Axon.container.
""",
example_input: """
```python
def slightly_larger(hidden_state, hidden_states, config) do
out_channels = config.out_channels
hidden_state = Axon.dense(hidden_state, out_channels)
hidden_state, hidden_state
```
""",
example_output: """
```python
def slightly_larger(hidden_state, hidden_states, config) do
out_channels = config.out_channels
hidden_state = Axon.dense(hidden_state, out_channels)
Axon.container(%{hidden_state: hidden_state, hidden_states: hidden_states})
```
"""
},
%Instruction{
name: :remove_directives,
instruction: """
Remove all `use` and `import` directives from the Elixir code.
"""
},
%Instruction{
name: :valid_elixir_code,
instruction: """
Check if the function is valid Elixir code.
Otherwise, fix all issues by converting Python expressions to Elixir expressions. E.g. // corresponds to div, scientific notation like 1e-05 needs a decimal point in Elixir 1.0e-05.
"""
}
]
Conversion - Code
defmodule Conversion do
use Agent
alias LangChain.Message
alias LangChain.Chains.LLMChain
@system_message system_message
@instructions conversion_instructions
@no_op_instruction "return unchanged code"
@template template
def num_instructions, do: length(@instructions)
def conversion_instruction(index) do
if index < num_instructions() do
Enum.at(@instructions, index)
else
@no_op_instruction
end
end
def start_link() do
Agent.start_link(fn -> %{} end)
end
def get(converted, key) do
Agent.get(converted, &Map.get(&1, key))
end
def put(converted, key, code, params_mapping) do
Agent.update(converted, &Map.put(&1, key, %{code: code, params_mapping: params_mapping}))
end
def code_to_string(converted) do
converted_funcs = Agent.get(converted, &Map.values(&1))
converted_funcs
|> Enum.map(& &1.code)
|> Enum.join("\n\n")
end
def params_mappings_to_string(converted) do
converted_funcs = Agent.get(converted, &Map.values(&1))
converted_funcs
|> Enum.map(& &1.params_mapping)
|> Enum.join("\n")
end
def new_chain!(llm_config) do
{:ok, chain, _response} =
%{llm: llm_config}
|> LLMChain.new!()
|> LLMChain.add_message(Message.new_system!(@system_message))
|> LLMChain.run()
chain
end
defp send_user_message(chain, message) do
chain
|> LLMChain.add_message(Message.new_user!(message))
|> LLMChain.run()
end
defp instruction_message(model_code, instruction) do
prompt_inputs = Map.from_struct(instruction) |> Map.put(:model_code, model_code)
LangChain.PromptTemplate.format(@template, prompt_inputs)
end
def update_model(model_code, chain, instruction) do
message = instruction_message(model_code, instruction)
case send_user_message(chain, message) do
{:ok, updated_chain, response} -> {:ok, updated_chain, response.content}
{:error, _chain, _response} -> {:error, chain, model_code}
end
end
end
defmodule Model do
require Logger
def apply_model(module_name, model_fn_name, args) do
try do
{:ok, apply(module_name, String.to_atom(model_fn_name), args)}
rescue
e ->
Logger.error(e)
{:error, e}
end
end
def model_string(model_code, module_name, additional_code) do
"""
defmodule #{module_name} do
#{model_code}
#{additional_code}
end
"""
end
def build_model(model_code, axon_function_name, build_args, additional_code) do
model = model_string(model_code, "ModelTest", additional_code)
try do
Code.eval_string(model, [], __ENV__)
{:ok, model} = apply_model(Model.ModelTest, axon_function_name, build_args)
model = Axon.build(model)
{:ok, model}
rescue
error ->
Logger.error(error)
{:error, error}
end
end
def get_params(init_fn, input_shape, input_type) do
try do
params = init_fn.(Nx.template(input_shape, input_type), %{})
{:ok, params}
rescue
error ->
Logger.error(error)
{:error, error}
end
end
end
Verification
Check that Python installation works
IO.puts("Python is here: #{python}")
{_, 0} = run_python.("print('hello from Python')", [])
defmodule Verification do
@python python
defp run_python(code, opts) do
System.cmd(@python, ["-c", code], opts)
end
@doc """
Define paths for `safetensors` files, we will use those to work on the same numbers in Python and Elixir.
"""
def safetensor_files(dir, name) do
{Path.join(dir, "#{name}_params_axon.safetensors"),
Path.join(dir, "#{name}_params_flax.safetensors"),
Path.join(dir, "#{name}_test_data_axon.safetensors"),
Path.join(dir, "#{name}_test_data_flax.safetensors")}
end
@doc """
Write `safetensors` files for Axon and Flax params.
"""
def save_params(params, param_mapping, axon_params_path, flax_params_path) do
axon_params =
for {axon_key, _} <- param_mapping, into: %{} do
{axon_key, ParamsUtils.get_from_flattened_key(params, axon_key)}
end
flax_params =
for {axon_key, flax_key} <- param_mapping, into: %{} do
{flax_key, ParamsUtils.get_from_flattened_key(params, axon_key)}
end
Safetensors.write!(axon_params_path, axon_params)
Safetensors.write!(flax_params_path, flax_params)
end
def run_axon_and_save(predict_fn, params, input_shape, path) do
input_data =
for dim <- Enum.reverse(Tuple.to_list(input_shape)), reduce: StreamData.float() do
acc -> StreamData.list_of(acc, length: dim)
end
test_data =
for i <- 0..100 do
input =
input_data
|> Enum.take(1)
|> hd
|> Nx.tensor()
input_name = "input_#{i}"
output_name = "output_#{i}"
output = predict_fn.(params, input)
[{input_name, input}, {output_name, output}]
end
|> List.flatten()
|> Map.new()
Safetensors.write!(path, test_data)
end
# need path for import, e.g. transformers.models.resnet.modeling_flax_resnet
def run_flax_and_save(
module,
module_args,
input_path,
output_path,
params_path,
import_path,
additional_imports
) do
flax_script =
"""
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from functools import partial
from typing import Optional, Tuple
from safetensors import safe_open
from safetensors.flax import save_file
from #{import_path} import #{module}
#{additional_imports}
def unflatten_dict(d, sep='.'):
result = {}
for key, value in d.items():
parts = key.split(sep)
node = result
for part in parts[:-1]:
node = node.setdefault(part, {})
node[parts[-1]] = value
return result
tensors = {}
with safe_open("#{input_path}", framework="flax") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
print("initializing: #{module}(#{module_args})")
model = #{module}(#{module_args})
params = {}
with safe_open("#{params_path}", framework="flax") as f:
for k in f.keys():
params[k] = f.get_tensor(k)
params = unflatten_dict(params)
out_tensors = tensors.copy()
input_keys = [key for key in tensors.keys() if key.startswith("input")]
for input_key in input_keys:
input = tensors[input_key]
output = model.apply(params, input)
output_key = input_key.replace("input", "output")
out_tensors[output_key] = output
save_file(out_tensors, "#{output_path}")
"""
run_python(flax_script, [])
end
defp assert_all_close(left, right) do
atol = 1.0e-4
rtol = 1.0e-4
equals =
left
|> Nx.all_close(right, atol: atol, rtol: rtol)
|> Nx.backend_transfer(Nx.BinaryBackend)
equals == Nx.tensor(1, type: :u8, backend: Nx.BinaryBackend)
end
defp same_result?(axon_result, flax_result) do
assert_all_close(axon_result, flax_result)
end
def verification_results(axon_path, flax_path) do
axon_data = Safetensors.read!(axon_path)
flax_data = Safetensors.read!(flax_path)
for output_key <- Map.keys(axon_data), String.starts_with?(output_key, "output"), into: %{} do
input_key = String.replace(output_key, "output", "input")
got_same? = same_result?(axon_data[output_key], flax_data[output_key])
{output_key,
%{
same_result?: got_same?,
input: axon_data[input_key],
axon_output: axon_data[output_key],
flax_output: flax_data[output_key]
}}
end
end
def verify_model(axon_model, flax_model, params_mapping) do
with {:ok, {init_fn, predict_fn}} <-
Model.build_model(
axon_model.code,
axon_model.name,
axon_model.build_args,
axon_model.additional_code
),
{:ok, params} <- Model.get_params(init_fn, axon_model.input_shape, axon_model.input_type) do
save_params(params, params_mapping, axon_model.params_path, flax_model.params_path)
run_axon_and_save(predict_fn, params, axon_model.input_shape, axon_model.results_path)
run_flax_and_save(
flax_model.name,
flax_model.model_args,
axon_model.results_path,
flax_model.results_path,
flax_model.params_path,
flax_model.import_path,
flax_model.additional_imports
)
# check if results are the same
results = verification_results(axon_model.results_path, flax_model.results_path)
if Enum.all?(results, fn {_k, res} -> res.same_result? end) do
{:ok, results}
else
{:verification_failed, results}
end
end
end
end
defmodule ParamsUtils do
def flatten_keys(%{} = params) do
for key <- Map.keys(params) do
prefixed_keys(params[key], key)
end
|> List.flatten()
end
defp prefixed_keys(%Nx.Tensor{}, key), do: key
defp prefixed_keys(%{} = params, prefix) do
for key <- Map.keys(params) do
prefixed_keys(params[key], "#{prefix}.#{key}")
end
end
def get_from_flattened_key(params, flattened_key) do
keys = String.split(flattened_key, ".")
for key <- keys, reduce: params do
acc -> acc[key]
end
end
def unflatten_and_put(params, flattened_key, value) do
single_param_map = flattened_map(flattened_key, value)
merge_recursive(params, single_param_map)
end
def merge_recursive(%{} = map1, %{} = map2) do
Map.merge(map1, map2, fn _k, m1, m2 -> merge_recursive(m1, m2) end)
end
defp flattened_map(flattened_key, value) do
case String.split(flattened_key, ".", parts: 2) do
[key] -> %{key => value}
[key, other_keys] -> %{key => flattened_map(other_keys, value)}
end
end
end
Params Mapping
defmodule ParamsMapping do
require Logger
alias LangChain.Message
alias LangChain.Chains.LLMChain
@python python
def new_chain!(llm_config) do
{:ok, chain, _response} =
%{llm: llm_config}
|> LLMChain.new!()
|> LLMChain.add_message(
Message.new_system!("You are an expert in Python, Flax, Elixir and Axon.")
)
|> LLMChain.run()
chain
end
defp params_mapping_prompt(axon_params, flax_params) do
"""
this is a first set of ids of params of a neural network:
#{axon_params}
this is a second set of ids of params of a neural network:
#{flax_params}
Both sets of ids refer to the same params. Create an Elixir map for the corresponding ids.
The keys of the map should be the parameter ids from the first set, the corresponding value should be the id from the second set that belongs to the key
end.
Reply ONLY with the Elixir map.
"""
end
def axon_params(model, input_shape, input_type) do
try do
{init_fn, _predict_fn} = Axon.build(model)
params =
init_fn.(Nx.template(input_shape, input_type), %{})
|> ParamsUtils.flatten_keys()
{:ok, params}
rescue
e ->
Logger.error(e)
{:error, e}
end
end
def flax_params(module, import_path, additional_imports, model_args, input_shape) do
flax_shape =
for num <- Tuple.to_list(input_shape), into: "(" do
"#{num}, "
end
|> String.replace_trailing(", ", ")")
input = "jnp.broadcast_to(1, #{flax_shape})"
python_script =
"""
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from functools import partial
from typing import Optional, Tuple
from safetensors import safe_open
from safetensors.flax import save_file
from #{import_path} import #{module}
#{additional_imports}
model = #{module}(#{model_args})
key = random.key(0)
input = #{input}
params = model.init(key, input)
def flatten_dict(d, parent_key='', sep='.'):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
l = flatten_dict(params)
print(list(l.keys()))
"""
case System.cmd(@python, ["-c", python_script], []) do
{params, 0} -> {:ok, params}
{output, _code} -> {:error, output}
end
end
def params_mapping(chain, axon_params, flax_params) do
prompt = params_mapping_prompt(axon_params, flax_params)
{:ok, _chain, response} =
chain
|> LLMChain.add_message(Message.new_user!(prompt))
|> LLMChain.run()
response.content
end
end
Interactive Tool - Global Definitions
{:ok, converted} = Conversion.start_link()
conversion_chain = Conversion.new_chain!(llm_config)
params_mapping_chain = ParamsMapping.new_chain!(llm_config)
Kino.nothing()
Interactive Tool - Setup
input_file = Kino.Input.file("Flax File")
i = Kino.Input.read(input_file)
input =
case i do
%{file_ref: file_ref} -> Kino.Input.file_path(file_ref) |> File.read!()
nil -> ""
end
inputs = Utils.top_level_defs(input)
select_list =
for {input, index} <- Enum.with_index(inputs) do
[label | _] = String.split(input, "(")
{index, label}
end
select_list =
if Enum.empty?(select_list) do
[{0, "No Flax Input"}]
else
select_list
end
target_select = Kino.Input.select("Conversion Target", select_list)
import Kino.Shorts
model_index = Kino.Input.read(target_select)
model = Enum.at(inputs, model_index)
run_instruction = Kino.Control.button("Run Instruction")
run_all_remaining_instructions = Kino.Control.button("Run All Remaining")
instruction_status = frame()
previous_instruction = Kino.Control.button("Previous")
next_instruction = Kino.Control.button("Next")
save = Kino.Control.button("Save Converted Model")
reload = Kino.Control.button("Reload")
previous_model_frame = frame()
current_model_frame = frame()
instruction_frame = frame()
first_instruction = Conversion.conversion_instruction(0)
num_instructions = Conversion.num_instructions()
Kino.Frame.render(
instruction_frame,
markdown("[0/#{num_instructions}] #{first_instruction.instruction}")
)
original_model_text = Kino.Input.textarea("Model Code", default: model, monospace: true)
Kino.Frame.render(previous_model_frame, text("No previous model yet"))
Kino.Frame.render(current_model_frame, original_model_text)
conversion_layout =
grid(
[
markdown("### Next LLM instruction"),
text(""),
instruction_frame,
grid(
[
previous_instruction,
run_instruction,
next_instruction,
run_all_remaining_instructions,
instruction_status
],
boxed: true,
columns: 2
),
markdown("### Previous version"),
markdown("### Next LLM input"),
previous_model_frame,
current_model_frame
],
boxed: true,
columns: 2
)
params_mapping = Kino.Input.textarea("Params Mapping")
generate_params_mapping = Kino.Control.button("Generate Params Mapping")
generate_params_status = frame()
Kino.Frame.render(generate_params_status, text("Waiting"))
params_mapping_frame = frame()
Kino.Frame.render(params_mapping_frame, params_mapping)
import_path = Kino.Input.text("Import Path", default: "transformers.models.this")
additional_imports = Kino.Input.text("Additional Imports", default: "#from path import Class")
input_shape = Kino.Input.text("Input Shape", default: "{3, 3}")
input_type = Kino.Input.text("Input Type", default: "f32")
axon_build_args =
Kino.Input.text("Axon Build Args (function inputs)", default: "[Axon.input(\"in\"), 1, 2]")
flax_model_args = Kino.Input.text("Flax Model Args (model attributes)", default: "3, 5")
verification_params_frame = frame()
verification_status = frame()
Kino.Frame.render(verification_status, text("Waiting"))
Kino.Frame.render(
verification_params_frame,
grid(
[
import_path,
additional_imports,
input_shape,
input_type,
axon_build_args,
flax_model_args
],
boxed: true,
columns: 2
)
)
verify = Kino.Control.button("Run Verification")
result_status_frame = frame()
results_frame = frame()
verification_layout =
grid(
[
current_model_frame,
reload,
params_mapping_frame,
grid([generate_params_mapping, generate_params_status]),
verification_params_frame,
grid([verify, verification_status]),
markdown("### Verification Status"),
markdown("### Failed Verification Results"),
result_status_frame,
results_frame
],
boxed: true,
columns: 2,
gap: 16
)
converted_frame = frame()
converted_code_md =
"""
```elixir
defmodule YourAxonModel do
#{Conversion.code_to_string(converted)}
end
```
"""
Kino.Frame.render(converted_frame, markdown(converted_code_md))
converted_layout =
grid(
[
markdown("### Converted"),
text(""),
converted_frame,
save
],
boxed: true,
columns: 2
)
ui =
grid([
grid([markdown("### Select Input File"), input_file]),
target_select,
text(""),
markdown("### Original Model"),
markdown("```python\n#{model}\n```"),
tabs(
Conversion: conversion_layout,
Verification: verification_layout
),
converted_layout
])
inputs = %{
target_select: target_select,
model_input: original_model_text,
import_path: import_path,
additional_imports: additional_imports,
input_shape: input_shape,
input_type: input_type,
params_mapping: params_mapping,
axon_build_args: axon_build_args,
flax_model_args: flax_model_args
}
stream =
Kino.Control.tagged_stream(
run_instruction: run_instruction,
run_all_remaining_instructions: run_all_remaining_instructions,
previous_instruction: previous_instruction,
next_instruction: next_instruction,
save: save,
verify: verify,
reload: reload,
gen_params_mapping: generate_params_mapping
)
Kino.nothing()
Kino.listen(
stream,
{conversion_chain, inputs, 0},
fn event, {chain, inputs, instruction_index} ->
case event do
{:run_all_remaining_instructions, _e} ->
model_code = Kino.Input.read(inputs.model_input)
Kino.Frame.render(instruction_status, text("Running remaining instructions"))
{chain, model_code} =
for index <- instruction_index..num_instructions, reduce: {chain, model_code} do
{chain, model_code} ->
instruction =
Conversion.conversion_instruction(index)
{:ok, updated_chain, model_code} =
Conversion.update_model(model_code, chain, instruction)
{updated_chain, model_code}
end
Kino.Frame.render(instruction_status, text("Waiting for instruction"))
model_code =
model_code
|> String.trim_leading("```elixir")
|> String.trim_leading("```python")
|> String.trim_trailing("```")
model_input =
Kino.Input.textarea("Model Code", default: model_code, monospace: true)
Kino.Frame.render(current_model_frame, model_input)
last_instruction = Conversion.conversion_instruction(num_instructions)
Kino.Frame.render(
instruction_frame,
markdown("[#{num_instructions}/#{num_instructions}] #{last_instruction.instruction}")
)
{:cont, {chain, %{inputs | model_input: model_input}, num_instructions}}
{:run_instruction, _e} ->
instruction =
Conversion.conversion_instruction(instruction_index)
model_code =
Kino.Input.read(inputs.model_input)
|> Kernel.then(fn code ->
"""
```
#{code}
```
"""
end)
Kino.Frame.render(previous_model_frame, markdown(model_code))
Kino.Frame.render(instruction_status, text("Running instruction"))
{:ok, updated_chain, model_code} =
Conversion.update_model(model_code, chain, instruction)
model_code =
model_code
|> String.trim_leading("```elixir")
|> String.trim_leading("```python")
|> String.trim_trailing("```")
model_input =
Kino.Input.textarea("Model Code", default: model_code, monospace: true)
Kino.Frame.render(current_model_frame, model_input)
next_index = instruction_index + 1
next_instruction = Conversion.conversion_instruction(next_index)
Kino.Frame.render(
instruction_frame,
markdown("[#{next_index}/#{num_instructions}] #{next_instruction.instruction}")
)
Kino.Frame.render(instruction_status, text("Waiting for instruction"))
{:cont, {updated_chain, %{inputs | model_input: model_input}, next_index}}
{:previous_instruction, _e} ->
previous_index =
if instruction_index > 0 do
instruction_index - 1
else
0
end
previous_instruction =
Conversion.conversion_instruction(previous_index)
Kino.Frame.render(
instruction_frame,
markdown("[#{previous_index}/#{num_instructions}] #{previous_instruction.instruction}")
)
{:cont, {chain, inputs, previous_index}}
{:next_instruction, _e} ->
next_index =
if instruction_index < num_instructions do
instruction_index + 1
else
instruction_index
end
next_instruction =
Conversion.conversion_instruction(next_index)
Kino.Frame.render(
instruction_frame,
markdown("[#{next_index}/#{num_instructions}] #{next_instruction.instruction}")
)
{:cont, {chain, inputs, next_index}}
{:save, _e} ->
conversion_target = Kino.Input.read(inputs.target_select)
model_code = Kino.Input.read(inputs.model_input)
params_mapping = Kino.Input.read(inputs.params_mapping)
Conversion.put(converted, conversion_target, model_code, params_mapping)
converted_code_md = """
```elixir
defmodule YourAxonModel do
#{Conversion.code_to_string(converted)}
end
```
"""
Kino.Frame.render(converted_frame, markdown(converted_code_md))
{:cont, {chain, inputs, instruction_index}}
{:verify, _e} ->
converted_code = Kino.Input.read(inputs.model_input)
original_code = Kino.Input.read(original_model_text)
{flax_name, axon_name} = Utils.extract_names(original_code)
input_type = Kino.Input.read(inputs.input_type) |> String.to_atom()
{input_shape, _} = Kino.Input.read(inputs.input_shape) |> Code.eval_string()
{params_mapping, _} = Kino.Input.read(inputs.params_mapping) |> Code.eval_string()
import_path = Kino.Input.read(inputs.import_path)
additional_imports = Kino.Input.read(inputs.additional_imports)
{axon_build_args, _} = Kino.Input.read(inputs.axon_build_args) |> Code.eval_string()
flax_model_args = Kino.Input.read(inputs.flax_model_args)
previously_converted_code = Conversion.code_to_string(converted)
{axon_params_path, flax_params_path, axon_path, flax_path} =
Verification.safetensor_files(data_dir, axon_name)
flax_model_data = %{
name: flax_name,
code: original_code,
model_args: flax_model_args,
import_path: import_path,
additional_imports: additional_imports,
params_path: flax_params_path,
results_path: flax_path
}
axon_model_data = %{
name: axon_name,
code: converted_code,
additional_code: previously_converted_code,
input_shape: input_shape,
input_type: input_type,
build_args: axon_build_args,
params_path: axon_params_path,
results_path: axon_path
}
Kino.Frame.render(verification_status, text("Running Verification"))
case Verification.verify_model(
axon_model_data,
flax_model_data,
params_mapping
) do
{:ok, verification_details} ->
Kino.Frame.render(verification_status, text("Waiting"))
Kino.Frame.render(result_status_frame, :ok)
failed =
Enum.filter(verification_details, fn {_key, value} -> not value.same_result? end)
Kino.Frame.render(results_frame, failed)
{:error, error} ->
Kino.Frame.render(verification_status, error)
Kino.Frame.render(result_status_frame, :error)
end
{:cont, {chain, inputs, instruction_index}}
{:reload, _e} ->
converted_code =
Kino.Input.read(inputs.model_input)
|> String.trim()
model_input =
Kino.Input.textarea("Model Input", default: converted_code, monospace: true)
Kino.Frame.render(current_model_frame, model_input)
params_mapping = Kino.Input.read(inputs.params_mapping)
params_mapping = Kino.Input.textarea("Params Mapping", default: params_mapping)
Kino.Frame.render(params_mapping_frame, params_mapping)
Kino.Frame.render(generate_params_status, text("Waiting"))
import_path = Kino.Input.read(inputs.import_path)
import_path = Kino.Input.text("Import Path", default: import_path)
additional_imports = Kino.Input.read(inputs.additional_imports)
additional_imports = Kino.Input.text("Additional Imports", default: additional_imports)
input_shape = Kino.Input.read(inputs.input_shape)
input_shape = Kino.Input.text("Input Shape", default: input_shape)
input_type = Kino.Input.read(inputs.input_type)
input_type = Kino.Input.text("Input Type", default: input_type)
axon_build_args = Kino.Input.read(inputs.axon_build_args)
axon_build_args =
Kino.Input.text("Axon Build Args (function inputs)", default: axon_build_args)
flax_model_args = Kino.Input.read(inputs.flax_model_args)
flax_model_args =
Kino.Input.text("Flax Model Args (model attributes)", default: flax_model_args)
Kino.Frame.render(
verification_params_frame,
grid(
[
import_path,
additional_imports,
input_shape,
input_type,
axon_build_args,
flax_model_args
],
boxed: true,
columns: 2
)
)
Kino.Frame.render(verification_status, text("Waiting"))
updated_inputs = %{
inputs
| model_input: model_input,
params_mapping: params_mapping,
import_path: import_path,
additional_imports: additional_imports,
input_shape: input_shape,
input_type: input_type,
axon_build_args: axon_build_args,
flax_model_args: flax_model_args
}
{:cont, {chain, updated_inputs, instruction_index}}
{:gen_params_mapping, _e} ->
converted_code = Kino.Input.read(inputs.model_input)
original_code = Kino.Input.read(original_model_text)
{flax_name, axon_name} = Utils.extract_names(original_code)
input_type = Kino.Input.read(inputs.input_type) |> String.to_atom()
{input_shape, _} = Kino.Input.read(inputs.input_shape) |> Code.eval_string()
import_path = Kino.Input.read(inputs.import_path)
additional_imports = Kino.Input.read(inputs.additional_imports)
{axon_build_args, _} = Kino.Input.read(inputs.axon_build_args) |> Code.eval_string()
flax_model_args = Kino.Input.read(inputs.flax_model_args)
previously_converted = Conversion.code_to_string(converted)
with {:ok, _model} <-
Model.build_model(converted_code, axon_name, axon_build_args, previously_converted),
{:ok, model} <-
Model.apply_model(Model.ModelTest, axon_name, axon_build_args),
{:ok, axon_params} <-
ParamsMapping.axon_params(model, input_shape, input_type),
{:ok, flax_params} <-
ParamsMapping.flax_params(
flax_name,
import_path,
additional_imports,
flax_model_args,
input_shape
) do
Kino.Frame.render(generate_params_status, text("Generating Params Mapping"))
params_mapping =
ParamsMapping.params_mapping(params_mapping_chain, axon_params, flax_params)
|> String.trim_leading("```elixir")
|> String.trim_trailing("```")
Kino.Frame.render(generate_params_status, text("Waiting"))
params_mapping = Kino.Input.textarea("Params Mapping", default: params_mapping)
Kino.Frame.render(params_mapping_frame, params_mapping)
{:cont, {chain, %{inputs | params_mapping: params_mapping}, instruction_index}}
else
{:error, error} ->
Kino.Frame.render(generate_params_status, error)
old_params_mapping = Kino.Input.read(inputs.params_mapping)
params_mapping = Kino.Input.textarea("Params Mapping", default: old_params_mapping)
Kino.Frame.render(params_mapping_frame, params_mapping)
{:cont, {chain, %{inputs | params_mapping: params_mapping}, instruction_index}}
end
end
end
)
Interactive Tool - UI
Amplify the output of the next cell
ui