Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Self Compressing Neural Networks

mnist-self-compression.livemd

Self Compressing Neural Networks

Mix.install(
  [
    :axon,
    :scidata,
    :nx,
    :exla,
    :kino,
    :kino_vega_lite
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Developing an Intuition

Paper: Self-Compressing Neural Networks [arxiv 2301.13142]

Typically neural networks are setup with a fixed number of weights that we optimize over. However, this has always begged the question, “How many parameters should I be using”. Did you use too many? too few? Is this really the optimal representation?

This paper explores this idea and proposes a method in which, while the network is training, it can also try to minimize the number of parameters and the amount of precision it needs to achieve the goal.

At a high-level it represents the weights as simulated floating point numbers — a function e, the exeponent and b, the bit-depth. Then, since these paramaters are available, we can add them to the loss function to be minimized. When b is minimized we’re effectively simulating quanitizing the weights.

Then, if a given b is every minimized to zero, it effectively means that that weight has zero impact on downstream layers since multiplying anything by zero is zero. So we can remove that weight and it’s connected weights from the network.

But let’s not get ahead of ourselves, first let’s get an intuition on the simulated floating point numbers. The equation presented in the paper is,

$$q(x,b,e) = 2^e⌊\min(\max(2^{-e}x, -2^{b-1}), 2^{b-1} - 1)⌉$$

Where, ⌊⌉ is the rounding operator. Let’s get a sense of what that function looks like.

require VegaLite, as: Vl

defmodule Paper do
  import Nx.Defn

  @doc """
  Simulation of a floating point number, used to simulation dynamic quantization of
  input tensor x.
  """
  defn floating_point(x, b, e) do
    # b cannot be less than zero
    b_relu = Axon.Activations.relu(b)

    2.0 ** e *
      Nx.round(
        Nx.min(
          Nx.max(2.0 ** -e * x, -(2.0 ** (b_relu - 1.0))),
          2.0 ** (b_relu - 1.0) - 1.0
        )
      )
  end
end

Kino.nothing()
{xmin, xmax} = {-1.0, 1.0}
xs = Nx.linspace(xmin, xmax, n: 1200)

graph =
  Vl.new(width: 300, height: 300, title: "Representable numbers with floating points")
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "x",
    type: :quantitative,
    scale: [domain: [xmin, xmax]],
    title: "weight"
  )
  |> Vl.encode_field(:y, "y",
    type: :quantitative,
    scale: [domain: [xmin, xmax]],
    title: "floating_point(weight, b, e)"
  )
  |> Kino.VegaLite.new()

inspect_frame = Kino.Frame.new()
b_input = Kino.Input.range("b", min: 0, max: 32.0, step: 0.1, default: 4.7, debounce: 16)
e_input = Kino.Input.range("e", min: -10, max: 0.0, step: 0.1, default: -4.6, debounce: 16)

form = Kino.Control.form([e: e_input, b: b_input], report_changes: true)

form
|> Kino.listen(fn %{data: %{b: b, e: e} = data} ->
  Kino.Frame.render(inspect_frame, data)

  data_points =
    [
      xs |> Nx.to_list(),
      Paper.floating_point(xs, b, e) |> Nx.to_list()
    ]
    |> Enum.zip()
    |> Enum.map(fn {x, y} -> %{x: x, y: y} end)

  Kino.VegaLite.clear(graph)
  Kino.VegaLite.push_many(graph, data_points)
end)

Kino.Layout.grid(
  [
    graph,
    Kino.Layout.grid([form, inspect_frame])
  ],
  columns: 2
)

Play around with the infographic above and you can see that the e parameter controls the resolution of the line, where as b controls how long the diagonal line is, rather the amount of numbers representable. This is under the hood is how floating point and quantization works. It’s what the paper uses to simulate the floating point representation of the weights, and allows us to minimize it while we’re training.

The lower the b parameter, the fewer bits we need to represent the weight and the smaller the network.

So… now that we have an intuition of behind the idea, let’s implement a Quantized Convolutional Layer that we can use in our neural network.

defmodule QConv do
  import Nx.Defn

  def qconv(%Axon{} = x, units, opts \\ []) when is_integer(units) and units > 0 do
    # Mostly copied from Axon.conv
    opts =
      Keyword.validate!(opts, [
        :name,
        :activation,
        :meta,
        kernel_initializer: :glorot_uniform,
        bias_initializer: :zeros,
        use_bias: true,
        kernel_size: 1,
        strides: 1,
        padding: :valid,
        input_dilation: 1,
        kernel_dilation: 1,
        channels: :last,
        feature_group_size: 1
      ])

    kernel_size = opts[:kernel_size]
    strides = opts[:strides]
    padding = opts[:padding]
    input_dilation = opts[:input_dilation]
    kernel_dilation = opts[:kernel_dilation]
    channels = opts[:channels]
    feature_group_size = opts[:feature_group_size]

    kernel_shape = &Axon.Shape.conv_kernel(&1, units, kernel_size, channels, feature_group_size)
    kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])

    # TODO: Actually calculate the shape
    e = Axon.param("e", {1, 1, 1, units}, initializer: Axon.Initializers.full(-8.0))
    # Start with 2 bits per weight
    b = Axon.param("b", {1, 1, 1, units}, initializer: Axon.Initializers.full(16.0))

    op = &do_qconv/5
    inputs = [x, kernel, e, b]

    Axon.layer(op, inputs,
      name: opts[:name],
      meta: opts[:meta],
      strides: strides,
      padding: padding,
      input_dilation: input_dilation,
      kernel_dilation: kernel_dilation,
      feature_group_size: feature_group_size,
      channels: channels,
      op_name: :qconv
    )
  end

  defnp do_qconv(input, kernel, e, b, opts \\ []) do
    # assert_min_rank!("Axon.Layers.conv", "input", input, 3)
    # assert_equal_rank!("Axon.Layers.conv", "input", input, "kernel", kernel)

    opts =
      keyword!(opts, [
        :meta,
        strides: 1,
        padding: :valid,
        input_dilation: 1,
        kernel_dilation: 1,
        feature_group_size: 1,
        batch_group_size: 1,
        channels: :last,
        mode: :inference
      ])

    # bias_reshape = Axon.Shape.conv_bias_reshape(input, bias, opts[:channels])
    {permutations, kernel_permutation} = Axon.Shape.conv_permutations(input, opts[:channels])

    qk = qkernel(kernel, e, b)
    k = stop_grad(Nx.round(qk) - qk) + qk

    input
    |> Nx.conv(k,
      strides: opts[:strides],
      padding: opts[:padding],
      input_dilation: opts[:input_dilation],
      kernel_dilation: opts[:kernel_dilation],
      feature_group_size: opts[:feature_group_size],
      batch_group_size: opts[:batch_group_size],
      input_permutation: permutations,
      kernel_permutation: kernel_permutation,
      output_permutation: permutations
    )

    # |> Nx.add(Nx.reshape(bias, bias_reshape))
  end

  defnp qkernel(kernel, e, b) do
    b_relu = Axon.Activations.relu(b)

    2.0 ** e *
      Nx.round(
        Nx.min(
          Nx.max(2.0 ** -e * kernel, -(2.0 ** (b_relu - 1.0))),
          2.0 ** (b_relu - 1.0) - 1.0
        )
      )
  end

  # defnp qbits()
