Hair Segmentation
File.cd!(__DIR__)
# for windows JP
System.shell("chcp 65001")
System.put_env("NNCOMPILED", "YES")
Mix.install([
{:tfl_interp, path: ".."},
{:cimg, "~> 0.1.19"},
{:nx, "~> 0.4.0"},
{:kino, "~> 0.7.0"}
])
0.Original work
“Real-time hair segmentation and recoloring on Mobile GPUs”, Google Research.
Thanks a lot!!!
Implementation with TflInterp in Elixir
1.Defining the inference module: HairSegmentation
-
Model
hair_segmentation.tflite: get from “https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite” if not existed.
-
Pre-processing
Resize the input image to the size
{@width, @height}
and normalize {0.0, 1.0}. -
Post-processing
Overlap the mask on the input image.
defmodule HairSegmentation do
@width 512
@height 512
alias TflInterp, as: NNInterp
use NNInterp,
model: "./model/hair_segmentation.tflite",
url: "https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite",
inputs: [f32: {1, @width, @height, 4}],
outputs: [f32: {1, @width, @height, 2}]
def apply(img) do
# preprocess
input0 =
CImg.builder(img)
|> CImg.resize({@width, @height})
|> CImg.append(CImg.create(@width, @height, 1, 1, 0), :c)
|> CImg.to_binary([{:range, {0.0, 1.0}}])
# prediction
output =
session()
|> NNInterp.set_input_tensor(0, input0)
|> NNInterp.invoke()
|> NNInterp.get_output_tensor(0)
|> Nx.from_binary(:f32)
|> Nx.reshape({@height, @width, :auto})
# postprocess
[background, hair] =
Enum.map(0..1, fn i ->
Nx.slice_along_axis(output, i, 1, axis: 2) |> Nx.squeeze()
end)
{w, h, _, _} = CImg.shape(img)
Nx.greater(hair, background)
|> Nx.to_binary()
|> CImg.from_binary(@width, @height, 1, 1, dtype: " CImg.resize({w, h}) # make image
end
def coloring(img, color, opacity \\ 0.5) do
mask = HairSegmentation.apply(img)
CImg.paint_mask(img, mask, color, opacity)
end
end
Launch HairSegmentation
.
# TflInterp.stop(HairSegmentation)
HairSegmentation.start_link([])
Display the properties of the HairSegmentation
model.
TflInterp.info(HairSegmentation)
2.Let’s try it
Load a photo and apply HairSegmentation to it.
img = CImg.load("photo_girl.jpg")
colored = HairSegmentation.coloring(img, [{0, 255, 0}], 0.3)
Enum.map([img, colored], &CImg.display_kino(&1, :jpeg))
|> Kino.Layout.grid(columns: 2)
□