Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Generate Embeddings

notebooks/generate_embeddings.livemd

Generate Embeddings

Mix.install(
  [
    # the ecto we know and love
    {:postgrex, ">= 0.0.0"},
    {:pgvector, "~> 0.3.0"},
    {:ecto, "~> 3.0"},
    {:ecto_sql, "~> 3.0"},
    # Nx stuff
    {:bumblebee, "~> 0.6.0"},
    {:nx, "~> 0.9.0"},
    {:exla, "~> 0.9.0"},
    {:axon, "~> 0.7.0"},
    {:kino, "~> 0.14.0"},
    # paradex
    {:paradex, path: Path.join(__DIR__, "../"), env: :test}
  ],
  # I've configured this for my rig, yours will likely differ.
  config: [
    paradex: [
      {:ecto_repos, [ParadexApp.Repo]},
      {ParadexApp.Repo, [
        pool: Ecto.Adapters.SQL.Sandbox,
        database: "paradex_test",
        username: "postgres",
        password: "postgres",
        hostname: "localhost",
        timeout: 600_000,
        ownership_timeout: 600_000,
        pool_timeout: 600_000,
        port: 5433,
        types: ParadexApp.PostgrexTypes
      ]}
    ],
    exla: [
      clients: [
        cuda: [
          platform: :cuda,
          memory_fraction: 0.85,
          device_id: 0
        ]
      ],
      client: :cuda
    ],
    nx: [
      default_backend: {EXLA.Backend, client: :cuda, device_id: 0}
    ]
  ],
  config_path: :paradex,
  lockfile: :paradex
)

IGNORE ME!!!

import Ecto.Query

alias ParadexApp.Repo
alias ParadexApp.Call

ParadexApp.Repo.start_link()

{:ok, model_info} = Bumblebee.load_model({:hf, "sentence-transformers/all-MiniLM-L6-v2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "sentence-transformers/all-MiniLM-L6-v2"})

serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer)

defmodule CallStream do
  defp start(), do: 0

  defp next(last_id) do
    calls =
      from(
        c in Call,
        where: c.id > ^last_id,
        where: is_nil(c.embedding),
        order_by: [asc: c.id],
        limit: 100
      ) |> Repo.all()

    case calls do
      [] -> {:halt, last_id}
      calls -> {calls, List.last(calls).id}
    end
  end

  defp finally(_last_id), do: :noop

  def resource() do
    Stream.resource(&start/0, &next/1, &finally/1)
  end
end

update_embedding = fn call ->
  %{embedding: embedding} = Nx.Serving.run(serving, call.transcript)

  Repo.update_all(
    from(c in Call, where: c.id == ^call.id),
    set: [embedding: Pgvector.new(embedding)]
  )
end

(
  Repo.transaction(fn ->
    CallStream.resource()
    |> Stream.each(& update_embedding.(&1))
    |> Stream.run()
  end, timeout: :infinity)
)