Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us
Notesclub

Question answering

question_answering_exla_cuda.livemd

Question answering

Mix.install(
  [
    {:bumblebee, "~> 0.4"},
    {:nx, "~> 0.6"},
    {:exla, "~> 0.6"},
    {:kino, "~> 0.12"}
  ],
  system_env: [
    {"XLA_TARGET", "cuda114"}
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend
    ]
  ]
)

設定

cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

{:ok, roberta} =
  Bumblebee.load_model({
    :hf,
    "deepset/roberta-base-squad2",
    cache_dir: cache_dir
  })
{:ok, tokenizer} =
  Bumblebee.load_tokenizer({
    :hf,
    "roberta-base",
    cache_dir: cache_dir
  })

文章の準備

question_input =
  Kino.Input.text("QUESTION",
    default: "What industries does Elixir help?"
  )
context_input =
  Kino.Input.textarea("CONTEXT",
    default:
      ~s/Elixir is a dynamic, functional language for building scalable and maintainable applications. Elixir runs on the Erlang VM, known for creating low-latency, distributed, and fault-tolerant systems. These capabilities and Elixir tooling allow developers to be productive in several domains, such as web development, embedded software, data pipelines, and multimedia processing, across a wide range of industries./
  )
question = Kino.Input.read(question_input)
context = Kino.Input.read(context_input)
inputs = Bumblebee.apply_tokenizer(tokenizer, {question, context})
outputs = Axon.predict(roberta.model, roberta.params, inputs)
answer_start_index =
  outputs.start_logits
  |> Nx.argmax()
  |> Nx.to_number()
answer_end_index =
  outputs.end_logits
  |> Nx.argmax()
  |> Nx.to_number()
answer_tokens =
  inputs["input_ids"][[0, answer_start_index..answer_end_index]]
  |> Nx.to_flat_list()
Bumblebee.Tokenizer.decode(tokenizer, answer_tokens)

時間計測

proc = fn question, context ->
  inputs = Bumblebee.apply_tokenizer(tokenizer, {question, context})
  outputs = Axon.predict(roberta.model, roberta.params, inputs)

  answer_start_index =
    outputs.start_logits
    |> Nx.argmax()
    |> Nx.to_number()

  answer_end_index =
    outputs.end_logits
    |> Nx.argmax()
    |> Nx.to_number()

  answer_tokens =
    inputs["input_ids"][[0, answer_start_index..answer_end_index]]
    |> Nx.to_flat_list()

  Bumblebee.Tokenizer.decode(tokenizer, answer_tokens)
end
1..10
|> Enum.map(fn _ ->
  {time, _} = :timer.tc(proc, [question, context])
  time
end)
|> then(&(Enum.sum(&1) / 10))