Solutions: B
Logger.configure(level: :info)
# All necessary dependencies are installed by installing the package below
Mix.install([
{:workshop_elixir_conf_us_2024, path: Path.join(__DIR__, "../..")}
])
Exercise B1
defmodule StyleTransferFilter do
use Membrane.Filter
def_input_pad(:input,
accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
)
def_output_pad(:output,
accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
)
@impl true
def handle_init(_ctx, _opts), do: {[], %{model: nil, loaded_params: nil}}
@impl true
def handle_stream_format(:input, format, _ctx, state) do
model = Workshop.Models.Mosaic.model(format.height, format.width)
loaded_params =
"#{__DIR__}/../../priv/nx/mosaic.nx"
|> File.read!()
|> Nx.deserialize()
{[stream_format: {:output, format}], %{state | model: model, loaded_params: loaded_params}}
end
@impl true
def handle_buffer(:input, buffer, ctx, state) do
input_tensor = preprocess(buffer.payload, ctx.pads.input.stream_format)
output_tensor = Axon.predict(state.model, state.loaded_params, %{"data" => input_tensor})
output_payload = postprocess(output_tensor, ctx.pads.input.stream_format)
buffer = %{buffer | payload: output_payload}
{[buffer: {:output, buffer}], state}
end
def preprocess(payload, format) do
payload
|> Nx.from_binary(:u8, backend: EXLA.Backend)
|> Nx.as_type(:f32)
|> Nx.reshape({1, format.height, format.width, 3})
end
def postprocess(tensor, input_stream_format) do
tensor
|> Nx.backend_transfer(EXLA.Backend)
|> Nx.reshape({3, input_stream_format.height, input_stream_format.width})
|> Workshop.Models.Mosaic.postprocess()
end
end
Exercise B2
defmodule StyleTransferFilter do
use Membrane.Filter
def_input_pad(:input,
accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
)
def_output_pad(:output,
accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
)
@impl true
def handle_init(_ctx, _opts), do: {[], %{model: nil}}
@impl true
def handle_setup(_ctx, state) do
model = "#{__DIR__}/../../priv/models/picasso.onnx" |> Ortex.load()
{[], %{state | model: model}}
end
@impl true
def handle_buffer(:input, buffer, ctx, state) do
input_tensor = preprocess(buffer.payload, ctx.pads.input.stream_format)
{output_tensor} =
Ortex.run(state.model, {
input_tensor,
Nx.tensor([1.0, 1.0, 1.0, 1.0], type: :f32)
})
output_payload = postprocess(output_tensor, ctx.pads.input.stream_format)
buffer = %{buffer | payload: output_payload}
{[buffer: {:output, buffer}], state}
end
def preprocess(payload, format) do
payload
|> Nx.from_binary(:u8, backend: EXLA.Backend)
|> Nx.as_type(:f32)
|> Nx.reshape({1, format.height, format.width, 3})
|> Nx.transpose(axes: [0, 3, 1, 2])
end
def postprocess(tensor, format) do
tensor
|> Nx.backend_transfer(EXLA.Backend)
|> Nx.reshape({3, format.height, format.width})
|> Nx.transpose(axes: [1, 2, 0])
|> clamp()
|> Nx.round()
|> Nx.as_type(:u8)
|> Nx.to_binary()
end
defp clamp(tensor) do
tensor
|> Nx.max(0.0)
|> Nx.min(255.0)
end
end
Exercise B3
handle_init/2
, handle_setup/2
and handle_tick/3
should have implementations as follows. handle_buffer/4
should remain unchanged.
Add |> via_in(:input, auto_demand_size: 5)
before |> child(:style_transfer, ...)
in the spec returned in pipeline, to avoid processing data chunks in StyleTransferFilter
.
If you have working soulution of Exercise B2, you can copy-paste there handle_init/2
, handle_setup/2
and handle_tick/3
from the solution below.
defmodule StyleTransferFilter do
use Membrane.Filter
def_input_pad(:input,
accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
)
def_output_pad(:output,
accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
)
@impl true
def handle_init(_ctx, _opts) do
state = %{models: nil, current_model: nil}
{[], state}
end
@impl true
def handle_setup(_ctx, state) do
directory_path = Path.join(__DIR__, "../../priv/models")
models =
File.ls!(directory_path)
|> Enum.map(fn model_filename ->
Path.join(directory_path, model_filename)
|> Ortex.load()
end)
current_model = Enum.random(models)
actions = [start_timer: {:my_timer, Membrane.Time.milliseconds(1500)}]
state = %{state | models: models, current_model: current_model}
{actions, state}
end
@impl true
def handle_tick(:my_timer, _ctx, state) do
current_model = Enum.random(state.models)
{[], %{state | current_model: current_model}}
end
@impl true
def handle_buffer(:input, buffer, ctx, state) do
input_tensor = preprocess(buffer.payload, ctx.pads.input.stream_format)
{output_tensor} =
Ortex.run(state.current_model, {
input_tensor,
Nx.tensor([1.0, 1.0, 1.0, 1.0], type: :f32)
})
output_payload = postprocess(output_tensor, ctx.pads.input.stream_format)
buffer = %{buffer | payload: output_payload}
{[buffer: {:output, buffer}], state}
end
def preprocess(payload, format) do
payload
|> Nx.from_binary(:u8, backend: EXLA.Backend)
|> Nx.as_type(:f32)
|> Nx.reshape({1, format.height, format.width, 3})
|> Nx.transpose(axes: [0, 3, 1, 2])
end
def postprocess(tensor, format) do
tensor
|> Nx.backend_transfer(EXLA.Backend)
|> Nx.reshape({3, format.height, format.width})
|> Nx.transpose(axes: [1, 2, 0])
|> clamp()
|> Nx.round()
|> Nx.as_type(:u8)
|> Nx.to_binary()
end
defp clamp(tensor) do
tensor
|> Nx.max(0.0)
|> Nx.min(255.0)
end
end