Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

LLMs and RAG

mistral_rag.livemd

LLMs and RAG

Mix.install(
  [
    {:bumblebee, "~> 0.5"},
    {:nx, "~> 0.9", override: true},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.14"},
    {:hnswlib, "~> 0.1"},
    {:req, "~> 0.4"}
  ],
  system_env: [
    {"XLA_TARGET", "cuda12"},
    {"EXLA_TARGET", "cuda"}
  ]
)

Nx.global_default_backend(EXLA.Backend)

知識の準備

青空文庫から楠山正雄さんの書いた「桃太郎」を転載したテキストを使用します

転載元: https://www.aozora.gr.jp/cards/000329/files/18376_12100.html

転載先: https://raw.githubusercontent.com/RyoWakabayashi/elixir-learning/main/livebooks/bumblebee/colab/momotaro.txt

%{body: text} =
  Req.get!(
    "https://raw.githubusercontent.com/RyoWakabayashi/elixir-learning/main/livebooks/bumblebee/colab/momotaro.txt"
  )

IO.puts("Document length: #{String.length(text)}")
chunks =
  text
  |> String.codepoints()
  |> Enum.chunk_every(1024)
  |> Enum.map(&Enum.join/1)

length(chunks)
repo = {:hf, "thenlper/gte-small"}

{:ok, model_info} = Bumblebee.load_model(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)

:ok
serving =
  Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
    compile: [batch_size: 64, sequence_length: 512],
    defn_options: [compiler: EXLA],
    output_attribute: :hidden_state,
    output_pool: :mean_pooling
  )

Kino.start_child({Nx.Serving, serving: serving, name: GteServing})
results = Nx.Serving.batched_run(GteServing, chunks)
chunk_embeddings = for result <- results, do: result.embedding

List.first(chunk_embeddings)

インデックスの作成と検索

{:ok, index} = HNSWLib.Index.new(:cosine, 384, 1_000_000)

for embedding <- chunk_embeddings do
  HNSWLib.Index.add_items(index, embedding)
end

HNSWLib.Index.get_current_count(index)
query = "桃太郎の家来の動物は何ですか?"

%{embedding: embedding} = Nx.Serving.batched_run(GteServing, query)

{:ok, labels, dist} = HNSWLib.Index.knn_query(index, embedding, k: 4)
# We can see some overlapping in our chunks
context =
  labels
  |> Nx.to_flat_list()
  |> Enum.sort()
  |> Enum.map(fn idx -> "[...] " <> Enum.at(chunks, idx) <> " [...]" end)
  |> Enum.join("\n\n")

IO.puts(context)

回答の生成

hf_token = System.fetch_env!("LB_HF_TOKEN")
repo = {:hf, "mistralai/Mistral-7B-Instruct-v0.2", auth_token: hf_token}

{:ok, model_info} = Bumblebee.load_model(repo, type: :bf16)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 100)

:ok
serving =
  Bumblebee.Text.generation(model_info, tokenizer, generation_config,
    compile: [batch_size: 1, sequence_length: 6000],
    defn_options: [compiler: EXLA]
  )

Kino.start_child({Nx.Serving, name: MistralServing, serving: serving})
prompt =
  """
  コンテキスト情報は以下の通りです.
  ---------------------
  #{context}
  ---------------------
  与えられたコンテキスト情報に基づき、事前の知識なしに質問に答えてください.
  質問: #{query}
  回答:
  """

results = Nx.Serving.batched_run(MistralServing, prompt)

And here we have our answer!

For additional context you can also visit the Mistral docs that go through a similar example.