Flax to Axon - code to code
Mix.install([
{:axon, "~> 0.6.1"},
{:kino, "~> 0.12.3"}
])
Goal
Automatically import models from transformers
Doesn’t have to be 100% solution, every automated step helps
Consider only Flax because it’s the most similar
We want models that compute the same results for the same inputs
Non-Goals
Beautiful code that can be maintained easily. That’s up for bumblebee
flowchart TD
a[This]-->b[that]-->c[jo]
Basics
Basically three levels:
- code
- graph
- jax/nx/xla
model =
Axon.input("input")
|> Axon.dense(1)
Let’s run it with some input
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 2}, :f32), %{})
params = %{
"dense_0" => %{
"bias" => Nx.tensor(
[0.0], type: :f32
),
"kernel" =>
Nx.tensor([[2],[3]], type: :f32)
}
}
predict_fn.(params, Nx.tensor([[1.0, 1.0]]))
Flax/Jax
How does flax/Jax work, how to get computational graph
So JAX is basically the same as Nx. It gives you low level stuff to pass to XLA and run computations (I guess).
Flax is built on top of that to easily build models
Sounds familiar, yes it plays the same role as Axon.
They have a linen
API and an even newer NNX
API, we will focus on linen
for now.
If we could get all the data required to build our Axon struct from a Flax model we can use it to run the model in Elixir.
So let’s inspect Flax and define the same model as before.
from flax import linen as nn # Linen API
class Dense(nn.Module):
"""A simple Dense model."""
@nn.compact
def __call__(self, x):
x = nn.Dense(features=1)(x)
return x
model = Dense()
key = random.key(0)
input = [1, 1]
params = model.init(key, input)
params = {'params': {'Dense_0': {'kernel': jnp.array([[2.],[3.]]), 'bias': jnp.array([0.])}}}
model.apply(params, input)
==> [5.0]
So, how can we create an equivalent Axon model given a Flax model. Here are the two models again:
model =
Axon.input("input")
|> Axon.dense(1)
class Dense(nn.Module):
"""A simple Dense model."""
@nn.compact
def __call__(self, x):
x = nn.Dense(features=1)(x)
return x
With this simple dense model, it’s easy. In Axon, we specify a special input layer, then we append all the layers that are present in the Flax model. Here, it’s only a single dense layer with a parameter (units in Axon, features in Flax) set to 1.
After that each framework provides a way to initialize the model, get the params and actually run it.
Alright we have everything we need to manually construct an Axon model from a very simple Flax model.
Let’s recap the steps we need to take given a Flax model:
- Create an input layer in Axon
- Append all layers found in Flax with the corresponding parameters
A slightly larger model
Let’s take it one step further and look at a slightly larger Axon model (taken from the Axon guides).
slightly_larger_model =
Axon.input("input")
|> Axon.dense(32)
|> Axon.activation(:relu)
|> Axon.dropout(rate: 0.5)
|> Axon.dense(1)
|> Axon.activation(:softmax)
And the same model in Flax:
from flax import linen as nn # Linen API
from jax import numpy as jnp
class SlightlyLarger(nn.Module):
"""A slightly larger model."""
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(features=32)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
x = nn.Dense(features=1)(x)
x = nn.softmax(x)
return x
Let’s see if we can reconstruct the Axon model from the Flax model. For this we follow our steps:
- Create an input layer
- Append all the layers found in Axon with the corresponding
We define an input layer the same way as before
Axon.input("input")
Then we go on with one layer at a time:
nn.Dense(features=32)(x)
Axon.dense(32)
nn.relu(x)
Axon.activation(:relu)
nn.Dropout(rate=0.5, deterministic=not training)(x)
Axon.dropout(rate: 0.5)
Note that we don’t need the deterministic
parameter in Axon
nn.Dense(features=1)(x)
Axon.dense(1)
nn.softmax(x)
Axon.activation(:softmax)
And then we pipe the input through the Axon layers, and we’re done.
Axon.input("input")
|> Axon.dense(32)
|> Axon.activation(:relu)
|> Axon.dropout(rate: 0.5)
|> Axon.dense(1)
|> Axon.activation(:softmax)
Flax with setup call
That’s cool, it worked with a slightly larger model. However, the model is still pretty neat. In particular, it’s defined in a single function. This maps naturally to the way we define models in Axon.
So, let’s continue to make it more difficult.
First, what we’ve seen so far is the compact way to define a Flax model (note the @nn.compact
decorator). Another common way to define Flax models, is to first initialize layers in a setup
function.
Afterwards, those layers get called in the __call__
function as before.
When we define our slightly larger Flax model using a setup
function it looks like this:
class SlightlyLargerSetup(nn.Module):
"""A slightly larger model defined with setup function."""
def setup(self):
self.dense0 = nn.Dense(features=32)
self.dropout = nn.Dropout(rate=0.5)
self.dense1 = nn.Dense(features=1)
def __call__(self, x, training: bool):
x = self.dense0(x)
x = nn.relu(x)
x = self.dropout(x, deterministic=not training)
x = self.dense1(x)
x = nn.softmax(x)
return x
Now, this looks fairly similar. However, to go back to the compact version we must replace the calls to the methods stored in our module attributes with the actual Flax layer in the __call__
function, including the parameters passed to the layers during initialization in setup
:
class SlightlyLargerSetup(nn.Module):
"""A slightly larger model defined with setup function."""
def setup(self):
self.dense0 = nn.Dense(features=32)
self.dropout = nn.Dropout(rate=0.5)
self.dense1 = nn.Dense(features=1)
def __call__(self, x, training: bool):
x = nn.Dense(features=32)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5)(x, deterministic=not training)
x = nn.Dense(features=1)(x)
x = nn.softmax(x)
return x
Note that the deterministic
parameter of the dropout layer is now passed in the second pair of parentheses. While that works, we should move it back to the first pair.
When we remove the setup
function and add the @nn.compact
decorator, we’re now back with the exact same model definition as before, and we know how to convert it to an Axon model.
So, we should add two steps to our recipe:
-
If there is a
setup
function, replace the calls to the store methods in__call__
with the actual Flax layer, including the parameters. -
In the
__call__
function, move all parameters in second pairs of parentheses to the first pair. - Create an input layer in Axon
- Append all layers found in Flax with the corresponding parameters
Flax with attributes
Another fun thing you can do with Flax models is to pass attributes when creating the model. Have a look here:
from flax import linen as nn # Linen API
from jax import numpy as jnp
class SlightlyLargerAttributes(nn.Module):
"""A slightly larger model with attributes."""
features_dense_0: int
dropout_rate: int
features_dense_1: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(features=self.features_dense_0)(x)
x = nn.relu(x)
x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
x = nn.Dense(features=self.features_dense_1)(x)
x = nn.softmax(x)
return x
The attributes are then used to initialize the layers. As said, you pass values when creating the model:
model = SlightlyLargerAttributes(features_dense_0=32, dropout_rate=0.5, features_dense_1=1)
Then, you can go on as before:
key = random.key(0)
input = [1,1]
params = model.init(key, input, training=True)
model.apply(params, input, training=False, rngs={'dropout': key })
This way the models are parameterized. So, how would we convert this model to Axon?
Let’s start by moving the attributes to the __call__
function as parameters, and replace all the references of the attributes with the function parameters when initializing the layers:
class SlightlyLargerAttributes(nn.Module):
"""A slightly larger model with attributes."""
@nn.compact
def __call__(self, x, training: bool, features_dense_0: int, dropout_rate: int, features_dense_1: int):
x = nn.Dense(features=features_dense_0)(x)
x = nn.relu(x)
x = nn.Dropout(rate=dropout_rate, deterministic=not training)(x)
x = nn.Dense(features=features_dense_1)(x)
x = nn.softmax(x)
return x
Alright, this looks similar to the models before. We basically know how to get an Axon model from this.
However, now we do have parameters in the __call__
function, so what should we do with them?
Let’s wrap our regular Axon model definition in a function.
model_with_params =
fn %{
features_dense_0: features_dense_0,
dropout_rate: dropout_rate,
features_dense_1: features_dense_1
} ->
Axon.input("input")
|> Axon.dense(features_dense_0)
|> Axon.activation(:relu)
|> Axon.dropout(rate: dropout_rate)
|> Axon.dense(features_dense_1)
|> Axon.activation(:softmax)
end
This way we parameterized our Axon model as well. Watch how we can use it.
model = model_with_params.(%{features_dense_0: 32, dropout_rate: 0.5, features_dense_1: 1})
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 2}, :f32), %{})
predict_fn.(params, Nx.tensor([[1.0, 1.0]]))
We add some steps to our recipe:
-
If there is a
setup
function, replace the calls to the store methods in__call__
with the actual Flax layer, including the parameters. -
If there are attributes, move them to the
__call__
function as parameters - Replace all the references of the attributes with the function parameters when initializing the layers
-
In the
__call__
function, move all parameters in second pairs of parentheses to the first pair. - Create an input layer in Axon
- Append all layers found in Flax with the corresponding parameters
-
Wrap the Axon model in a function that takes all the required parameters from the
__call__
function
Flax with loops
So far so good. Let’s make it more difficult.
When defining models in Flax, you can use loops. Look at this.
from flax import linen as nn # Linen API
from jax import numpy as jnp
class SlightlyLargerLoop(nn.Module):
"""A slightly larger model with a loop."""
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(features=32)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
for i in range(1,3):
x = nn.Dense(features=i)(x)
x = nn.softmax(x)
return x
So, how would we create our Axon model from that?
Let’s extract the loop into a function.
loop = fn axon, enumerable ->
for i <- enumerable, reduce: axon do
ax -> ax |> Axon.dense(i)
end
end
Then we plug the function into our pipeline as a layer.
Axon.input("input")
|> Axon.dense(32)
|> Axon.activation(:relu)
|> Axon.dropout(rate: 0.5)
|> loop.(1..3)
|> Axon.dense(1)
|> Axon.activation(:softmax)
Or, we could plug the function directly into the pipeline using then
.
Axon.input("input")
|> Axon.dense(32)
|> Axon.activation(:relu)
|> Axon.dropout(rate: 0.5)
|> then(fn axon ->
for i <- 1..3, reduce: axon do
ax -> ax |> Axon.dense(i)
end
end)
|> Axon.dense(1)
|> Axon.activation(:softmax)
We might need to revisit loops later when we look at more complicated ones but for now we leave it at that.
Let’s add the steps we performed to the recipe.
-
If there is a
setup
function, replace the calls to the store methods in__call__
with the actual Flax layer, including the parameters. -
If there are attributes, move them to the
__call__
function as parameters and replace all the references of the attributes with the function parameters when initializing the layers -
In the
__call__
function, move all parameters in second pairs of parentheses to the first pair. - Extract loops to a seperate function, transform this function according to our rules. Plug the function into the pipeline.
- Create an input layer in Axon
- Append all layers found in Flax with the corresponding parameters
-
Wrap the Axon model in a function that takes all the required parameters from the
__call__
function
Flax with nested calls
So far so good. Let’s make it more difficult.
When defining models in Flax, you can use arbitrary functions. Here an example:
Combined Flax models
Of course, we can combine different Flax modules like in the following example.
from flax import linen as nn # Linen API
from jax import numpy as jnp
class DenseSoftmax(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=1)(x)
x = nn.softmax(x)
return x
class SlightlyLargerCombined(nn.Module):
"""A slightly larger model which uses another module."""
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(features=32)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
x = DenseSoftmax()(x)
return x
Here, SlightlyLargerCombined
makes use of DenseSoftmax
. How would we do convert this to Axon?
Well, we would first convert DenseSoftmax
without an input layer. Then, we wrap it in a function and plug it into the conversion of SligthlyLargerCombined
.
dense_softmax = fn axon ->
axon
|> Axon.dense(1)
|> Axon.softmax()
end
slightly_larger_combined =
Axon.input("input")
|> Axon.dense(32)
|> Axon.relu()
|> Axon.dropout(rate: 0.5)
|> dense_softmax.()
Again, we add our findings to the recipe.
-
If there is a
setup
function, replace the calls to the store methods in__call__
with the actual Flax layer, including the parameters. -
If there are attributes, move them to the
__call__
function as parameters and replace all the references of the attributes with the function parameters when initializing the layers -
In the
__call__
function, move all parameters in second pairs of parentheses to the first pair. - Extract loops to a seperate function, transform this function according to our rules. Plug the function into the pipeline.
- Append all layers found in Flax with the corresponding parameters
-
Wrap the Axon model in a function that takes all the required parameters from the
__call__
function - Convert all called modules, wrap them in a function and plug them in the pipeline.
- Create an input layer in Axon
Handle plain jnp <-> Nx
Real world model
Let’s get serious and have a look at a real world model, the Flax implementation of resnet
in transformers
, you can find it here.
We start with FlaxResNetConvLayer
.
class FlaxResNetConvLayer(nn.Module):
out_channels: int
kernel_size: int = 3
stride: int = 1
activation: Optional[str] = "relu"
dtype: jnp.dtype = jnp.float32
def setup(self):
self.convolution = nn.Conv(
self.out_channels,
kernel_size=(self.kernel_size, self.kernel_size),
strides=self.stride,
padding=self.kernel_size // 2,
dtype=self.dtype,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype),
)
self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity()
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = self.convolution(x)
hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
hidden_state = self.activation_func(hidden_state)
return hidden_state
Alright, let’s follow our rules and see if we can get an Axon model.
-
If there is a
setup
function, replace the calls to the stored methods in__call__
with the actual Flax layer, including the parameters. -
If there are attributes, move them to the
__call__
function as parameters and replace all the references of the attributes with the function parameters when initializing the layers -
In the
__call__
function, move all parameters in second pairs of parentheses to the first pair. -
Extract loops to a seperate function, transform this function according to our rules. Plug the function into the pipeline.
-
Append all layers found in Flax with the corresponding parameters
-
Wrap the Axon model in a function that takes all the required parameters from the
__call__
function -
Convert all called modules, wrap them in a function and plug them in the pipeline.
-
Create an input layer in Axon
-
If there is a
setup
function, replace the calls to the stored methods in__call__
with the actual Flax layer, including the parameters.
class FlaxResNetConvLayer(nn.Module):
out_channels: int
kernel_size: int = 3
stride: int = 1
activation: Optional[str] = "relu"
dtype: jnp.dtype = jnp.float32
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = nn.Conv(
self.out_channels,
kernel_size=(self.kernel_size, self.kernel_size),
strides=self.stride,
padding=self.kernel_size // 2,
dtype=self.dtype,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype),
)(x)
hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)(hidden_state, use_running_average=deterministic)
hidden_state = ACT2FN[self.activation] if self.activation is not None else Identity() (hidden_state)
return hidden_state
-
If there are attributes, move them to the
__call__
function as parameters and replace all the references of the attributes with the function parameters
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = nn.Conv(
out_channels,
kernel_size=(kernel_size, kernel_size),
strides=stride,
padding=kernel_size // 2,
dtype=dtype,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=dtype),
)(x)
hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=dtype)(hidden_state, use_running_average=deterministic)
hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
return hidden_state
-
In the
__call__
function, move all additional parameters in second pairs of parentheses to the first pair.
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = nn.Conv(
out_channels,
kernel_size=(kernel_size, kernel_size),
strides=stride,
padding=kernel_size // 2,
dtype=dtype,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=dtype),
)(x)
hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
return hidden_state
- Extract loops to a seperate function, transform this function according to our rules. Plug the function into the pipeline.
Nothing to do here, there is no loop.
-
Replace the Flax layers with the corresponding Axon layers according to our conversion table. Take into account the parameters in the first parenthesis. Replace initializer functions with the corresponding Axon functions. Follow these rules to replace activation functions: If there is an
activation
parameter, and the code makes use ofACT2FN[activation]
, replace that withAxon.activation(activation)
. If there is an actual activation function called, replace it with Axon’s activation function according to the conversion table. Replace all=
for arguments with:
We start with the conv layer
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = Axon.conv(
out_channels,
kernel_size={kernel_size, kernel_size},
strides=stride,
padding=kernel_size // 2,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=dtype),
)(x)
hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
return hidden_state
We must replace the kernel initializer.
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = Axon.conv(
out_channels,
kernel_size={kernel_size, kernel_size},
strides=stride,
padding=kernel_size // 2,
use_bias=False,
kernel_init=Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
)(x)
hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
return hidden_state
Next, the batch norm layer.
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = Axon.conv(
out_channels,
kernel_size={kernel_size, kernel_size},
strides=stride,
padding=kernel_size // 2,
use_bias=False,
kernel_init=Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
)(x)
hidden_state = Axon.batch_norm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
return hidden_state
And we must replace the activation function.
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = Axon.conv(
out_channels,
kernel_size={kernel_size, kernel_size},
strides=stride,
padding=kernel_size // 2,
use_bias=False,
kernel_init=Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
)(x)
hidden_state = Axon.batch_norm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
hidden_state = Axon.activation(activation) (hidden_state)
return hidden_state
We replace all =
with :
for function arguments.
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = Axon.conv(
out_channels,
kernel_size: kernel_size, kernel_size},
strides: stride,
padding: kernel_size // 2,
use_bias: False,
kernel_init: Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
)(x)
hidden_state = Axon.batch_norm(momentum: 0.9, epsilon: 1e-05, use_running_average=deterministic)(hidden_state)
hidden_state = Axon.activation(activation) (hidden_state)
return hidden_state
- Move the single param in second parenthesis to be first param of first parenthesis.
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = Axon.conv(x,
out_channels,
kernel_size: kernel_size, kernel_size},
strides: stride,
padding: kernel_size // 2,
use_bias: False,
kernel_init: Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
)
hidden_state = Axon.batch_norm(hidden_state, momentum: 0.9, epsilon: 1e-05, use_running_average: deterministic)
hidden_state = Axon.activation(hidden_state, activation)
return hidden_state
And remove the return
keyword.
class FlaxResNetConvLayer(nn.Module):
def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
) -> jnp.ndarray:
hidden_state = Axon.conv(x,
out_channels,
kernel_size: kernel_size, kernel_size},
strides: stride,
padding: kernel_size // 2,
use_bias: False,
kernel_init: Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
)
hidden_state = Axon.batch_norm(hidden_state, momentum: 0.9, epsilon: 1e-05, use_running_average: deterministic)
hidden_state = Axon.activation(hidden_state, activation)
hidden_state
-
Wrap the Axon model in a function that takes all the required parameters from the
__call__
function
- name the function corresponding to the class name, but snake case.
-
take the same arguments as the
__call__
function -
remove the
self
argument -
remove the
dtype
argument - remove the type specs if present
- wrap the function in do … end
def flax_res_net_conv_layer (x, deterministic, out_channels, kernel_size, stride, activation ) do
hidden_state = Axon.conv(x,
out_channels,
kernel_size: kernel_size, kernel_size},
strides: stride,
padding: kernel_size // 2,
use_bias: False,
kernel_init: Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
)
hidden_state = Axon.batch_norm(hidden_state, momentum: 0.9, epsilon: 1e-05, use_running_average: deterministic)
hidden_state = Axon.activation(hidden_state, activation)
hidden_state
end
- Convert all called modules, wrap them in a function and plug them in the pipeline.
Nothing to do here.
-
Check if the function is valid Elixir code. Otherwise, fix all issues by converting Python expressions to Elixir expressions.
E.g.
//
corresponds todiv
, scientific notation like1e-05
needs a decimal point in Elixir1.0e-05
.
flax_res_net_conv_layer = fn x, out_channels, kernel_size, stride, activation ->
hidden_state =
Axon.conv(
x,
out_channels,
kernel_size: kernel_size,
strides: stride,
padding: [
{div(kernel_size, 2), div(kernel_size, 2)},
{div(kernel_size, 2), div(kernel_size, 2)}
],
use_bias: false,
kernel_initializer:
Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal)
)
hidden_state =
Axon.batch_norm(hidden_state,
momentum: 0.9,
epsilon: 1.0e-05
)
hidden_state = Axon.activation(hidden_state, activation)
hidden_state
end
- Create an input layer in Axon and pass it in as first argument
model = flax_res_net_conv_layer.(Axon.input("input"), 3, 3, 1, :relu)
Axon.Display.as_graph(model, Nx.template({1, 2, 2}, :f32))
Parameterize models based on input
Verify model
Alright, we converted the model from Flax to Axon. How do we know that everything went right? We must verify that.
out_channels = 3
kernel_size = 3
stride = 1
activation = :relu
model =
flax_res_net_conv_layer.(Axon.input("input"), out_channels, kernel_size, stride, activation)
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 3, 3}, :f32), %{})
params =
%{
"batch_norm_0" => %{
"beta" => Nx.tensor([1, 1, 1], type: :f32),
"gamma" => Nx.tensor([1, 1, 1], type: :f32),
"mean" => Nx.tensor([1, 2, 3], type: :f32),
"var" => Nx.tensor([1, 2, 3], type: :f32)
},
"conv_0" => %{
"kernel" => Nx.broadcast(Nx.tensor(1, type: :f32), {1, 3, 3})
}
}
predict_fn.(params, Nx.tensor([[[1.0, 2.0, 3.0]]]))
predict_fn.(params, Nx.tensor([[[1.0, 2.0, 3.0] ]]))
Automate it
Wow, that’s a lot. We MUST automate Luckily it’s 2024 and we have something called LLM.
Load models from HuggingFace
Flax vs Axon layers
Framework | Layer | Params |
---|
Flax | nn.Conv | features | kernel_size | strides | padding | dtype | use_bias | kernel_init |
Axon | Axon.conv | units (second pos. param) | kernel_size | strides | padding | - | use_bias | kernel_initializer |
| Hints | | | single number instead of tuple | | predefined values (:valid) | | — | |Flax| nn.BatchNorm | use_running_average | axis | momentum | epsilon |dtype | param_dtype | use_bias |use_scale | bias_init |scale_init |axis_name |axis_index_groups |use_fast_variance |force_float32_reductions |parent |name | |Axon| Axon.batch_norm | - | - | - | epsilon | - | - | - | - | beta_initializer(?) | gamma_initializer(?) | - | - |- | - | - |name | | Hints | | | | | |
Flax vs Axon initializers
Framework | Initializer | Params | ||||||
---|---|---|---|---|---|---|---|---|
Flax | nn.initializers.variance_scaling | scale (1. pos) | mode (2. pos) | distribution (3. pos) | in_axis | out_axis | batch_axis | dtype |
Axon | Axon.Initializers.variance_scaling | scale | mode | distribution |
Hints |
Allow adding custom rules