Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Artistic Style Transfer

notebooks/artistic_style_transfer.livemd

Artistic Style Transfer

Mix.install([
  {:tflite_elixir, "~> 0.3.0"},
  {:evision, "0.1.31"},
  {:req, "~> 0.3.0"},
  {:kino, "~> 0.9.0"}
])

Introduction

One of the most exciting developments in deep learning to come out recently is artistic style transfer, or the ability to create a new image, known as a pastiche, based on two input images: one representing the artistic style and one representing the content.

Using this technique, we can generate beautiful new artworks in a range of styles.

https://www.tensorflow.org/lite/examples/style_transfer/overview

Understand the model architecture

This Artistic Style Transfer model consists of two submodels:

  1. Style Prediciton Model: A MobilenetV2-based neural network that takes an input style image to a 100-dimension style bottleneck vector.
  2. Style Transform Model: A neural network that takes apply a style bottleneck vector to a content image and creates a stylized image.

It’s useful to alias the module as something shorter when we make extensive use of the functions from certain modules.

alias Evision, as: Cv
alias TFLiteElixir, as: TFLite
alias TFLiteElixir.TFLiteTensor

Download data files

Download the content and style images, and the pre-trained TensorFlow Lite models.

# /data is the writable portion of a Nerves system
downloads_dir =
  if Code.ensure_loaded?(Nerves.Runtime), do: "/data/livebook", else: File.cwd!()

download = fn url ->
  save_as = Path.join(downloads_dir, URI.encode_www_form(url))
  unless File.exists?(save_as), do: Req.get!(url, output: save_as)
  save_as
end

data_files =
  [
    img_content:
      "https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg",
    img_style:
      "https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg",
    style_predict:
      "https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite",
    style_transform:
      "https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite"
  ]
  |> Enum.map(fn {key, url} -> {key, download.(url)} end)
  |> Map.new()

data_files
|> Enum.map(fn {k, v} -> [name: k, location: v] end)
|> Kino.DataTable.new(name: "Data files")

Load input images

content_image_input = Kino.Input.image("Content image")
content_image_mat =
  case Kino.Input.read(content_image_input) do
    %{data: data, height: height, width: width} ->
      Cv.Mat.from_binary(data, {:u, 8}, height, width, 3)
      |> Cv.cvtColor(Cv.Constant.cv_COLOR_BGR2RGB())

    nil ->
      Cv.imread(data_files.img_content)
      |> Cv.cvtColor(Cv.Constant.cv_COLOR_BGR2RGB())
  end

:ok
style_image_input = Kino.Input.image("Style image")
style_image_mat =
  case Kino.Input.read(style_image_input) do
    %{data: data, height: height, width: width} ->
      Cv.Mat.from_binary(data, {:u, 8}, height, width, 3)
      |> Cv.cvtColor(Cv.Constant.cv_COLOR_BGR2RGB())

    nil ->
      Cv.imread(data_files.img_style)
  end

:ok
[
  ["Content image", content_image_mat],
  ["Style image", style_image_mat]
]
|> Enum.map(fn [label, img] ->
  Kino.Layout.grid([img, Kino.Markdown.new("**#{label}**")], boxed: true)
end)
|> Kino.Layout.grid(columns: 2)

Pre-process the inputs

  • The content image and the style image must be RGB images with pixel values being float32 numbers between [0..1].
  • The style image size must be (1, 256, 256, 3). We central crop the image and resize it.
  • The content image must be (1, 384, 384, 3). We central crop the image and resize it.
preprocess_image = fn image, target_dim ->
  # Resize the image so that the shorter dimension becomes 256px.
  {h, w, _} = image.shape

  {resize_h, resize_w} =
    if h > w do
      scale = target_dim / w
      {round(h * scale), target_dim}
    else
      scale = target_dim / h
      {target_dim, round(w * scale)}
    end

  image = Cv.resize(image, {resize_h, resize_w})

  # Central crop the image.
  {centre_h, centre_w} = {resize_h / 2, resize_w / 2}
  x = round(centre_w - target_dim / 2)
  y = round(centre_h - target_dim / 2)
  new_shape = {y, x, target_dim, target_dim}
  cropped_image = Cv.Mat.roi(image, new_shape)

  %{
    image: cropped_image,
    tensor:
      cropped_image
      |> Cv.Mat.to_nx(Nx.BinaryBackend)
      |> Nx.new_axis(0)
      |> Nx.as_type({:f, 32})
      # Change the range from [0..255] to [0..1]
      |> Nx.divide(255)
  }
