Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Flax to Axon - code to code

flax_to_axon_code_to_code.livemd

Flax to Axon - code to code

Mix.install([
  {:axon, "~> 0.6.1"},
  {:kino, "~> 0.12.3"}
])

Goal

Automatically import models from transformers
Doesn’t have to be 100% solution, every automated step helps
Consider only Flax because it’s the most similar
We want models that compute the same results for the same inputs

Non-Goals

Beautiful code that can be maintained easily. That’s up for bumblebee

flowchart TD
    a[This]-->b[that]-->c[jo]

Basics

Basically three levels:

  • code
  • graph
  • jax/nx/xla
model =
  Axon.input("input")
  |> Axon.dense(1)

Let’s run it with some input

{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 2}, :f32), %{})
params = %{
  "dense_0" => %{
    "bias" => Nx.tensor(     
      [0.0], type: :f32      
    ),
    "kernel" =>
      Nx.tensor([[2],[3]], type: :f32)
  }
}
predict_fn.(params, Nx.tensor([[1.0, 1.0]]))

Flax/Jax

How does flax/Jax work, how to get computational graph So JAX is basically the same as Nx. It gives you low level stuff to pass to XLA and run computations (I guess). Flax is built on top of that to easily build models Sounds familiar, yes it plays the same role as Axon. They have a linen API and an even newer NNX API, we will focus on linen for now.

If we could get all the data required to build our Axon struct from a Flax model we can use it to run the model in Elixir.

So let’s inspect Flax and define the same model as before.

from flax import linen as nn  # Linen API

class Dense(nn.Module):
  """A simple Dense model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=1)(x)
    return x

model = Dense()
key = random.key(0)
input = [1, 1]

params = model.init(key, input)
params = {'params': {'Dense_0': {'kernel': jnp.array([[2.],[3.]]), 'bias': jnp.array([0.])}}}

model.apply(params, input)

==> [5.0]

So, how can we create an equivalent Axon model given a Flax model. Here are the two models again:

model =
  Axon.input("input")
  |> Axon.dense(1)
class Dense(nn.Module):
  """A simple Dense model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=1)(x)
    return x

With this simple dense model, it’s easy. In Axon, we specify a special input layer, then we append all the layers that are present in the Flax model. Here, it’s only a single dense layer with a parameter (units in Axon, features in Flax) set to 1.

After that each framework provides a way to initialize the model, get the params and actually run it.

Alright we have everything we need to manually construct an Axon model from a very simple Flax model.

Let’s recap the steps we need to take given a Flax model:

  1. Create an input layer in Axon
  2. Append all layers found in Flax with the corresponding parameters

A slightly larger model

Let’s take it one step further and look at a slightly larger Axon model (taken from the Axon guides).

slightly_larger_model =
  Axon.input("input")
  |> Axon.dense(32)
  |> Axon.activation(:relu)
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(1)
  |> Axon.activation(:softmax)

And the same model in Flax:

from flax import linen as nn  # Linen API
from jax import numpy as jnp

class SlightlyLarger(nn.Module):
  """A slightly larger model."""

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(features=32)(x)
    x = nn.relu(x)
    x = nn.Dropout(rate=0.5, deterministic=not training)(x)
    x = nn.Dense(features=1)(x)
    x = nn.softmax(x)
    return x

Let’s see if we can reconstruct the Axon model from the Flax model. For this we follow our steps:

  1. Create an input layer
  2. Append all the layers found in Axon with the corresponding

We define an input layer the same way as before

Axon.input("input")

Then we go on with one layer at a time:

nn.Dense(features=32)(x)
Axon.dense(32)
nn.relu(x)
Axon.activation(:relu)
nn.Dropout(rate=0.5, deterministic=not training)(x)
Axon.dropout(rate: 0.5)

Note that we don’t need the deterministic parameter in Axon

nn.Dense(features=1)(x)
Axon.dense(1)
nn.softmax(x)
Axon.activation(:softmax)

And then we pipe the input through the Axon layers, and we’re done.

Axon.input("input")
  |> Axon.dense(32)
  |> Axon.activation(:relu)
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(1)
  |> Axon.activation(:softmax)

