Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Demo 2: Inference

sorry-dave/inference_demo.livemd

Demo 2: Inference

Mix.install(
  [
    {:axon, "~> 0.5"},
    {:bumblebee, "~> 0.5"},
    {:nx, "~> 0.8.0", override: true},
    {:exla, "~> 0.8.0", override: true},
    {:kino, "~> 0.14.0"},
    {:kino_flame, github: "hugobarauna/kino_flame"},
    {:flame, "~> 0.5.1"},
    {:explorer, "~> 0.9.2"},
    {:kino_vega_lite, "~> 0.1.13"},
    {:vega_lite, "~> 0.1.9"},
    {:table_rex, "~> 4.0", override: true},
    {:req, "~> 0.5.6"},
    {:req_s3, "~> 0.2.3"}
  ],
  system_env: [
    XLA_TARGET: "cuda12",
    AWS_ACCESS_KEY_ID: System.get_env("LB_LIVEBOOK_DEMOS_ACCESS_KEY_ID"),
    AWS_SECRET_ACCESS_KEY: System.get_env("LB_LIVEBOOK_DEMOS_SECRET_ACCESS_KEY"),
    AWS_REGION: "auto",
    AWS_ENDPOINT_URL_S3: "https://fly.storage.tigris.dev"
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend,
      default_defn_options: [compiler: EXLA, client: :cuda]
    ]
  ]
)

Configurações iniciais

require Explorer.DataFrame, as: DataFrame
require Explorer.Series, as: Series
import Kino.Shorts
defmodule RemoteMacOsNotifier do
  def notify(message) do
    :erpc.call(livebook_node(), System, :cmd, [
      "osascript",
      [
        "-e",
        "display notification \"#{message}\" with title \"Livebook\" sound name \"Pop\""
      ]
    ])
  end

  defp livebook_node() do
    lb_node =
      Node.list(:hidden)
      |> Enum.filter(&String.contains?(Atom.to_string(&1), "127.0.0.1"))
      |> List.first()

    Node.connect(lb_node)
    lb_node
  end
end
repo = "google-bert/bert-base-uncased"
repo_cache_dir_name = String.replace(repo, "/", "--")
repo_cache_dir = "#{Bumblebee.cache_dir()}/huggingface/#{repo_cache_dir_name}"

paths = for file <- File.ls!(repo_cache_dir), do: "#{repo_cache_dir}/#{file}"

Configuração do cluster de dataset distribuido

Kino.start_child(
  {FLAME.Pool,
   name: :dataset_pool,
   code_sync: [
     start_apps: true,
     sync_beams: Kino.beam_paths(),
     compress: true,
     copy_paths: [],
     verbose: true
   ],
   min: 0,
   max: 16,
   max_concurrency: 1,
   boot_timeout: :timer.minutes(3),
   idle_shutdown_after: :timer.minutes(1),
   timeout: :infinity,
   track_resources: true,
   log: :info,
   backend:
     {FLAME.FlyBackend,
      cpu_kind: "performance",
      cpus: 2,
      memory_mb: 8192,
      env: %{
        "LIVEBOOK_COOKIE" => Node.get_cookie(),
        "LB_LIVEBOOK_DEMOS_ACCESS_KEY_ID" =>
          System.fetch_env!("LB_LIVEBOOK_DEMOS_ACCESS_KEY_ID"),
        "LB_LIVEBOOK_DEMOS_SECRET_ACCESS_KEY" =>
          System.fetch_env!("LB_LIVEBOOK_DEMOS_SECRET_ACCESS_KEY")
      }}}
)
RemoteMacOsNotifier.notify("Dataset cluster setup ✅")

Carregando o dataset de modo distribuido no cluster

req = Req.new() |> ReqS3.attach()

%{
  "ListBucketResult" => %{
    "Contents" => keys
  }
} = Req.get!(req, url: "s3://livebook-demos").body

parquet_file_keys =
  keys
  |> Enum.map(&amp;Map.get(&amp;1, "Key"))
  |> Enum.filter(&amp;String.starts_with?(&amp;1, "cannabinoid-edibles/eval"))
number_files_to_process = read_number("Nº de arquivos parquet para processar", default: 16)
Kino.nothing()
parquet_file_keys = Enum.take(parquet_file_keys, number_files_to_process)
dataframes =
  parquet_file_keys
  |> Task.async_stream(
    fn key ->
      FLAME.call(:dataset_pool, fn ->
        DataFrame.from_parquet!("s3://livebook-demos/#{key}",
          config: [
            region: "auto",
            endpoint: "https://fly.storage.tigris.dev",
            access_key_id: System.get_env("LB_LIVEBOOK_DEMOS_ACCESS_KEY_ID"),
            secret_access_key: System.get_env("LB_LIVEBOOK_DEMOS_SECRET_ACCESS_KEY")
          ]
        )
      end)
    end,
    max_concurrency: 16,
    timeout: :infinity,
    ordered: false
  )
  |> Stream.map(fn {:ok, df} -> df end)
  |> Enum.to_list()

