Artistic Style Transfer
Mix.install([
{:tflite_elixir, "~> 0.3.0"},
{:evision, "0.1.31"},
{:req, "~> 0.3.0"},
{:kino, "~> 0.9.0"}
])
Introduction
One of the most exciting developments in deep learning to come out recently is artistic style transfer, or the ability to create a new image, known as a pastiche, based on two input images: one representing the artistic style and one representing the content.
Using this technique, we can generate beautiful new artworks in a range of styles.
https://www.tensorflow.org/lite/examples/style_transfer/overview
Understand the model architecture
This Artistic Style Transfer model consists of two submodels:
- Style Prediciton Model: A MobilenetV2-based neural network that takes an input style image to a 100-dimension style bottleneck vector.
- Style Transform Model: A neural network that takes apply a style bottleneck vector to a content image and creates a stylized image.
It’s useful to alias the module as something shorter when we make extensive use of the functions from certain modules.
alias Evision, as: Cv
alias TFLiteElixir, as: TFLite
alias TFLiteElixir.TFLiteTensor
Download data files
Download the content and style images, and the pre-trained TensorFlow Lite models.
# /data is the writable portion of a Nerves system
downloads_dir =
if Code.ensure_loaded?(Nerves.Runtime), do: "/data/livebook", else: File.cwd!()
download = fn url ->
save_as = Path.join(downloads_dir, URI.encode_www_form(url))
unless File.exists?(save_as), do: Req.get!(url, output: save_as)
save_as
end
data_files =
[
img_content:
"https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg",
img_style:
"https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg",
style_predict:
"https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite",
style_transform:
"https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite"
]
|> Enum.map(fn {key, url} -> {key, download.(url)} end)
|> Map.new()
data_files
|> Enum.map(fn {k, v} -> [name: k, location: v] end)
|> Kino.DataTable.new(name: "Data files")
Load input images
content_image_input = Kino.Input.image("Content image")
content_image_mat =
case Kino.Input.read(content_image_input) do
%{data: data, height: height, width: width} ->
Cv.Mat.from_binary(data, {:u, 8}, height, width, 3)
|> Cv.cvtColor(Cv.Constant.cv_COLOR_BGR2RGB())
nil ->
Cv.imread(data_files.img_content)
|> Cv.cvtColor(Cv.Constant.cv_COLOR_BGR2RGB())
end
:ok
style_image_input = Kino.Input.image("Style image")
style_image_mat =
case Kino.Input.read(style_image_input) do
%{data: data, height: height, width: width} ->
Cv.Mat.from_binary(data, {:u, 8}, height, width, 3)
|> Cv.cvtColor(Cv.Constant.cv_COLOR_BGR2RGB())
nil ->
Cv.imread(data_files.img_style)
end
:ok
[
["Content image", content_image_mat],
["Style image", style_image_mat]
]
|> Enum.map(fn [label, img] ->
Kino.Layout.grid([img, Kino.Markdown.new("**#{label}**")], boxed: true)
end)
|> Kino.Layout.grid(columns: 2)
Pre-process the inputs
- The content image and the style image must be RGB images with pixel values being float32 numbers between [0..1].
- The style image size must be (1, 256, 256, 3). We central crop the image and resize it.
- The content image must be (1, 384, 384, 3). We central crop the image and resize it.
preprocess_image = fn image, target_dim ->
# Resize the image so that the shorter dimension becomes 256px.
{h, w, _} = image.shape
{resize_h, resize_w} =
if h > w do
scale = target_dim / w
{round(h * scale), target_dim}
else
scale = target_dim / h
{target_dim, round(w * scale)}
end
image = Cv.resize(image, {resize_h, resize_w})
# Central crop the image.
{centre_h, centre_w} = {resize_h / 2, resize_w / 2}
x = round(centre_w - target_dim / 2)
y = round(centre_h - target_dim / 2)
new_shape = {y, x, target_dim, target_dim}
cropped_image = Cv.Mat.roi(image, new_shape)
%{
image: cropped_image,
tensor:
cropped_image
|> Cv.Mat.to_nx(Nx.BinaryBackend)
|> Nx.new_axis(0)
|> Nx.as_type({:f, 32})
# Change the range from [0..255] to [0..1]
|> Nx.divide(255)
}
end
# Preprocess the input images.
preprocessed_content = preprocess_image.(content_image_mat, 384)
preprocessed_style = preprocess_image.(style_image_mat, 256)
IO.puts(["Style Image Shape: ", inspect(preprocessed_style.tensor.shape)])
IO.puts(["Content Image Shape: ", inspect(preprocessed_content.tensor.shape)])
[
["Content image preprocessed", preprocessed_content.image],
["Style image preprocessed", preprocessed_style.image]
]
|> Enum.map(fn [label, img] ->
Kino.Layout.grid([img, Kino.Markdown.new("**#{label}**")], boxed: true)
end)
|> Kino.Layout.grid(columns: 2)
Run style prediction
# Run style prediction on preprocessed style image.
run_style_predict = fn <<_::binary>> = preprocessed_style_data ->
# Load the model.
{:ok, interpreter} = TFLite.Interpreter.new(data_files.style_predict)
# Get input and output tensors.
{:ok, _input_tensors} = TFLite.Interpreter.inputs(interpreter)
{:ok, output_tensors} = TFLite.Interpreter.outputs(interpreter)
# Run the model
TFLite.Interpreter.input_tensor(interpreter, 0, preprocessed_style_data)
TFLite.Interpreter.invoke(interpreter)
# Calculate style bottleneck for the preprocessed style image.
TFLite.Interpreter.tensor(interpreter, Enum.at(output_tensors, 0))
|> TFLiteTensor.to_nx(backend: Nx.BinaryBackend)
end
style_bottleneck = run_style_predict.(Nx.to_binary(preprocessed_style.tensor))
Run style transform
# Run style transform on preprocessed style image.
run_style_transform = fn <<_::binary>> = style_bottleneck_data,
<<_::binary>> = preprocessed_content_data ->
# Load the model.
{:ok, interpreter} = TFLite.Interpreter.new(data_files.style_transform)
# Run the model.
TFLite.Interpreter.input_tensor(interpreter, 0, preprocessed_content_data)
TFLite.Interpreter.input_tensor(interpreter, 1, style_bottleneck_data)
TFLite.Interpreter.invoke(interpreter)
# Get output tensors.
{:ok, output_tensors} = TFLite.Interpreter.outputs(interpreter)
# Transform content image.
out_tensor = TFLite.Interpreter.tensor(interpreter, Enum.at(output_tensors, 0))
[1 | shape] = TFLite.TFLiteTensor.dims(out_tensor)
out_tensor
|> TFLiteTensor.to_nx(backend: Nx.BinaryBackend)
# Change the range from [0..1] to [0..255]
|> Nx.multiply(255)
|> Nx.reshape(List.to_tuple(shape))
|> Nx.as_type({:u, 8})
|> Cv.Mat.from_nx_2d()
|> Cv.cvtColor(Cv.Constant.cv_COLOR_RGB2BGR())
end
# Stylize the content image using the style bottleneck.
run_style_transform.(
Nx.to_binary(style_bottleneck),
Nx.to_binary(preprocessed_content.tensor)
)
Run style blending
# Calculate style bottleneck of the content image.
style_bottleneck_content =
preprocess_image.(content_image_mat, 256).tensor
|> Nx.to_binary()
|> run_style_predict.()
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5
# Blend the style bottleneck of style image and content image
style_bottleneck_blended =
Nx.add(
Nx.multiply(content_blending_ratio, style_bottleneck_content),
Nx.multiply(1 - content_blending_ratio, style_bottleneck)
)
# Stylize the content image using the style bottleneck.
run_style_transform.(
Nx.to_binary(style_bottleneck_blended),
Nx.to_binary(preprocessed_content.tensor)
)
img = Cv.imread("/Users/cocoa/Downloads/cat.jpg")
[b, g, r] = Cv.split(img)
g