end

# Preprocess the input images.
preprocessed_content = preprocess_image.(content_image_mat, 384)
preprocessed_style = preprocess_image.(style_image_mat, 256)

IO.puts(["Style Image Shape: ", inspect(preprocessed_style.tensor.shape)])
IO.puts(["Content Image Shape: ", inspect(preprocessed_content.tensor.shape)])
[
  ["Content image preprocessed", preprocessed_content.image],
  ["Style image preprocessed", preprocessed_style.image]
]
|> Enum.map(fn [label, img] ->
  Kino.Layout.grid([img, Kino.Markdown.new("**#{label}**")], boxed: true)
end)
|> Kino.Layout.grid(columns: 2)

Run style prediction

# Run style prediction on preprocessed style image.
run_style_predict = fn <<_::binary>> = preprocessed_style_data ->
  # Load the model.
  {:ok, interpreter} = TFLite.Interpreter.new(data_files.style_predict)

  # Get input and output tensors.
  {:ok, _input_tensors} = TFLite.Interpreter.inputs(interpreter)
  {:ok, output_tensors} = TFLite.Interpreter.outputs(interpreter)

  # Run the model
  TFLite.Interpreter.input_tensor(interpreter, 0, preprocessed_style_data)
  TFLite.Interpreter.invoke(interpreter)

  # Calculate style bottleneck for the preprocessed style image.
  TFLite.Interpreter.tensor(interpreter, Enum.at(output_tensors, 0))
  |> TFLiteTensor.to_nx(backend: Nx.BinaryBackend)
end

style_bottleneck = run_style_predict.(Nx.to_binary(preprocessed_style.tensor))

Run style transform

# Run style transform on preprocessed style image.
run_style_transform = fn <<_::binary>> = style_bottleneck_data,
                         <<_::binary>> = preprocessed_content_data ->
  # Load the model.
  {:ok, interpreter} = TFLite.Interpreter.new(data_files.style_transform)

  # Run the model.
  TFLite.Interpreter.input_tensor(interpreter, 0, preprocessed_content_data)
  TFLite.Interpreter.input_tensor(interpreter, 1, style_bottleneck_data)
  TFLite.Interpreter.invoke(interpreter)

  # Get output tensors.
  {:ok, output_tensors} = TFLite.Interpreter.outputs(interpreter)

  # Transform content image.
  out_tensor = TFLite.Interpreter.tensor(interpreter, Enum.at(output_tensors, 0))
  [1 | shape] = TFLite.TFLiteTensor.dims(out_tensor)

  out_tensor
  |> TFLiteTensor.to_nx(backend: Nx.BinaryBackend)
  # Change the range from [0..1] to [0..255]
  |> Nx.multiply(255)
  |> Nx.reshape(List.to_tuple(shape))
  |> Nx.as_type({:u, 8})
  |> Cv.Mat.from_nx_2d()
  |> Cv.cvtColor(Cv.Constant.cv_COLOR_RGB2BGR())
end

# Stylize the content image using the style bottleneck.
run_style_transform.(
  Nx.to_binary(style_bottleneck),
  Nx.to_binary(preprocessed_content.tensor)
)

Run style blending

# Calculate style bottleneck of the content image.
style_bottleneck_content =
  preprocess_image.(content_image_mat, 256).tensor
  |> Nx.to_binary()
  |> run_style_predict.()
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended =
  Nx.add(
    Nx.multiply(content_blending_ratio, style_bottleneck_content),
    Nx.multiply(1 - content_blending_ratio, style_bottleneck)
  )
# Stylize the content image using the style bottleneck.
run_style_transform.(
  Nx.to_binary(style_bottleneck_blended),
  Nx.to_binary(preprocessed_content.tensor)
)
img = Cv.imread("/Users/cocoa/Downloads/cat.jpg")
[b, g, r] = Cv.split(img)
g