Flax with setup call

That’s cool, it worked with a slightly larger model. However, the model is still pretty neat. In particular, it’s defined in a single function. This maps naturally to the way we define models in Axon.

So, let’s continue to make it more difficult. First, what we’ve seen so far is the compact way to define a Flax model (note the @nn.compact decorator). Another common way to define Flax models, is to first initialize layers in a setup function. Afterwards, those layers get called in the __call__ function as before.

When we define our slightly larger Flax model using a setup function it looks like this:

class SlightlyLargerSetup(nn.Module):
  """A slightly larger model defined with setup function."""
  def setup(self):
    self.dense0 = nn.Dense(features=32)    
    self.dropout = nn.Dropout(rate=0.5)
    self.dense1 = nn.Dense(features=1)
  
  def __call__(self, x, training: bool):
    x = self.dense0(x)
    x = nn.relu(x)
    x = self.dropout(x, deterministic=not training)
    x = self.dense1(x)
    x = nn.softmax(x)
    return x

Now, this looks fairly similar. However, to go back to the compact version we must replace the calls to the methods stored in our module attributes with the actual Flax layer in the __call__ function, including the parameters passed to the layers during initialization in setup:

class SlightlyLargerSetup(nn.Module):
  """A slightly larger model defined with setup function."""
  def setup(self):
    self.dense0 = nn.Dense(features=32)    
    self.dropout = nn.Dropout(rate=0.5)
    self.dense1 = nn.Dense(features=1)
  
  def __call__(self, x, training: bool):
    x = nn.Dense(features=32)(x)
    x = nn.relu(x)
    x = nn.Dropout(rate=0.5)(x, deterministic=not training)
    x = nn.Dense(features=1)(x)
    x = nn.softmax(x)
    return x

Note that the deterministic parameter of the dropout layer is now passed in the second pair of parentheses. While that works, we should move it back to the first pair. When we remove the setup function and add the @nn.compact decorator, we’re now back with the exact same model definition as before, and we know how to convert it to an Axon model.

So, we should add two steps to our recipe:

  1. If there is a setup function, replace the calls to the store methods in __call__ with the actual Flax layer, including the parameters.
  2. In the __call__ function, move all parameters in second pairs of parentheses to the first pair.
  3. Create an input layer in Axon
  4. Append all layers found in Flax with the corresponding parameters

Flax with attributes

Another fun thing you can do with Flax models is to pass attributes when creating the model. Have a look here:

from flax import linen as nn  # Linen API
from jax import numpy as jnp

class SlightlyLargerAttributes(nn.Module):
  """A slightly larger model with attributes."""
  features_dense_0: int
  dropout_rate: int
  features_dense_1: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(features=self.features_dense_0)(x)
    x = nn.relu(x)
    x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
    x = nn.Dense(features=self.features_dense_1)(x)
    x = nn.softmax(x)
    return x

The attributes are then used to initialize the layers. As said, you pass values when creating the model:

model = SlightlyLargerAttributes(features_dense_0=32, dropout_rate=0.5, features_dense_1=1)

Then, you can go on as before:

key = random.key(0)
input = [1,1]

params = model.init(key, input, training=True)

model.apply(params, input, training=False, rngs={'dropout': key })

This way the models are parameterized. So, how would we convert this model to Axon?

Let’s start by moving the attributes to the __call__ function as parameters, and replace all the references of the attributes with the function parameters when initializing the layers:

class SlightlyLargerAttributes(nn.Module):
  """A slightly larger model with attributes."""

  @nn.compact
  def __call__(self, x, training: bool, features_dense_0: int, dropout_rate: int, features_dense_1: int):
    x = nn.Dense(features=features_dense_0)(x)
    x = nn.relu(x)
    x = nn.Dropout(rate=dropout_rate, deterministic=not training)(x)
    x = nn.Dense(features=features_dense_1)(x)
    x = nn.softmax(x)
    return x

Alright, this looks similar to the models before. We basically know how to get an Axon model from this. However, now we do have parameters in the __call__ function, so what should we do with them?

