Japanese embedding with Ruri
Mix.install(
[
{:bumblebee, "~> 0.6"},
{:exla, "~> 0.9"},
{:kino, "~> 0.14"},
{:hnswlib, "~> 0.1"},
{:scholar, "~> 0.4"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
前提条件
/tmp/ruri_base/
に ruri-base のモデルが保存されているものとします
また、トークナイザーが tokenizer.json
に変換されているものとします
- 1_Pooling/config.json
- config_sentence_transformers.json
- config.json
- model.safetensors
- modules.json
- README.md
- sentence_bert_config.json
- special_tokens_map.json
- tokenizer_config.json
- tokenizer.json
- vocab.txt
変換は以下に示すコンテナで実行できます
https://github.com/RyoWakabayashi/elixir-learning/tree/main/ml_model_conversion/ruri_tokenizer
テキスト埋め込みモデルの読込
{:ok, model_info} = Bumblebee.load_model({:local, "/tmp/ruri_base"})
トークナイザーの読込
{:ok, tokenizer} = Bumblebee.load_tokenizer({:local, "/tmp/ruri_base"})
検索用インデックスの構築
string_inputs =
[
"りんごはバラ科の落葉高木が実らせる果実で、世界中で広く栽培される。甘味と酸味のバランスが良く、生食のほかジュースや菓子など多彩な料理に利用される。ビタミンや食物繊維も豊富で、健康維持に役立つ。",
"コンピューターは情報を高速かつ正確に処理する装置で、計算やデータ分析、通信など多様な分野で活用される。人工知能の発展とともに進化し、人々の生活や産業を大きく支えている。モバイルなど形態も多様化している。",
"クジラは海洋に生息する巨大な哺乳類で、ヒゲクジラ類とハクジラ類に大別される。水中で呼吸を行うために定期的に海面に浮上し、高度な社会性やコミュニケーション能力を持つ。歌と呼ばれる鳴き声で意思疎通する種もいる。",
"ラーメンは中国の麺料理を起源とする日本の国民食のひとつ。小麦粉の麺とスープが主体で、醤油・味噌・塩・豚骨など多様な味が楽しめる。具材もチャーシューやメンマ、ネギなど豊富で、地域ごとに特色ある進化を遂げている。ラーメンは中国の麺料理を起源とする日本の国民食のひとつ。小麦粉の麺とスープが主体で、醤油・味噌・塩・豚骨など多様な味が楽しめる。具材もチャーシューやメンマ、ネギなど豊富で、地域ごとに特色ある進化を遂げている。",
]
|> Enum.map(fn input -> "文章: #{input}" end)
inputs = Bumblebee.apply_tokenizer(tokenizer, string_inputs)
embedding = Axon.predict(model_info.model, model_info.params, inputs, compiler: EXLA)
input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)
embeddings =
embedding.hidden_state
|> Nx.multiply(input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
|> Scholar.Preprocessing.normalize(norm: :euclidean)
|> Nx.to_batched(1)
|> Enum.to_list()
{:ok, index} = HNSWLib.Index.new(:cosine, 768, 100)
for embedding <- embeddings do
HNSWLib.Index.add_items(index, embedding)
end
HNSWLib.Index.get_current_count(index)
テキスト検索
search = fn query ->
query_inputs = Bumblebee.apply_tokenizer(tokenizer, ["クエリ: #{query}"])
query_embedding = Axon.predict(model_info.model, model_info.params, query_inputs, compiler: EXLA)
input_mask_expanded = Nx.new_axis(query_inputs["attention_mask"], -1)
query_embeddings =
query_embedding.hidden_state
|> Nx.multiply(input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
|> Scholar.Preprocessing.normalize(norm: :euclidean)
|> Nx.squeeze()
{:ok, labels, _dist} = HNSWLib.Index.knn_query(index, query_embeddings, k: 1)
labels
|> Nx.to_flat_list()
|> hd()
|> then(&Enum.at(string_inputs, &1))
end
search.("鯨について教えて")
search.("お腹が空いた")
search.("植物")