RemoteMacOsNotifier.notify("Dataset cluster online 👍")
dataframes
|> length()
dataframe = List.first(dataframes)
Enum.reduce(dataframes, 0, fn dataframe, acc ->
  acc + DataFrame.n_rows(dataframe)
end)
DataFrame.filter(dataframe, contains(abstract, "cannabinoid"))
Enum.reduce(dataframes, 0, fn dataframe, acc ->
  times_mentioned =
    dataframe
    |> DataFrame.mutate(mentions?: Series.contains(abstract, "cannabinoid"))
    |> DataFrame.pull(:mentions?)
    |> Series.sum()

  acc + times_mentioned
end)

Servidor de inferência distribuido

Kino.start_child(
  {FLAME.Pool,
   name: :inference_pool,
   code_sync: [
     start_apps: true,
     sync_beams: Kino.beam_paths(),
     compress: false,
     copy_paths: paths,
     verbose: true
   ],
   min: 1,
   max: 1,
   max_concurrency: 1,
   boot_timeout: :timer.minutes(3),
   idle_shutdown_after: :timer.minutes(1),
   timeout: :infinity,
   track_resources: true,
   log: :info,
   backend:
     {FLAME.FlyBackend,
      cpu_kind: "performance",
      cpus: 4,
      memory_mb: 32768,
      gpu_kind: "l40s",
      gpus: 4,
      env: %{
        "LIVEBOOK_COOKIE" => Node.get_cookie(),
        "AWS_ACCESS_KEY_ID" => System.fetch_env!("AWS_ACCESS_KEY_ID"),
        "AWS_ENDPOINT_URL_S3" => System.fetch_env!("AWS_ENDPOINT_URL_S3"),
        "AWS_REGION" => System.fetch_env!("AWS_REGION"),
        "AWS_SECRET_ACCESS_KEY" => System.fetch_env!("AWS_SECRET_ACCESS_KEY"),
        "XLA_TARGET" => System.fetch_env!("XLA_TARGET")
      }}}
)
defmodule InferenceServing do
  def child_spec(_) do
    %{
      id: __MODULE__,
      start: {__MODULE__, :start_link, []},
      type: :supervisor
    }
  end

  def start_link() do
    {:ok, bert} =
      Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"},
        architecture: :for_sequence_classification,
        backend: {EXLA.Backend, client: :host}
      )

    {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"})

    serving =
      Bumblebee.Text.text_classification(bert, tokenizer,
        compile: [batch_size: 1, sequence_length: 512],
        defn_options: [compiler: EXLA],
        preallocate_params: true
      )

    Nx.Serving.start_link(name: __MODULE__, serving: serving, partitions: true)
  end
end
FLAME.place_child(:inference_pool, InferenceServing)
RemoteMacOsNotifier.notify("Processo servidor de inferência rodando ✅")

Inferência distribuida com Nx.Serving. É transparente se o processo está rodando local ou em outra máquina do cluster

Primeiro rodando com apenas um input (patent abstract)

sample_abstract =
  dataframe
  |> DataFrame.filter(not is_nil(abstract))
  |> DataFrame.shuffle()
  |> then(fn df -> df["abstract"][0] end)
Nx.Serving.batched_run(InferenceServing, sample_abstract)

Agora fazendo a inferência usando todo dataset distribuido no cluster

predictions =
  dataframe
  |> DataFrame.filter(not is_nil(abstract))
  |> DataFrame.pull(:abstract)
  |> Series.to_enum()
  |> Stream.take(100)
  |> Stream.chunk_every(20)
  |> Enum.to_list()
  |> Stream.flat_map(&amp;Nx.Serving.batched_run(InferenceServing, &amp;1))
  |> Enum.map(fn %{predictions: predictions} ->
    predictions
  end)
predictions =
  dataframes
  |> Task.async_stream(
    fn dataframe ->
      dataset_node =
        dataframe
        |> Map.values()
        |> List.first()
        |> then(fn %Explorer.PolarsBackend.DataFrame{resource: ref} -> ref end)
        |> node()

      IO.puts("Processing dataset from node #{dataset_node}")

      dataframe
      |> DataFrame.filter(not is_nil(abstract))
      |> DataFrame.pull(:abstract)
      |> Series.to_enum()
      |> Stream.take(100)
      |> Stream.chunk_every(20)
      |> Enum.to_list()
      |> Stream.flat_map(&amp;Nx.Serving.batched_run(InferenceServing, &amp;1))
      |> Enum.map(fn %{predictions: predictions} ->
        predictions
      end)
    end,
    max_concurrency: 16,
    timeout: :infinity,
    ordered: false
  )
  |> Stream.map(fn {:ok, predictions} -> predictions end)
  |> Enum.to_list()

RemoteMacOsNotifier.notify("Inferência em dataset distribuido ✅")
predictions
|> List.flatten()
|> Enum.count()