LLMs and RAG
Mix.install(
[
{:bumblebee, "~> 0.5"},
{:nx, "~> 0.9", override: true},
{:exla, "~> 0.9"},
{:kino, "~> 0.14"},
{:hnswlib, "~> 0.1"},
{:req, "~> 0.4"}
],
system_env: [
{"XLA_TARGET", "cuda12"},
{"EXLA_TARGET", "cuda"}
]
)
Nx.global_default_backend(EXLA.Backend)
知識の準備
青空文庫から楠山正雄さんの書いた「桃太郎」を転載したテキストを使用します
転載元: https://www.aozora.gr.jp/cards/000329/files/18376_12100.html
%{body: text} =
Req.get!(
"https://raw.githubusercontent.com/RyoWakabayashi/elixir-learning/main/livebooks/bumblebee/colab/momotaro.txt"
)
IO.puts("Document length: #{String.length(text)}")
chunks =
text
|> String.codepoints()
|> Enum.chunk_every(1024)
|> Enum.map(&Enum.join/1)
length(chunks)
repo = {:hf, "thenlper/gte-small"}
{:ok, model_info} = Bumblebee.load_model(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
:ok
serving =
Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
compile: [batch_size: 64, sequence_length: 512],
defn_options: [compiler: EXLA],
output_attribute: :hidden_state,
output_pool: :mean_pooling
)
Kino.start_child({Nx.Serving, serving: serving, name: GteServing})
results = Nx.Serving.batched_run(GteServing, chunks)
chunk_embeddings = for result <- results, do: result.embedding
List.first(chunk_embeddings)
インデックスの作成と検索
{:ok, index} = HNSWLib.Index.new(:cosine, 384, 1_000_000)
for embedding <- chunk_embeddings do
HNSWLib.Index.add_items(index, embedding)
end
HNSWLib.Index.get_current_count(index)
query = "桃太郎の家来の動物は何ですか?"
%{embedding: embedding} = Nx.Serving.batched_run(GteServing, query)
{:ok, labels, dist} = HNSWLib.Index.knn_query(index, embedding, k: 4)
# We can see some overlapping in our chunks
context =
labels
|> Nx.to_flat_list()
|> Enum.sort()
|> Enum.map(fn idx -> "[...] " <> Enum.at(chunks, idx) <> " [...]" end)
|> Enum.join("\n\n")
IO.puts(context)
回答の生成
hf_token = System.fetch_env!("LB_HF_TOKEN")
repo = {:hf, "mistralai/Mistral-7B-Instruct-v0.2", auth_token: hf_token}
{:ok, model_info} = Bumblebee.load_model(repo, type: :bf16)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 100)
:ok
serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
compile: [batch_size: 1, sequence_length: 6000],
defn_options: [compiler: EXLA]
)
Kino.start_child({Nx.Serving, name: MistralServing, serving: serving})
prompt =
"""
コンテキスト情報は以下の通りです.
---------------------
#{context}
---------------------
与えられたコンテキスト情報に基づき、事前の知識なしに質問に答えてください.
質問: #{query}
回答:
"""
results = Nx.Serving.batched_run(MistralServing, prompt)
And here we have our answer!
For additional context you can also visit the Mistral docs that go through a similar example.