mobile Bert Q&A
File.cd!(__DIR__)
# for windows JP
System.shell("chcp 65001")
System.put_env("NNCOMPILED", "YES")
# CAUTION: TflInerp takes several tens of minutes to download the Tensorflow source code.
# In this setup, we are using TflInterp which is downloaded and built first.
Mix.install([
{:tfl_interp, path: ".."},
{:nx, "~> 0.4.1"}
])
0.Original work
“BERT Question and Answer”
“Google Research/MobileBERT”
Thanks a lot!!!
Implementation for Elixir using TflInterp
1.Defining companion modules
defmodule MBertQA.Tokenizer do
@moduledoc """
Mini Tokenizer for Tensorflow lite's mobileBert example.
"""
@dic1 :word2token1
@dic2 :word2token2
@dic3 :token2word
@doc """
Load a vocabulary dictionary.
"""
def load_dic(fname) do
Enum.each([@dic1, @dic2, @dic3], fn name ->
if :ets.whereis(name) != :undefined, do: :ets.delete(name)
:ets.new(name, [:named_table])
end)
File.stream!(fname)
|> Stream.map(&String.trim_trailing/1)
|> Stream.with_index()
|> Enum.each(fn {word, index} ->
# init dic for encoding.
case word do
<<"##", trailing::binary>> -> :ets.insert(@dic2, {trailing, index})
_ -> :ets.insert(@dic1, {word, index})
end
# init dic for decoding.
:ets.insert(@dic3, {index, word})
end)
end
@doc """
Tokenize text.
"""
def tokenize(text, tolower \\ true) do
if tolower do
String.downcase(text)
else
text
end
|> text2words()
|> words2tokens()
end
defp text2words(text) do
text
# cleansing
|> String.replace(~r/([[:cntrl:]]|\xff\xfd)/, "")
# separate panc with whitespace
|> String.replace(~r/([[:punct:]])/, " \\1 ")
# split with whitespace
|> String.split()
end
defp words2tokens(words) do
Enum.reduce(words, [], fn word, tokens ->
case wordpiece1(word, len = String.length(word)) do
{^len, token} ->
[token | tokens]
{0, token} ->
[token | tokens]
{n, token} ->
len = len - n
wordpiece2(String.slice(word, n, len), len, [token | tokens])
end
end)
|> Enum.reverse()
end
defp wordpiece1(_, 0), do: {0, token("[UNK]")}
defp wordpiece1(word, n) do
case lookup_1(String.slice(word, 0, n)) do
{_piece, token} ->
{n, token}
nil ->
wordpiece1(word, n - 1)
end
end
defp wordpiece2(_, 0, tokens), do: [token("[UNK]") | tokens]
defp wordpiece2(word, n, tokens) do
case lookup_2(String.slice(word, 0, n)) do
{_piece, token} ->
len = String.length(word) - n
if len == 0 do
[token | tokens]
else
wordpiece2(String.slice(word, n, len), len, [token | tokens])
end
nil ->
wordpiece2(word, n - 1, tokens)
end
end
defp lookup_1(x), do: lookup(@dic1, x)
defp lookup_2(x), do: lookup(@dic2, x)
defp lookup(dic, x), do: List.first(:ets.lookup(dic, x))
@doc """
Get token of `word`.
"""
def token(word) do
try do
:ets.lookup_element(@dic1, word, 2)
rescue
ArgumentError -> nil
end
end
@doc """
Decode the token list `tokens`.
"""
def decode(tokens) do
Enum.map(tokens, fn x -> :ets.lookup_element(@dic3, x, 2) end)
end
end
defmodule MBertQA.Feature do
@moduledoc """
Feature converter for Tensorflow lite's mobileBert example.
Feature tensor[4][] is consisted of
[0] list of tokens converted from query and contex text in the dictionary.
[1] list of token masks. all are 1.
[2] segment type indicator. 0 for query, 1 for context.
[3] index of the token's position in the context map. see bellow.
Context map is a list of words separated from the context by whitespace.
"""
alias MBertQA.Tokenizer
@max_seq 384
@max_query 64
defdelegate load_dic(fname), to: Tokenizer, as: :load_dic
@doc """
"""
def convert(query, context) do
{t_query, room} = convert_query(query)
{t_context, room, context_map} = convert_context(context, room)
{
Nx.concatenate([cls(), t_query, sep(0), t_context, sep(1), padding(room)], axis: 1),
context_map
}
end
defp convert_query(text, room \\ @max_seq - 3) do
t_query = convert_text(text, @max_query, 0)
{
t_query,
room - Nx.axis_size(t_query, 1)
}
end
defp convert_context(text, room) do
context_map = String.split(text)
{t_context, room} =
context_map
|> Stream.with_index()
|> Enum.reduce_while({[], room}, fn
{_, _}, {_, 0} = res ->
{:halt, res}
{word, index}, {t_context, max_len} ->
t_word = convert_text(word, max_len, 1, index)
{:cont, {[t_word | t_context], max_len - Nx.axis_size(t_word, 1)}}
end)
{
Enum.reverse(t_context) |> Nx.concatenate(axis: 1),
room,
context_map
}
end
defp convert_text(text, max_len, segment, index \\ -1) do
tokens =
text
|> Tokenizer.tokenize()
|> Enum.take(max_len)
|> Nx.tensor()
len = Nx.axis_size(tokens, 0)
Nx.stack([
tokens,
Nx.broadcast(1, {len}),
Nx.broadcast(segment, {len}),
Nx.broadcast(index, {len})
])
|> Nx.as_type({:s, 32})
end
defp cls() do
Nx.tensor([[Tokenizer.token("[CLS]")], [1], [0], [-1]], type: {:s, 32})
end
defp sep(segment) do
Nx.tensor([[Tokenizer.token("[SEP]")], [1], [segment], [-1]], type: {:s, 32})
end
defp padding(n) do
Nx.tensor([0, 0, 0, -1], type: {:s, 32})
|> Nx.broadcast({4, n}, axes: [0])
end
end
2.Defining the inference module: MBertQA
-
Pre-processing:
-
Post-processing:
defmodule MBertQA do
@moduledoc """
Documentation for `MBertQA`.
"""
alias TflInterp, as: NNInterp
use NNInterp,
model: "./model/lite-model_mobilebert_1_metadata_1.tflite",
url:
"https://github.com/shoz-f/tfl_interp/releases/download/0.0.1/lite-model_mobilebert_1_metadata_1.tflite"
alias MBertQA.Feature
@max_ans 32
@predict_num 5
def setup() do
Feature.load_dic("./model/vocab.txt")
end
def apply(query, context, predict_num \\ @predict_num) do
# pre-processing
{feature, context_map} = Feature.convert(query, context)
# prediction
mod =
__MODULE__
|> NNInterp.set_input_tensor(0, Nx.to_binary(feature[0]))
|> NNInterp.set_input_tensor(1, Nx.to_binary(feature[1]))
|> NNInterp.set_input_tensor(2, Nx.to_binary(feature[2]))
|> NNInterp.invoke()
[end_logits, beg_logits] =
Enum.map(0..1, fn x ->
NNInterp.get_output_tensor(mod, x)
|> Nx.from_binary({:f, 32})
end)
# post-processing
[beg_index, end_index] =
Enum.map([beg_logits, end_logits], fn t ->
Nx.argsort(t, direction: :desc)
|> Nx.slice_along_axis(0, predict_num)
|> Nx.to_flat_list()
|> Enum.filter(&(feature[3][&1] >= 0))
end)
for b <- beg_index, e <- end_index, b <= e, e - b + 1 < @max_ans do
{b, e, Nx.to_number(Nx.add(beg_logits[b], end_logits[e]))}
end
|> Enum.sort(&(elem(&1, 2) >= elem(&2, 2)))
|> Enum.take(predict_num)
# make answer text with score
|> Enum.map(fn {b, e, score} ->
b = Nx.to_number(feature[3][b])
e = Nx.to_number(feature[3][e])
{
Enum.slice(context_map, b..e) |> Enum.join(" "),
score
}
end)
end
end
Launch MBertQA
.
# TflInterp.stop(MBertQA)
MBertQA.start_link([])
Displays the properties of the MBertQA
model.
TflInterp.info(MBertQA)
3.Let’s try it
MBertQA.setup()
context = """
Google LLC is an American multinational technology company that specializes in
Internet-related services and products, which include online advertising
technologies, search engine, cloud computing, software, and hardware. It is
considered one of the Big Four technology companies, alongside Amazon, Apple,
and Facebook.
Google was founded in September 1998 by Larry Page and Sergey Brin while they
were Ph.D. students at Stanford University in California. Together they own
about 14 percent of its shares and control 56 percent of the stockholder voting
power through supervoting stock. They incorporated Google as a California
privately held company on September 4, 1998, in California. Google was then
reincorporated in Delaware on October 22, 2002. An initial public offering (IPO)
took place on August 19, 2004, and Google moved to its headquarters in Mountain
View, California, nicknamed the Googleplex. In August 2015, Google announced
plans to reorganize its various interests as a conglomerate called Alphabet Inc.
Google is Alphabet's leading subsidiary and will continue to be the umbrella
company for Alphabet's Internet interests. Sundar Pichai was appointed CEO of
Google, replacing Larry Page who became the CEO of Alphabet.
"""
"What is CEO of Google?"
|> MBertQA.apply(context, 3)
|> Enum.each(fn {ans, score} ->
IO.puts("\n>ANS: \"#{ans}\", score:#{score}")
end)
4.TIL ;-)
□