Text embeddings
Mix.install(
[
{:kino_bumblebee, "~> 0.3.0"},
{:exla, "~> 0.5.3"},
{:scholar, "~> 0.1.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Simple Embedding Example
Given https://github.com/elixir-nx/bumblebee/issues/100#issuecomment-1372230339 use mean pooling to correct result
{:ok, model_info} =
Bumblebee.load_model({:hf, "sentence-transformers/all-MiniLM-L6-v2"}, log_params_diff: false)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "sentence-transformers/all-MiniLM-L6-v2"})
string_inputs = [
"Health experts said it is too early to predict whether demand would match up with the 171 million doses of the new boosters the U.S. ordered for the fall."
]
inputs = Bumblebee.apply_tokenizer(tokenizer, string_inputs)
embedding = Axon.predict(model_info.model, model_info.params, inputs, compiler: EXLA)
input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)
embedding.hidden_state
|> Nx.multiply(input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
|> Scholar.Preprocessing.normalize(norm: :euclidean)
Embedding Example with Serving
Application.put_env(:sentence_transformer, :batch_size, 4)
string_inputs = [
"Health experts said it is too early to predict whether demand would match up with the 171 million doses of the new boosters the U.S. ordered for the fall.",
"He was subdued by passengers and crew when he fled to the back of the aircraft after the confrontation, according to the U.S. attorney's office in Los Angeles.",
"Until you have a dog you don't understand what could be eaten.",
"Accidentally put grown-up toothpaste on my toddler’s toothbrush and he screamed like I was cleaning his teeth with a Carolina Reaper dipped in Tabasco sauce."
]
batch_size = Application.get_env(:sentence_transformer, :batch_size)
defn_options = [compiler: EXLA]
serving =
Nx.Serving.new(
fn _opts ->
model_name = "sentence-transformers/all-MiniLM-L6-v2"
{:ok, model_info} = Bumblebee.load_model({:hf, model_name})
{_init_fun, predict_fun} = Axon.build(model_info.model)
inputs_template = %{
"attention_mask" => Nx.template({batch_size, 128}, :u32),
"input_ids" => Nx.template({batch_size, 128}, :u32),
"token_type_ids" => Nx.template({batch_size, 128}, :u32)
}
template_args = [Nx.to_template(model_info.params), inputs_template]
predict_fun = Nx.Defn.compile(predict_fun, template_args, defn_options)
fn incoming_inputs ->
inputs = Nx.Batch.pad(incoming_inputs, batch_size - incoming_inputs.size)
predict_fun.(model_info.params, inputs)
end
end,
batch_size: batch_size
)
{:ok, pid} =
Supervisor.start_link(
[
{Nx.Serving,
serving: serving,
name: SentenceTransformer.Serving,
batch_timeout: 100,
batch_size: batch_size}
],
strategy: :one_for_one
)
model_name = "sentence-transformers/all-MiniLM-L6-v2"
{:ok, model_info} = Bumblebee.load_model({:hf, model_name})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
text_inputs =
for text <- string_inputs do
Bumblebee.apply_tokenizer(tokenizer, [text])
end
text_batch = Nx.Batch.concatenate(text_inputs)
text_results = Nx.Serving.batched_run(SentenceTransformer.Serving, text_batch)
results =
for {text_input, i} <- Enum.with_index(text_inputs) do
text_attention_mask = text_input["attention_mask"]
text_input_mask_expanded = Nx.new_axis(text_attention_mask, -1)
text_results.hidden_state[i]
|> Nx.multiply(text_input_mask_expanded)
|> Nx.sum(axes: [1])
|> Nx.divide(Nx.sum(text_input_mask_expanded, axes: [1]))
|> Scholar.Preprocessing.normalize(norm: :euclidean)
end
Supervisor.stop(pid)
results