Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Open AI Text Embedings - fork

books/ai/openai/tsne.livemd

Open AI Text Embedings - fork

Mix.install(
  [
    {:req, "~> 0.3.6"},
    {:exla, "~> 0.5"},
    {:kino_vega_lite, "~> 0.1"},
    {:nx, "~> 0.5"},
    {:rustler, "~> 0.0"},
    {:scholar, "~> 0.1"},
    {:tsne, "~> 0.1"},
    {:vega_lite, "~> 0.1.7"}
  ],
  config: [
    nx: [default_backend: EXLA.Backend]
  ]
)

Library

defmodule OpenAi.Client do
  @type ok_tuple(value) :: {:ok, value}
  @type error_tuple :: {:error, any()}

  require Logger

  @openai_timeout 120 * 1000

  @spec build_req :: Req.Request.t()
  defp build_req do
    host = "https://api.openai.com"
    secret_key = System.fetch_env!("LB_" <> "OPEN_AI_API_KEY")

    headers = %{
      "Authorization" => "Bearer #{secret_key}"
    }

    Req.new(
      base_url: host,
      headers: headers,
      receive_timeout: @openai_timeout,
      max_retries: 2
    )
  end

  @spec fetch_embedding(String.t(), String.t() | list(String.t())) ::
          ok_tuple(list()) | error_tuple()
  def fetch_embedding(model, input) do
    payload = %{"model" => model, "input" => input}

    build_req()
    |> Req.post(
      url: "/v1/embeddings",
      json: payload,
      retry: fn res -> retry_fun(res, "EMBEDDINGS") end
    )
    |> process_response(:text_embedding_failed)
    |> then(fn result ->
      case result do
        {:ok, data} ->
          data
          |> Map.get("data")
          |> Enum.map(fn embedding ->
            %{index: embedding["index"], embedding: embedding["embedding"]}
          end)
          |> then(fn result -> {:ok, result} end)

        {:error, error} ->
          {:error, error}
      end
    end)
  end

  @spec process_response({:ok, Req.Response.t()} | error_tuple(), atom()) ::
          ok_tuple(binary() | term()) | error_tuple()
  defp process_response({:ok, res}, _error_code) when res.status < 400, do: {:ok, res.body}

  defp process_response({:ok, res}, error_code) do
    Logger.error("[OPENAI_REQUEST] #{inspect(error_code)}: res #{inspect(res)}")
    {:error, error_code}
  end

  defp process_response({:error, error}, error_code) do
    error |> Kernel.inspect() |> Logger.error()
    {:error, error_code}
  end

  defp retry_fun(%Req.Response{} = res, source) do
    if res.status >= 400 do
      Logger.error("[OPENAI_REQUEST][#{source}] HTTP Error: #{inspect(res.status)}, retrying...")
      # retry on 4xx or 5xx
      true
    else
      false
    end
  end

  defp retry_fun(%Mint.TransportError{} = e, source) do
    # retry on transport errors
    Logger.error("[OPENAI_REQUEST][#{source}] Transport Error: #{inspect(e)}")
    true
  end

  defp retry_fun(x, source) do
    Logger.debug("[OPENAI_REQUEST][#{source}] Unknown retry function param: #{inspect(x)}")
    false
  end
end
defmodule OpenAi do
  @type ok_tuple(value) :: {:ok, value}
  @type error_tuple :: {:error, any()}

  require Logger

  alias OpenAi.Client

  @text_embedding_model "text-embedding-ada-002"

  @spec text_embedding(String.t() | list(String.t())) :: ok_tuple(list()) | error_tuple()
  def text_embedding(input) do
    case Client.fetch_embedding(@text_embedding_model, input) do
      {:ok, result} ->
        Logger.debug("[OPENAI][TEXT_EMBEDDING] fetch_text_embedding success")
        {:ok, result}

      {:error, reason} ->
        Logger.error("[OPENAI][ERROR] fetch_text_embedding failed, #{inspect(reason)}")
        {:error, reason}
    end
  end
end

Usage

{:ok, result} =
  OpenAi.text_embedding([
    "an aerial view of the campus of harvard university in cambridge, massachusetts, united states",
    "aerial view of harvard university"
  ])
embeddings =
  result
  |> Enum.map(fn r -> r.embedding end)
  |> then(fn l ->
    padded_list =
      if length(l) < 16 do
        pad_length = 16 - length(l)
        pad_list = List.duplicate(List.duplicate(0, 1536), pad_length)
        Enum.concat(l, pad_list)
      else
        l
      end

    padded_list
  end)
defmodule Math do
  def clamp(num, _min, max) when num > max, do: max
  def clamp(num, min, _max) when num < min, do: min
  def clamp(num, _min, _max), do: num
end

perplexity = Math.clamp(floor((length(embeddings) - 1) / 3), 5, 30)

e =
  Scholar.Decomposition.PCA.fit_transform(embeddings |> Nx.tensor(), num_components: 2)
  |> Nx.to_list()
  |> Tsne.barnes_hut(perplexity: perplexity / 1.0)
  |> Enum.map(fn [x, y] -> %{x: x, y: y} end)
alias VegaLite, as: Vl
Vl.new(width: 800, height: 800)
|> Vl.data_from_values(e)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)