Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Solutions: B

livebooks/solutions/B_solution.livemd

Solutions: B

Logger.configure(level: :info)

# All necessary dependencies are installed by installing the package below
Mix.install([
  {:workshop_elixir_conf_us_2024, path: Path.join(__DIR__, "../..")}
])

Exercise B1

defmodule StyleTransferFilter do
  use Membrane.Filter

  def_input_pad(:input,
    accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
  )

  def_output_pad(:output,
    accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
  )

  @impl true
  def handle_init(_ctx, _opts), do: {[], %{model: nil, loaded_params: nil}}

  @impl true
  def handle_stream_format(:input, format, _ctx, state) do
    model = Workshop.Models.Mosaic.model(format.height, format.width)

    loaded_params =
      "#{__DIR__}/../../priv/nx/mosaic.nx"
      |> File.read!()
      |> Nx.deserialize()

    {[stream_format: {:output, format}], %{state | model: model, loaded_params: loaded_params}}
  end

  @impl true
  def handle_buffer(:input, buffer, ctx, state) do
    input_tensor = preprocess(buffer.payload, ctx.pads.input.stream_format)
    output_tensor = Axon.predict(state.model, state.loaded_params, %{"data" => input_tensor})

    output_payload = postprocess(output_tensor, ctx.pads.input.stream_format)
    buffer = %{buffer | payload: output_payload}

    {[buffer: {:output, buffer}], state}
  end

  def preprocess(payload, format) do
    payload
    |> Nx.from_binary(:u8, backend: EXLA.Backend)
    |> Nx.as_type(:f32)
    |> Nx.reshape({1, format.height, format.width, 3})
  end

  def postprocess(tensor, input_stream_format) do
    tensor
    |> Nx.backend_transfer(EXLA.Backend)
    |> Nx.reshape({3, input_stream_format.height, input_stream_format.width})
    |> Workshop.Models.Mosaic.postprocess()
  end
end

Exercise B2

defmodule StyleTransferFilter do
  use Membrane.Filter

  def_input_pad(:input,
    accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
  )

  def_output_pad(:output,
    accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
  )

  @impl true
  def handle_init(_ctx, _opts), do: {[], %{model: nil}}

  @impl true
  def handle_setup(_ctx, state) do
    model = "#{__DIR__}/../../priv/models/picasso.onnx" |> Ortex.load()
    {[], %{state | model: model}}
  end

  @impl true
  def handle_buffer(:input, buffer, ctx, state) do
    input_tensor = preprocess(buffer.payload, ctx.pads.input.stream_format)

    {output_tensor} =
      Ortex.run(state.model, {
        input_tensor,
        Nx.tensor([1.0, 1.0, 1.0, 1.0], type: :f32)
      })

    output_payload = postprocess(output_tensor, ctx.pads.input.stream_format)
    buffer = %{buffer | payload: output_payload}

    {[buffer: {:output, buffer}], state}
  end

  def preprocess(payload, format) do
    payload
    |> Nx.from_binary(:u8, backend: EXLA.Backend)
    |> Nx.as_type(:f32)
    |> Nx.reshape({1, format.height, format.width, 3})
    |> Nx.transpose(axes: [0, 3, 1, 2])
  end

  def postprocess(tensor, format) do
    tensor
    |> Nx.backend_transfer(EXLA.Backend)
    |> Nx.reshape({3, format.height, format.width})
    |> Nx.transpose(axes: [1, 2, 0])
    |> clamp()
    |> Nx.round()
    |> Nx.as_type(:u8)
    |> Nx.to_binary()
  end

  defp clamp(tensor) do
    tensor
    |> Nx.max(0.0)
    |> Nx.min(255.0)
  end
end

Exercise B3

handle_init/2, handle_setup/2 and handle_tick/3 should have implementations as follows. handle_buffer/4 should remain unchanged.

Add |> via_in(:input, auto_demand_size: 5) before |> child(:style_transfer, ...) in the spec returned in pipeline, to avoid processing data chunks in StyleTransferFilter.

If you have working soulution of Exercise B2, you can copy-paste there handle_init/2, handle_setup/2 and handle_tick/3 from the solution below.

defmodule StyleTransferFilter do
  use Membrane.Filter

  def_input_pad(:input,
    accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
  )

  def_output_pad(:output,
    accepted_format: %Membrane.RawVideo{pixel_format: :RGB}
  )

  @impl true
  def handle_init(_ctx, _opts) do
    state = %{models: nil, current_model: nil}
    {[], state}
  end

  @impl true
  def handle_setup(_ctx, state) do
    directory_path = Path.join(__DIR__, "../../priv/models")

    models =
      File.ls!(directory_path)
      |> Enum.map(fn model_filename ->
        Path.join(directory_path, model_filename)
        |> Ortex.load()
      end)

    current_model = Enum.random(models)

    actions = [start_timer: {:my_timer, Membrane.Time.milliseconds(1500)}]
    state = %{state | models: models, current_model: current_model}
    {actions, state}
  end

  @impl true
  def handle_tick(:my_timer, _ctx, state) do
    current_model = Enum.random(state.models)
    {[], %{state | current_model: current_model}}
  end

  @impl true
  def handle_buffer(:input, buffer, ctx, state) do
    input_tensor = preprocess(buffer.payload, ctx.pads.input.stream_format)

    {output_tensor} =
      Ortex.run(state.current_model, {
        input_tensor,
        Nx.tensor([1.0, 1.0, 1.0, 1.0], type: :f32)
      })

    output_payload = postprocess(output_tensor, ctx.pads.input.stream_format)
    buffer = %{buffer | payload: output_payload}

    {[buffer: {:output, buffer}], state}
  end

  def preprocess(payload, format) do
    payload
    |> Nx.from_binary(:u8, backend: EXLA.Backend)
    |> Nx.as_type(:f32)
    |> Nx.reshape({1, format.height, format.width, 3})
    |> Nx.transpose(axes: [0, 3, 1, 2])
  end

  def postprocess(tensor, format) do
    tensor
    |> Nx.backend_transfer(EXLA.Backend)
    |> Nx.reshape({3, format.height, format.width})
    |> Nx.transpose(axes: [1, 2, 0])
    |> clamp()
    |> Nx.round()
    |> Nx.as_type(:u8)
    |> Nx.to_binary()
  end

  defp clamp(tensor) do
    tensor
    |> Nx.max(0.0)
    |> Nx.min(255.0)
  end
end