Let’s wrap our regular Axon model definition in a function.

model_with_params =
  fn %{
       features_dense_0: features_dense_0,
       dropout_rate: dropout_rate,
       features_dense_1: features_dense_1
     } ->
    Axon.input("input")
    |> Axon.dense(features_dense_0)
    |> Axon.activation(:relu)
    |> Axon.dropout(rate: dropout_rate)
    |> Axon.dense(features_dense_1)
    |> Axon.activation(:softmax)
  end

This way we parameterized our Axon model as well. Watch how we can use it.

model = model_with_params.(%{features_dense_0: 32, dropout_rate: 0.5, features_dense_1: 1})

{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 2}, :f32), %{})

predict_fn.(params, Nx.tensor([[1.0, 1.0]]))

We add some steps to our recipe:

  1. If there is a setup function, replace the calls to the store methods in __call__ with the actual Flax layer, including the parameters.
  2. If there are attributes, move them to the __call__ function as parameters
  3. Replace all the references of the attributes with the function parameters when initializing the layers
  4. In the __call__ function, move all parameters in second pairs of parentheses to the first pair.
  5. Create an input layer in Axon
  6. Append all layers found in Flax with the corresponding parameters
  7. Wrap the Axon model in a function that takes all the required parameters from the __call__ function

Flax with loops

So far so good. Let’s make it more difficult.

When defining models in Flax, you can use loops. Look at this.

from flax import linen as nn  # Linen API
from jax import numpy as jnp

class SlightlyLargerLoop(nn.Module):
  """A slightly larger model with a loop."""

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(features=32)(x)
    x = nn.relu(x)
    x = nn.Dropout(rate=0.5, deterministic=not training)(x)
    for i in range(1,3):
      x = nn.Dense(features=i)(x)
    x = nn.softmax(x)
    return x

So, how would we create our Axon model from that?

Let’s extract the loop into a function.

loop = fn axon, enumerable -> 
  for i <- enumerable, reduce: axon do
    ax -> ax |> Axon.dense(i)
  end
end

Then we plug the function into our pipeline as a layer.

Axon.input("input")
|> Axon.dense(32)
|> Axon.activation(:relu)
|> Axon.dropout(rate: 0.5)
|> loop.(1..3)
|> Axon.dense(1)
|> Axon.activation(:softmax)

Or, we could plug the function directly into the pipeline using then.

Axon.input("input")
|> Axon.dense(32)
|> Axon.activation(:relu)
|> Axon.dropout(rate: 0.5)
|> then(fn axon ->
  for i <- 1..3, reduce: axon do
    ax -> ax |> Axon.dense(i)
  end
end)
|> Axon.dense(1)
|> Axon.activation(:softmax)

We might need to revisit loops later when we look at more complicated ones but for now we leave it at that.

Let’s add the steps we performed to the recipe.

  1. If there is a setup function, replace the calls to the store methods in __call__ with the actual Flax layer, including the parameters.
  2. If there are attributes, move them to the __call__ function as parameters and replace all the references of the attributes with the function parameters when initializing the layers
  3. In the __call__ function, move all parameters in second pairs of parentheses to the first pair.
  4. Extract loops to a seperate function, transform this function according to our rules. Plug the function into the pipeline.
  5. Create an input layer in Axon
  6. Append all layers found in Flax with the corresponding parameters
  7. Wrap the Axon model in a function that takes all the required parameters from the __call__ function

Flax with nested calls

So far so good. Let’s make it more difficult.

When defining models in Flax, you can use arbitrary functions. Here an example:

Combined Flax models

Of course, we can combine different Flax modules like in the following example.

from flax import linen as nn  # Linen API
from jax import numpy as jnp

class DenseSoftmax(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=1)(x)
    x = nn.softmax(x)
    return x

class SlightlyLargerCombined(nn.Module):
  """A slightly larger model which uses another module."""

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(features=32)(x)
    x = nn.relu(x)
    x = nn.Dropout(rate=0.5, deterministic=not training)(x)
    x = DenseSoftmax()(x)
    return x

