Open AI Text Embedings - fork
Mix.install(
[
{:req, "~> 0.3.6"},
{:exla, "~> 0.5"},
{:kino_vega_lite, "~> 0.1"},
{:nx, "~> 0.5"},
{:rustler, "~> 0.0"},
{:scholar, "~> 0.1"},
{:tsne, "~> 0.1"},
{:vega_lite, "~> 0.1.7"}
],
config: [
nx: [default_backend: EXLA.Backend]
]
)
Library
defmodule OpenAi.Client do
@type ok_tuple(value) :: {:ok, value}
@type error_tuple :: {:error, any()}
require Logger
@openai_timeout 120 * 1000
@spec build_req :: Req.Request.t()
defp build_req do
host = "https://api.openai.com"
secret_key = System.fetch_env!("LB_" <> "OPEN_AI_API_KEY")
headers = %{
"Authorization" => "Bearer #{secret_key}"
}
Req.new(
base_url: host,
headers: headers,
receive_timeout: @openai_timeout,
max_retries: 2
)
end
@spec fetch_embedding(String.t(), String.t() | list(String.t())) ::
ok_tuple(list()) | error_tuple()
def fetch_embedding(model, input) do
payload = %{"model" => model, "input" => input}
build_req()
|> Req.post(
url: "/v1/embeddings",
json: payload,
retry: fn res -> retry_fun(res, "EMBEDDINGS") end
)
|> process_response(:text_embedding_failed)
|> then(fn result ->
case result do
{:ok, data} ->
data
|> Map.get("data")
|> Enum.map(fn embedding ->
%{index: embedding["index"], embedding: embedding["embedding"]}
end)
|> then(fn result -> {:ok, result} end)
{:error, error} ->
{:error, error}
end
end)
end
@spec process_response({:ok, Req.Response.t()} | error_tuple(), atom()) ::
ok_tuple(binary() | term()) | error_tuple()
defp process_response({:ok, res}, _error_code) when res.status < 400, do: {:ok, res.body}
defp process_response({:ok, res}, error_code) do
Logger.error("[OPENAI_REQUEST] #{inspect(error_code)}: res #{inspect(res)}")
{:error, error_code}
end
defp process_response({:error, error}, error_code) do
error |> Kernel.inspect() |> Logger.error()
{:error, error_code}
end
defp retry_fun(%Req.Response{} = res, source) do
if res.status >= 400 do
Logger.error("[OPENAI_REQUEST][#{source}] HTTP Error: #{inspect(res.status)}, retrying...")
# retry on 4xx or 5xx
true
else
false
end
end
defp retry_fun(%Mint.TransportError{} = e, source) do
# retry on transport errors
Logger.error("[OPENAI_REQUEST][#{source}] Transport Error: #{inspect(e)}")
true
end
defp retry_fun(x, source) do
Logger.debug("[OPENAI_REQUEST][#{source}] Unknown retry function param: #{inspect(x)}")
false
end
end
defmodule OpenAi do
@type ok_tuple(value) :: {:ok, value}
@type error_tuple :: {:error, any()}
require Logger
alias OpenAi.Client
@text_embedding_model "text-embedding-ada-002"
@spec text_embedding(String.t() | list(String.t())) :: ok_tuple(list()) | error_tuple()
def text_embedding(input) do
case Client.fetch_embedding(@text_embedding_model, input) do
{:ok, result} ->
Logger.debug("[OPENAI][TEXT_EMBEDDING] fetch_text_embedding success")
{:ok, result}
{:error, reason} ->
Logger.error("[OPENAI][ERROR] fetch_text_embedding failed, #{inspect(reason)}")
{:error, reason}
end
end
end
Usage
{:ok, result} =
OpenAi.text_embedding([
"an aerial view of the campus of harvard university in cambridge, massachusetts, united states",
"aerial view of harvard university"
])
embeddings =
result
|> Enum.map(fn r -> r.embedding end)
|> then(fn l ->
padded_list =
if length(l) < 16 do
pad_length = 16 - length(l)
pad_list = List.duplicate(List.duplicate(0, 1536), pad_length)
Enum.concat(l, pad_list)
else
l
end
padded_list
end)
defmodule Math do
def clamp(num, _min, max) when num > max, do: max
def clamp(num, min, _max) when num < min, do: min
def clamp(num, _min, _max), do: num
end
perplexity = Math.clamp(floor((length(embeddings) - 1) / 3), 5, 30)
e =
Scholar.Decomposition.PCA.fit_transform(embeddings |> Nx.tensor(), num_components: 2)
|> Nx.to_list()
|> Tsne.barnes_hut(perplexity: perplexity / 1.0)
|> Enum.map(fn [x, y] -> %{x: x, y: y} end)
alias VegaLite, as: Vl
Vl.new(width: 800, height: 800)
|> Vl.data_from_values(e)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)