Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

B: Style Transfer

B_style_transfer.livemd

B: Style Transfer

Logger.configure(level: :info)

# All necessary dependencies are installed by installing the package below
Mix.install([
  {:workshop_elixir_conf_us_2024, path: Path.join(__DIR__, "../..")}
])

Get node

Value returned by Node.self() can be used to get the metrics from the running pipelines.

Go to livebooks/metrics.livemd to visualize them.

Node.self()

Axon models

We have two public modules in the :workshop_elixir_conf_us_2024 application:

  • Workshop.Models.Mosaic
  • Workshop.Models.Candy.

Each of them defines two functions:

  1. model(image_height, image_width) which returns an Axon model
  2. postprocess(tensor) which postprocesses tensor returned by the Axon model

Weights for these two models are stored in priv/nx directory.

$ ls priv/nx
candy.nx
mosaic.nx

Both models expect t:Nx.tensor() that represents an image on their input. Tensor shape should be {batch_size, height, width, colors}, where:

  • batch_size is the number of images passed to the model at once (for us, it might be always equal 1)
  • hegith is the image’s height (in other words, raw video frame height)
  • width is the image’s width (in other words, raw video frame width)
  • colors is always equal to 3, since in the RGB format each pixel is described by three numbers, one per each color

Exercise B1: Write Style Transfer Element

Write your own StyleTransferFilter, that will use one of the models above, to perform style tranfer on our video clip with the Big Buck Bunny.

Then, add your element to the StyleTransferPipeline in the cell below the exercise’s description.

Use Axon library in your filter’s implementation.

Below there are some general tips, that might be helpful:

  • raw video frames returned by SWScale.Converter{format: :RGB} have the following arrangement: H x W x colors. A number representing one color of a pixel takes 1 byte and doesn’t have a sign.
  • stream_format received in the handle_stream_format/4 callback contains information about input video width and height
  • t:Nx.tensor() returned by the model needs to be clamped to the range between 0 and 255
  • always set the t:Nx.tensor() backend to EXLA.Backend - otherwise operations on tensors will be slow (take a look at Nx.backend_transfer/2 function and the :backend option in the Nx.from_binary/3)
  • postprocessing is already implemented in postprocess/1 function in the model’s module. However, you need to implement StyleTransferFilter.preprocess/1 on your own. Below the cell with the definition of the StyleTransferFilter there is another cell with a test checking if your preprocessing implementation is done properly.

Some further tips concerning Axon:

  • to get a model’s reference, call model/2 function of a specific model’s module.
  • to load weights, read content of one of the *.nx files and deserialize it to the Nx.tensor().
  • to run a model, execute Axon.predict/3. This function expects an Axon model, weights and a map with input data under the "data" key as its input.
    Axon.predict(model, weights, %{"data" => preprocessed_image_tensor})
defmodule StyleTransferFilter do
  use Membrane.Filter

  def_input_pad(:input,
    accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
  )

  def_output_pad(:output,
    accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
  )

  @impl true
  def handle_init(_ctx, _opts), do: {[], %{}}

  @impl true
  def handle_stream_format(:input, _stream_format, _ctx, state) do
    # handle the stream_format
    {[], state}
  end

  # more callbacks 

  def preprocess(_input_binary, _stream_format) do
    # ...
  end
end

The cell below contains a test for the StyleTransferFilter.preprocess/2

stream_format = %{pixel_format: RGB, height: 4, width: 5}

# input_payload is <<0, 1, 2, 3, ..., 58, 59>>
input_payload = 0..59 |> Enum.to_list() |> to_string()

expected_preprocessed_tensor = Nx.iota({1, 4, 5, 3}, type: :f32, backend: EXLA.Backend)

preprocessed_data = StyleTransferFilter.preprocess(input_payload, stream_format)

if Nx.to_list(preprocessed_data) == Nx.to_list(expected_preprocessed_tensor) do
  :ok
else
  raise """
  Your preprocess/2 function returned
  #{inspect(preprocessed_data, limit: :infinity, pretty: true)}
  while it should return
  #{inspect(expected_preprocessed_tensor, limit: :infinity, pretty: true)}
  """