Here, SlightlyLargerCombined makes use of DenseSoftmax. How would we do convert this to Axon? Well, we would first convert DenseSoftmax without an input layer. Then, we wrap it in a function and plug it into the conversion of SligthlyLargerCombined.

dense_softmax = fn axon ->
  axon
  |> Axon.dense(1)
  |> Axon.softmax()
end

slightly_larger_combined =
  Axon.input("input")
  |> Axon.dense(32)
  |> Axon.relu()
  |> Axon.dropout(rate: 0.5)
  |> dense_softmax.()

Again, we add our findings to the recipe.

  1. If there is a setup function, replace the calls to the store methods in __call__ with the actual Flax layer, including the parameters.
  2. If there are attributes, move them to the __call__ function as parameters and replace all the references of the attributes with the function parameters when initializing the layers
  3. In the __call__ function, move all parameters in second pairs of parentheses to the first pair.
  4. Extract loops to a seperate function, transform this function according to our rules. Plug the function into the pipeline.
  5. Append all layers found in Flax with the corresponding parameters
  6. Wrap the Axon model in a function that takes all the required parameters from the __call__ function
  7. Convert all called modules, wrap them in a function and plug them in the pipeline.
  8. Create an input layer in Axon

Handle plain jnp <-> Nx

Real world model

Let’s get serious and have a look at a real world model, the Flax implementation of resnet in transformers, you can find it here.

We start with FlaxResNetConvLayer.

class FlaxResNetConvLayer(nn.Module):
    out_channels: int
    kernel_size: int = 3
    stride: int = 1
    activation: Optional[str] = "relu"
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.convolution = nn.Conv(
            self.out_channels,
            kernel_size=(self.kernel_size, self.kernel_size),
            strides=self.stride,
            padding=self.kernel_size // 2,
            dtype=self.dtype,
            use_bias=False,
            kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype),
        )
        self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
        self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity()

    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        hidden_state = self.convolution(x)
        hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
        hidden_state = self.activation_func(hidden_state)
        return hidden_state

Alright, let’s follow our rules and see if we can get an Axon model.

  1. If there is a setup function, replace the calls to the stored methods in __call__ with the actual Flax layer, including the parameters.

  2. If there are attributes, move them to the __call__ function as parameters and replace all the references of the attributes with the function parameters when initializing the layers

  3. In the __call__ function, move all parameters in second pairs of parentheses to the first pair.

  4. Extract loops to a seperate function, transform this function according to our rules. Plug the function into the pipeline.

  5. Append all layers found in Flax with the corresponding parameters

  6. Wrap the Axon model in a function that takes all the required parameters from the __call__ function

  7. Convert all called modules, wrap them in a function and plug them in the pipeline.

  8. Create an input layer in Axon

  9. If there is a setup function, replace the calls to the stored methods in __call__ with the actual Flax layer, including the parameters.

class FlaxResNetConvLayer(nn.Module):
    out_channels: int
    kernel_size: int = 3
    stride: int = 1
    activation: Optional[str] = "relu"
    dtype: jnp.dtype = jnp.float32

    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        hidden_state = nn.Conv(
            self.out_channels,
            kernel_size=(self.kernel_size, self.kernel_size),
            strides=self.stride,
            padding=self.kernel_size // 2,
            dtype=self.dtype,
            use_bias=False,
            kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype),
        )(x)
        hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)(hidden_state, use_running_average=deterministic)
        hidden_state = ACT2FN[self.activation] if self.activation is not None else Identity() (hidden_state)
        return hidden_state
  1. If there are attributes, move them to the __call__ function as parameters and replace all the references of the attributes with the function parameters
class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = nn.Conv(
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            strides=stride,
            padding=kernel_size // 2,
            dtype=dtype,
            use_bias=False,
            kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=dtype),
        )(x)
        hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=dtype)(hidden_state, use_running_average=deterministic)
        hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
        return hidden_state
  1. In the __call__ function, move all additional parameters in second pairs of parentheses to the first pair.
class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = nn.Conv(
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            strides=stride,
            padding=kernel_size // 2,
            dtype=dtype,
            use_bias=False,
            kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=dtype),
        )(x)
        hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
        hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
        return hidden_state
  1. Extract loops to a seperate function, transform this function according to our rules. Plug the function into the pipeline.