end
{:module, QConv, <<70, 79, 82, 49, 0, 0, 26, ...>>, true}
input = Axon.input("input")

model =
  input
  |> QConv.qconv(16, kernel_size: {3, 3})
  |> Axon.relu()
  |> Axon.max_pool(kernel_size: {2, 2})
  |> QConv.qconv(32, kernel_size: {3, 3})
  |> Axon.relu()
  |> Axon.max_pool(kernel_size: {2, 2})
  |> Axon.flatten()
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)


Axon.Display.as_graph(model, Nx.template({60_000, 28, 28, 1}, :f32))
graph TD;
13[/"input (:input) {60000, 28, 28, 1}"/];
14["qconv_0 (:qconv) {60000, 26, 26, 16}"];
15["relu_0 (:relu) {60000, 26, 26, 16}"];
16["max_pool_0 (:max_pool) {60000, 13, 13, 16}"];
17["qconv_1 (:qconv) {60000, 11, 11, 32}"];
18["relu_1 (:relu) {60000, 11, 11, 32}"];
19["max_pool_1 (:max_pool) {60000, 5, 5, 32}"];
20["flatten_0 (:flatten) {60000, 800}"];
21["dense_0 (:dense) {60000, 128}"];
22["relu_2 (:relu) {60000, 128}"];
23["dense_1 (:dense) {60000, 10}"];
24["softmax_0 (:softmax) {60000, 10}"];
23 --> 24;
22 --> 23;
21 --> 22;
20 --> 21;
19 --> 20;
18 --> 19;
17 --> 18;
16 --> 17;
15 --> 16;
14 --> 15;
13 --> 14;
{{images_binary, images_type, images_shape}, {labels_binary, labels_type, labels_shape}} =
  Scidata.MNIST.download()

images =
  Nx.from_binary(images_binary, images_type)
  |> Nx.reshape({:auto, 28, 28, 1}, names: [:images, :channels, :height, :width])
  |> Nx.divide(255)
  |> Nx.to_batched(32)

labels =
  labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.reshape(labels_shape)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  |> Nx.to_batched(32)

{{images_binary, images_type, images_shape}, {labels_binary, labels_type, labels_shape}} =
  Scidata.MNIST.download_test()

test_images =
  Nx.from_binary(images_binary, images_type)
  |> Nx.reshape({:auto, 28, 28, 1}, names: [:images, :channels, :height, :width])
  |> Nx.divide(255)

test_labels =
  labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.reshape(labels_shape)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
#Nx.Tensor<
  u8[10000][10]
  EXLA.Backend
  [
    [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    ...
  ]
>
defmodule SelfCompressingLoss do
  import Nx.Defn

  deftransform loss(y_true, y_pred, opts \\ []) do
    raise "foobar"
    l = Axon.Losses.categorical_cross_entropy(y_true, y_pred, opts)
    l
  end
end

loss_fn = &amp;SelfCompressingLoss.loss(&amp;1, &amp;2, reduction: :mean)

params =
  model
  |> Axon.Loop.trainer(loss_fn, :adam)
  |> Axon.Loop.metric(:accuracy, "Accuracy")
  |> Axon.Loop.validate(model, [{test_images, test_labels}], event: :epoch_completed)
  |> Axon.Loop.run(Stream.zip(images, labels), %{}, epochs: 2, compiler: EXLA)

16:57:38.276 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 1850, Accuracy: 0.1123717 loss: 2.3014505