Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Text embeddings

sentence_transformer.livemd

Text embeddings

Mix.install(
  [
    {:kino_bumblebee, "~> 0.3.0"},
    {:exla, "~> 0.5.3"},
    {:scholar, "~> 0.1.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Simple Embedding Example

Given https://github.com/elixir-nx/bumblebee/issues/100#issuecomment-1372230339 use mean pooling to correct result

{:ok, model_info} =
  Bumblebee.load_model({:hf, "sentence-transformers/all-MiniLM-L6-v2"}, log_params_diff: false)

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "sentence-transformers/all-MiniLM-L6-v2"})

string_inputs = [
  "Health experts said it is too early to predict whether demand would match up with the 171 million doses of the new boosters the U.S. ordered for the fall."
]

inputs = Bumblebee.apply_tokenizer(tokenizer, string_inputs)

embedding = Axon.predict(model_info.model, model_info.params, inputs, compiler: EXLA)

input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)

embedding.hidden_state
|> Nx.multiply(input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
|> Scholar.Preprocessing.normalize(norm: :euclidean)

Embedding Example with Serving

Application.put_env(:sentence_transformer, :batch_size, 4)

string_inputs = [
  "Health experts said it is too early to predict whether demand would match up with the 171 million doses of the new boosters the U.S. ordered for the fall.",
  "He was subdued by passengers and crew when he fled to the back of the aircraft after the confrontation, according to the U.S. attorney's office in Los Angeles.",
  "Until you have a dog you don't understand what could be eaten.",
  "Accidentally put grown-up toothpaste on my toddler’s toothbrush and he screamed like I was cleaning his teeth with a Carolina Reaper dipped in Tabasco sauce."
]

batch_size = Application.get_env(:sentence_transformer, :batch_size)
defn_options = [compiler: EXLA]

serving =
  Nx.Serving.new(
    fn _opts ->
      model_name = "sentence-transformers/all-MiniLM-L6-v2"
      {:ok, model_info} = Bumblebee.load_model({:hf, model_name})

      {_init_fun, predict_fun} = Axon.build(model_info.model)

      inputs_template = %{
        "attention_mask" => Nx.template({batch_size, 128}, :u32),
        "input_ids" => Nx.template({batch_size, 128}, :u32),
        "token_type_ids" => Nx.template({batch_size, 128}, :u32)
      }

      template_args = [Nx.to_template(model_info.params), inputs_template]

      predict_fun = Nx.Defn.compile(predict_fun, template_args, defn_options)

      fn incoming_inputs ->
        inputs = Nx.Batch.pad(incoming_inputs, batch_size - incoming_inputs.size)
        predict_fun.(model_info.params, inputs)
      end
    end,
    batch_size: batch_size
  )

{:ok, pid} =
  Supervisor.start_link(
    [
      {Nx.Serving,
       serving: serving,
       name: SentenceTransformer.Serving,
       batch_timeout: 100,
       batch_size: batch_size}
    ],
    strategy: :one_for_one
  )

model_name = "sentence-transformers/all-MiniLM-L6-v2"

{:ok, model_info} = Bumblebee.load_model({:hf, model_name})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})

text_inputs =
  for text <- string_inputs do
    Bumblebee.apply_tokenizer(tokenizer, [text])
  end

text_batch = Nx.Batch.concatenate(text_inputs)

text_results = Nx.Serving.batched_run(SentenceTransformer.Serving, text_batch)

results =
  for {text_input, i} <- Enum.with_index(text_inputs) do
    text_attention_mask = text_input["attention_mask"]
    text_input_mask_expanded = Nx.new_axis(text_attention_mask, -1)

    text_results.hidden_state[i]
    |> Nx.multiply(text_input_mask_expanded)
    |> Nx.sum(axes: [1])
    |> Nx.divide(Nx.sum(text_input_mask_expanded, axes: [1]))
    |> Scholar.Preprocessing.normalize(norm: :euclidean)
  end

Supervisor.stop(pid)

results