Nothing to do here, there is no loop.

  1. Replace the Flax layers with the corresponding Axon layers according to our conversion table. Take into account the parameters in the first parenthesis. Replace initializer functions with the corresponding Axon functions. Follow these rules to replace activation functions: If there is an activation parameter, and the code makes use of ACT2FN[activation], replace that with Axon.activation(activation). If there is an actual activation function called, replace it with Axon’s activation function according to the conversion table. Replace all = for arguments with :

We start with the conv layer

class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = Axon.conv(
            out_channels,
            kernel_size={kernel_size, kernel_size},
            strides=stride,
            padding=kernel_size // 2,
            use_bias=False,
            kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=dtype),
        )(x)
        hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
        hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
        return hidden_state

We must replace the kernel initializer.

class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = Axon.conv(
            out_channels,
            kernel_size={kernel_size, kernel_size},
            strides=stride,
            padding=kernel_size // 2,
            use_bias=False,
            kernel_init=Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
        )(x)
        hidden_state = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
        hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
        return hidden_state

Next, the batch norm layer.

class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = Axon.conv(
            out_channels,
            kernel_size={kernel_size, kernel_size},
            strides=stride,
            padding=kernel_size // 2,
            use_bias=False,
            kernel_init=Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
        )(x)
        hidden_state = Axon.batch_norm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
        hidden_state = ACT2FN[activation] if activation is not None else Identity() (hidden_state)
        return hidden_state

And we must replace the activation function.

class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = Axon.conv(
            out_channels,
            kernel_size={kernel_size, kernel_size},
            strides=stride,
            padding=kernel_size // 2,
            use_bias=False,
            kernel_init=Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
        )(x)
        hidden_state = Axon.batch_norm(momentum=0.9, epsilon=1e-05, dtype=dtype, use_running_average=deterministic)(hidden_state)
        hidden_state = Axon.activation(activation) (hidden_state)
        return hidden_state

We replace all = with : for function arguments.

class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = Axon.conv(
            out_channels,
            kernel_size: kernel_size, kernel_size},
            strides: stride,
            padding: kernel_size // 2,
            use_bias: False,
            kernel_init: Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
        )(x)
        hidden_state = Axon.batch_norm(momentum: 0.9, epsilon: 1e-05, use_running_average=deterministic)(hidden_state)
        hidden_state = Axon.activation(activation) (hidden_state)
        return hidden_state
  1. Move the single param in second parenthesis to be first param of first parenthesis.
class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = Axon.conv(x,
            out_channels,
            kernel_size: kernel_size, kernel_size},
            strides: stride,
            padding: kernel_size // 2,
            use_bias: False,
            kernel_init: Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
        )
        hidden_state = Axon.batch_norm(hidden_state, momentum: 0.9, epsilon: 1e-05, use_running_average: deterministic)
        hidden_state = Axon.activation(hidden_state, activation) 
        return hidden_state

And remove the return keyword.

class FlaxResNetConvLayer(nn.Module):

    def __call__(self, x: jnp.ndarray, deterministic: bool = True, out_channels: int kernel_size: int = 3 stride: int = 1 activation: Optional[str] = "relu" dtype: jnp.dtype = jnp.float32
 ) -> jnp.ndarray:
        hidden_state = Axon.conv(x,
            out_channels,
            kernel_size: kernel_size, kernel_size},
            strides: stride,
            padding: kernel_size // 2,
            use_bias: False,
            kernel_init: Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
        )
        hidden_state = Axon.batch_norm(hidden_state, momentum: 0.9, epsilon: 1e-05, use_running_average: deterministic)
        hidden_state = Axon.activation(hidden_state, activation) 
        hidden_state
  1. Wrap the Axon model in a function that takes all the required parameters from the __call__ function
  • name the function corresponding to the class name, but snake case.
  • take the same arguments as the __call__ function
  • remove the self argument
  • remove the dtype argument
  • remove the type specs if present
  • wrap the function in do … end