end
defmodule StyleTransferPipeline do
  use Membrane.Pipeline

  @impl true
  def handle_init(_ctx, _options) do
    input_path = "#{__DIR__}/../../priv/fixtures/bunny_with_sound.mp4"
    output_path = "#{__DIR__}/../../priv/outputs/style_transfer_bunny.mp4"

    spec = [
      child(:source, %Membrane.File.Source{location: input_path})
      |> child(:mp4_demuxer, Membrane.MP4.Demuxer.ISOM)
      |> via_out(:output, options: [kind: :video])
      |> child({:h264_parser, 1}, %Membrane.H264.Parser{output_stream_structure: :annexb})
      |> child(:h264_decoder, Membrane.H264.FFmpeg.Decoder)
      |> child(:rgb_converter, %Membrane.FFmpeg.SWScale.Converter{format: :RGB, output_width: 640})
      |> child(:style_transfer, StyleTransferFilter)
      |> child(:yuv_converter, %Membrane.FFmpeg.SWScale.Converter{format: :I420})
      |> child(:h264_encoder, %Membrane.H264.FFmpeg.Encoder{preset: :ultrafast})
      |> child({:h264_parser, 2}, %Membrane.H264.Parser{output_stream_structure: :avc1})
      |> child(:mp4_muxer, Membrane.MP4.Muxer.ISOM)
      |> child(:file_sink, %Membrane.File.Sink{location: output_path}),
      get_child(:mp4_demuxer)
      |> via_out(:output, options: [kind: :audio])
      |> get_child(:mp4_muxer)
    ]

    {[spec: spec], %{}}
  end

  @impl true
  def handle_element_end_of_stream(:file_sink, _input, _ctx, state) do
    {[terminate: :normal], state}
  end

  @impl true
  def handle_element_end_of_stream(_element, _input, _ctx, state), do: {[], state}
end

Run the pipeline

{:ok, supervisor, _pipeline} = Membrane.Pipeline.start_link(StyleTransferPipeline)
ref = Process.monitor(supervisor)

receive do
  {:DOWN, ^ref, _process, _pid, reason} -> reason
end

Kino.Download.new(fn -> File.read!("#{__DIR__}/../../priv/outputs/style_transfer_bunny.mp4") end,
  label: "Download the video",
  filename: "style_transfer_bunny.mp4"
)

ONNX models

In priv/models directory there is a couple of models in the .onnx format. Each of them does specific style tranfer of the input data.

$ ls priv/models
candy.onnx		
kaganawa.onnx		
mosaic.onnx		
mosaic_mobile.onnx	
picasso.onnx		
princess.onnx		
udnie.onnx		
vangogh.onnx

Each model expects t:Nx.tensor() that represents an image on its input. Tensor shape should be {batch_size, colors, height, width}.

Notice!

The order of the axes here is slightly different than the order of the axes expected by the Axon models.

Exercise B2: Use Ortex

Now, modify your code from Exercise B1, so that it uses Ortex instead of Axon.

There are some tips regarding Ortex:

  • to use a specific model, we have to load it first, using the Ortex.load/1 function (Ortex docs).
  • our models expect two input tensors: one representing an image, and a second one specifying parameters used in the layers inside the models. The code snippet below illustrates how to run a model with those two inputs. In your case, you can always run the model with the same parameters as the ones used in the example.
    {output_tensor} = 
    Ortex.run(loaded_model, {
      preprocessed_image_tensor, 
      Nx.tensor([1.0, 1.0, 1.0, 1.0], type: :f32)
    })
  • data preprocessing and postprocessing here is slightly different than in exercise B1, so below there are 2 cells with the tests for both the preprocessing and postprocessing.
# test preprocessing

stream_format = %{pixel_format: RGB, height: 4, width: 5}

# input_payload is <<0, 1, 2, 3, ..., 58, 59>>
input_payload = 0..59 |> Enum.to_list() |> to_string()

expected_preprocessed_tensor =
  Nx.iota({1, 4, 5, 3}, type: :f32, backend: EXLA.Backend)
  |> Nx.transpose(axes: [0, 3, 1, 2])

preprocessed_data = StyleTransferFilter.preprocess(input_payload, stream_format)

if Nx.to_list(preprocessed_data) == Nx.to_list(expected_preprocessed_tensor) do
  :ok
else
  raise """
  Your preprocess/2 function returned
  #{inspect(preprocessed_data, limit: :infinity, pretty: true)}
  while it should return
  #{inspect(expected_preprocessed_tensor, limit: :infinity, pretty: true)}
  """
end
# test postprocessing

stream_format = %{pixel_format: RGB, height: 4, width: 5}

tensor_to_postprocess =
  -300..290//10
  |> Enum.to_list()
  |> Nx.tensor(type: :f32, backend: EXLA.Backend)
  |> Nx.reshape({4, 5, 3})
  |> Nx.transpose(axes: [2, 0, 1])

expected_postprocessed_data =
  <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200,
    210, 220, 230, 240, 250, 255, 255, 255, 255>>

postprocessed_data = StyleTransferFilter.postprocess(tensor_to_postprocess, stream_format)

if postprocessed_data == expected_postprocessed_data do
  :ok
else
  raise """
  Your postprocess/2 function returned
  #{inspect(postprocessed_data, limit: :infinity, pretty: true)}
  while it should return
  #{inspect(expected_postprocessed_data, limit: :infinity, pretty: true)}
  """
end

Exercise B3: Compose models

Try to compose 2 style transfers. You can do it in any way you want.

Which approach to this problem is the simplest to implement? Is it the most efficient one?

Exercise B4: Rotate styles

Now, let’s introduce the following changes to your StyleTransferFilter:

  • load all models in the handle_setup/2 callback
  • every 1.5 seconds change the used model, so that the style of the output video changes in time