Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Japanese embedding with Ruri

bumblebee_ruri_embedding.livemd

Japanese embedding with Ruri

Mix.install(
  [
    {:bumblebee, "~> 0.6"},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.14"},
    {:hnswlib, "~> 0.1"},
    {:scholar, "~> 0.4"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

前提条件

/tmp/ruri_base/ruri-base のモデルが保存されているものとします

また、トークナイザーが tokenizer.json に変換されているものとします

  • 1_Pooling/config.json
  • config_sentence_transformers.json
  • config.json
  • model.safetensors
  • modules.json
  • README.md
  • sentence_bert_config.json
  • special_tokens_map.json
  • tokenizer_config.json
  • tokenizer.json
  • vocab.txt

変換は以下に示すコンテナで実行できます

https://github.com/RyoWakabayashi/elixir-learning/tree/main/ml_model_conversion/ruri_tokenizer

テキスト埋め込みモデルの読込

{:ok, model_info} = Bumblebee.load_model({:local, "/tmp/ruri_base"})

トークナイザーの読込

{:ok, tokenizer} = Bumblebee.load_tokenizer({:local, "/tmp/ruri_base"})

検索用インデックスの構築

string_inputs =
  [
    "りんごはバラ科の落葉高木が実らせる果実で、世界中で広く栽培される。甘味と酸味のバランスが良く、生食のほかジュースや菓子など多彩な料理に利用される。ビタミンや食物繊維も豊富で、健康維持に役立つ。",
    "コンピューターは情報を高速かつ正確に処理する装置で、計算やデータ分析、通信など多様な分野で活用される。人工知能の発展とともに進化し、人々の生活や産業を大きく支えている。モバイルなど形態も多様化している。",
    "クジラは海洋に生息する巨大な哺乳類で、ヒゲクジラ類とハクジラ類に大別される。水中で呼吸を行うために定期的に海面に浮上し、高度な社会性やコミュニケーション能力を持つ。歌と呼ばれる鳴き声で意思疎通する種もいる。",
    "ラーメンは中国の麺料理を起源とする日本の国民食のひとつ。小麦粉の麺とスープが主体で、醤油・味噌・塩・豚骨など多様な味が楽しめる。具材もチャーシューやメンマ、ネギなど豊富で、地域ごとに特色ある進化を遂げている。ラーメンは中国の麺料理を起源とする日本の国民食のひとつ。小麦粉の麺とスープが主体で、醤油・味噌・塩・豚骨など多様な味が楽しめる。具材もチャーシューやメンマ、ネギなど豊富で、地域ごとに特色ある進化を遂げている。",
  ]
  |> Enum.map(fn input -> "文章: #{input}" end)

inputs = Bumblebee.apply_tokenizer(tokenizer, string_inputs)
embedding = Axon.predict(model_info.model, model_info.params, inputs, compiler: EXLA)

input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)

embeddings =
  embedding.hidden_state
  |> Nx.multiply(input_mask_expanded)
  |> Nx.sum(axes: [1])
  |> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
  |> Scholar.Preprocessing.normalize(norm: :euclidean)
  |> Nx.to_batched(1)
  |> Enum.to_list()
{:ok, index} = HNSWLib.Index.new(:cosine, 768, 100)

for embedding <- embeddings do
  HNSWLib.Index.add_items(index, embedding)
end

HNSWLib.Index.get_current_count(index)

テキスト検索

search = fn query ->
  query_inputs = Bumblebee.apply_tokenizer(tokenizer, ["クエリ: #{query}"])

  query_embedding = Axon.predict(model_info.model, model_info.params, query_inputs, compiler: EXLA)
  
  input_mask_expanded = Nx.new_axis(query_inputs["attention_mask"], -1)
  
  query_embeddings =
    query_embedding.hidden_state
    |> Nx.multiply(input_mask_expanded)
    |> Nx.sum(axes: [1])
    |> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
    |> Scholar.Preprocessing.normalize(norm: :euclidean)
    |> Nx.squeeze()
  
  {:ok, labels, _dist} = HNSWLib.Index.knn_query(index, query_embeddings, k: 1)
  
  labels
  |> Nx.to_flat_list()
  |> hd()
  |> then(&amp;Enum.at(string_inputs, &amp;1))
end
search.("鯨について教えて")
search.("お腹が空いた")
search.("植物")