def flax_res_net_conv_layer (x, deterministic, out_channels, kernel_size, stride, activation  ) do 
        hidden_state = Axon.conv(x,
            out_channels,
            kernel_size: kernel_size, kernel_size},
            strides: stride,
            padding: kernel_size // 2,
            use_bias: False,
            kernel_init: Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal),
        )
        hidden_state = Axon.batch_norm(hidden_state, momentum: 0.9, epsilon: 1e-05, use_running_average: deterministic)
        hidden_state = Axon.activation(hidden_state, activation) 
        hidden_state
end
  1. Convert all called modules, wrap them in a function and plug them in the pipeline.

Nothing to do here.

  1. Check if the function is valid Elixir code. Otherwise, fix all issues by converting Python expressions to Elixir expressions. E.g. // corresponds to div, scientific notation like 1e-05 needs a decimal point in Elixir 1.0e-05.
flax_res_net_conv_layer = fn x, out_channels, kernel_size, stride, activation ->
  hidden_state =
    Axon.conv(
      x,
      out_channels,
      kernel_size: kernel_size,
      strides: stride,
      padding: [
        {div(kernel_size, 2), div(kernel_size, 2)},
        {div(kernel_size, 2), div(kernel_size, 2)}
      ],
      use_bias: false,
      kernel_initializer:
        Axon.Initializers.variance_scaling(scale: 2.0, mode: :fan_out, distribution: :normal)
    )

  hidden_state =
    Axon.batch_norm(hidden_state,
      momentum: 0.9,
      epsilon: 1.0e-05
    )

  hidden_state = Axon.activation(hidden_state, activation)
  hidden_state
end
  1. Create an input layer in Axon and pass it in as first argument
model = flax_res_net_conv_layer.(Axon.input("input"), 3, 3, 1, :relu)

Axon.Display.as_graph(model, Nx.template({1, 2, 2}, :f32))

Parameterize models based on input

Verify model

Alright, we converted the model from Flax to Axon. How do we know that everything went right? We must verify that.

out_channels = 3
kernel_size = 3
stride = 1
activation = :relu

model =
  flax_res_net_conv_layer.(Axon.input("input"), out_channels, kernel_size, stride, activation)

{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 3, 3}, :f32), %{})

params =
  %{
    "batch_norm_0" => %{
      "beta" => Nx.tensor([1, 1, 1], type: :f32),
      "gamma" => Nx.tensor([1, 1, 1], type: :f32),
      "mean" => Nx.tensor([1, 2, 3], type: :f32),
      "var" => Nx.tensor([1, 2, 3], type: :f32)
    },
    "conv_0" => %{
      "kernel" => Nx.broadcast(Nx.tensor(1, type: :f32), {1, 3, 3})
    }
  }

predict_fn.(params, Nx.tensor([[[1.0, 2.0, 3.0]]]))
predict_fn.(params, Nx.tensor([[[1.0, 2.0, 3.0] ]]))

Automate it

Wow, that’s a lot. We MUST automate Luckily it’s 2024 and we have something called LLM.

Load models from HuggingFace

Flax vs Axon layers

Framework Layer Params
Flax nn.Conv features kernel_size strides padding dtype use_bias kernel_init
Axon Axon.conv units (second pos. param) kernel_size strides padding - use_bias kernel_initializer

| Hints | | | single number instead of tuple | | predefined values (:valid) | | — | |Flax| nn.BatchNorm | use_running_average | axis | momentum | epsilon |dtype | param_dtype | use_bias |use_scale | bias_init |scale_init |axis_name |axis_index_groups |use_fast_variance |force_float32_reductions |parent |name | |Axon| Axon.batch_norm | - | - | - | epsilon | - | - | - | - | beta_initializer(?) | gamma_initializer(?) | - | - |- | - | - |name | | Hints | | | | | |

Flax vs Axon initializers

Framework Initializer Params
Flax nn.initializers.variance_scaling scale (1. pos) mode (2. pos) distribution (3. pos) in_axis out_axis batch_axis dtype
Axon Axon.Initializers.variance_scaling scale mode distribution
Hints

Allow adding custom rules