B: Style Transfer
Logger.configure(level: :info)
# All necessary dependencies are installed by installing the package below
Mix.install([
{:workshop_elixir_conf_us_2024, path: Path.join(__DIR__, "../..")}
])
Get node
Value returned by Node.self()
can be used to get the metrics from the running pipelines.
Go to livebooks/metrics.livemd
to visualize them.
Node.self()
Axon models
We have two public modules in the :workshop_elixir_conf_us_2024
application:
-
Workshop.Models.Mosaic
-
Workshop.Models.Candy
.
Each of them defines two functions:
-
model(image_height, image_width)
which returns an Axon model -
postprocess(tensor)
which postprocesses tensor returned by the Axon model
Weights for these two models are stored in priv/nx
directory.
$ ls priv/nx
candy.nx
mosaic.nx
Both models expect t:Nx.tensor()
that represents an image on their input. Tensor shape should be {batch_size, height, width, colors}
, where:
-
batch_size
is the number of images passed to the model at once (for us, it might be always equal1
) -
hegith
is the image’s height (in other words, raw video frame height) -
width
is the image’s width (in other words, raw video frame width) -
colors
is always equal to3
, since in theRGB
format each pixel is described by three numbers, one per each color
Exercise B1: Write Style Transfer Element
Write your own StyleTransferFilter
, that will use one of the models above, to perform style tranfer on our video clip with the Big Buck Bunny.
Then, add your element to the StyleTransferPipeline
in the cell below the exercise’s description.
Use Axon
library in your filter’s implementation.
Below there are some general tips, that might be helpful:
-
raw video frames returned by
SWScale.Converter{format: :RGB}
have the following arrangement:H x W x colors
. A number representing one color of a pixel takes 1 byte and doesn’t have a sign. -
stream_format
received in thehandle_stream_format/4
callback contains information about input video width and height -
t:Nx.tensor()
returned by the model needs to be clamped to the range between 0 and 255 -
always set the
t:Nx.tensor()
backend toEXLA.Backend
- otherwise operations on tensors will be slow (take a look atNx.backend_transfer/2
function and the:backend
option in theNx.from_binary/3
) -
postprocessing is already implemented in
postprocess/1
function in the model’s module. However, you need to implementStyleTransferFilter.preprocess/1
on your own. Below the cell with the definition of theStyleTransferFilter
there is another cell with a test checking if your preprocessing implementation is done properly.
Some further tips concerning Axon
:
-
to get a model’s reference, call
model/2
function of a specific model’s module. -
to load weights, read content of one of the
*.nx
files and deserialize it to theNx.tensor()
. -
to run a model, execute
Axon.predict/3
. This function expects anAxon
model, weights and a map with input data under the"data"
key as its input.Axon.predict(model, weights, %{"data" => preprocessed_image_tensor})
defmodule StyleTransferFilter do
use Membrane.Filter
def_input_pad(:input,
accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
)
def_output_pad(:output,
accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
)
@impl true
def handle_init(_ctx, _opts), do: {[], %{}}
@impl true
def handle_stream_format(:input, _stream_format, _ctx, state) do
# handle the stream_format
{[], state}
end
# more callbacks
def preprocess(_input_binary, _stream_format) do
# ...
end
end
The cell below contains a test for the StyleTransferFilter.preprocess/2
stream_format = %{pixel_format: RGB, height: 4, width: 5}
# input_payload is <<0, 1, 2, 3, ..., 58, 59>>
input_payload = 0..59 |> Enum.to_list() |> to_string()
expected_preprocessed_tensor = Nx.iota({1, 4, 5, 3}, type: :f32, backend: EXLA.Backend)
preprocessed_data = StyleTransferFilter.preprocess(input_payload, stream_format)
if Nx.to_list(preprocessed_data) == Nx.to_list(expected_preprocessed_tensor) do
:ok
else
raise """
Your preprocess/2 function returned
#{inspect(preprocessed_data, limit: :infinity, pretty: true)}
while it should return
#{inspect(expected_preprocessed_tensor, limit: :infinity, pretty: true)}
"""
end
defmodule StyleTransferPipeline do
use Membrane.Pipeline
@impl true
def handle_init(_ctx, _options) do
input_path = "#{__DIR__}/../../priv/fixtures/bunny_with_sound.mp4"
output_path = "#{__DIR__}/../../priv/outputs/style_transfer_bunny.mp4"
spec = [
child(:source, %Membrane.File.Source{location: input_path})
|> child(:mp4_demuxer, Membrane.MP4.Demuxer.ISOM)
|> via_out(:output, options: [kind: :video])
|> child({:h264_parser, 1}, %Membrane.H264.Parser{output_stream_structure: :annexb})
|> child(:h264_decoder, Membrane.H264.FFmpeg.Decoder)
|> child(:rgb_converter, %Membrane.FFmpeg.SWScale.Converter{format: :RGB, output_width: 640})
|> child(:style_transfer, StyleTransferFilter)
|> child(:yuv_converter, %Membrane.FFmpeg.SWScale.Converter{format: :I420})
|> child(:h264_encoder, %Membrane.H264.FFmpeg.Encoder{preset: :ultrafast})
|> child({:h264_parser, 2}, %Membrane.H264.Parser{output_stream_structure: :avc1})
|> child(:mp4_muxer, Membrane.MP4.Muxer.ISOM)
|> child(:file_sink, %Membrane.File.Sink{location: output_path}),
get_child(:mp4_demuxer)
|> via_out(:output, options: [kind: :audio])
|> get_child(:mp4_muxer)
]
{[spec: spec], %{}}
end
@impl true
def handle_element_end_of_stream(:file_sink, _input, _ctx, state) do
{[terminate: :normal], state}
end
@impl true
def handle_element_end_of_stream(_element, _input, _ctx, state), do: {[], state}
end
Run the pipeline
{:ok, supervisor, _pipeline} = Membrane.Pipeline.start_link(StyleTransferPipeline)
ref = Process.monitor(supervisor)
receive do
{:DOWN, ^ref, _process, _pid, reason} -> reason
end
Kino.Download.new(fn -> File.read!("#{__DIR__}/../../priv/outputs/style_transfer_bunny.mp4") end,
label: "Download the video",
filename: "style_transfer_bunny.mp4"
)
ONNX models
In priv/models
directory there is a couple of models in the .onnx
format. Each of them does specific style tranfer of the input data.
$ ls priv/models
candy.onnx
kaganawa.onnx
mosaic.onnx
mosaic_mobile.onnx
picasso.onnx
princess.onnx
udnie.onnx
vangogh.onnx
Each model expects t:Nx.tensor()
that represents an image on its input. Tensor shape should be {batch_size, colors, height, width}
.
The order of the axes here is slightly different than the order of the axes expected by the Axon models.
Exercise B2: Use Ortex
Now, modify your code from Exercise B1
, so that it uses Ortex
instead of Axon
.
There are some tips regarding Ortex
:
-
to use a specific model, we have to load it first, using the
Ortex.load/1
function (Ortex docs). -
our models expect two input tensors: one representing an image, and a second one specifying parameters used in the layers inside the models. The code snippet below illustrates how to run a model with those two inputs. In your case, you can always run the model with the same parameters as the ones used in the example.
{output_tensor} = Ortex.run(loaded_model, { preprocessed_image_tensor, Nx.tensor([1.0, 1.0, 1.0, 1.0], type: :f32) })
-
data preprocessing and postprocessing here is slightly different than in exercise
B1
, so below there are 2 cells with the tests for both the preprocessing and postprocessing.
# test preprocessing
stream_format = %{pixel_format: RGB, height: 4, width: 5}
# input_payload is <<0, 1, 2, 3, ..., 58, 59>>
input_payload = 0..59 |> Enum.to_list() |> to_string()
expected_preprocessed_tensor =
Nx.iota({1, 4, 5, 3}, type: :f32, backend: EXLA.Backend)
|> Nx.transpose(axes: [0, 3, 1, 2])
preprocessed_data = StyleTransferFilter.preprocess(input_payload, stream_format)
if Nx.to_list(preprocessed_data) == Nx.to_list(expected_preprocessed_tensor) do
:ok
else
raise """
Your preprocess/2 function returned
#{inspect(preprocessed_data, limit: :infinity, pretty: true)}
while it should return
#{inspect(expected_preprocessed_tensor, limit: :infinity, pretty: true)}
"""
end
# test postprocessing
stream_format = %{pixel_format: RGB, height: 4, width: 5}
tensor_to_postprocess =
-300..290//10
|> Enum.to_list()
|> Nx.tensor(type: :f32, backend: EXLA.Backend)
|> Nx.reshape({4, 5, 3})
|> Nx.transpose(axes: [2, 0, 1])
expected_postprocessed_data =
<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200,
210, 220, 230, 240, 250, 255, 255, 255, 255>>
postprocessed_data = StyleTransferFilter.postprocess(tensor_to_postprocess, stream_format)
if postprocessed_data == expected_postprocessed_data do
:ok
else
raise """
Your postprocess/2 function returned
#{inspect(postprocessed_data, limit: :infinity, pretty: true)}
while it should return
#{inspect(expected_postprocessed_data, limit: :infinity, pretty: true)}
"""
end
Exercise B3: Compose models
Try to compose 2 style transfers. You can do it in any way you want.
Which approach to this problem is the simplest to implement? Is it the most efficient one?
Exercise B4: Rotate styles
Now, let’s introduce the following changes to your StyleTransferFilter:
-
load all models in the
handle_setup/2
callback - every 1.5 seconds change the used model, so that the style of the output video changes in time