Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Run Nx.Serving with Enumerable

run_with_enum_flow.livemd

Run Nx.Serving with Enumerable

Mix.install([
  {:nx, "~> 0.9"},
  {:flow, "~> 1.2"},
  {:kino, "~> 0.14"}
])

Serve

serving =
  fn opts -> Nx.Defn.jit(& &1, opts) end
  |> Nx.Serving.new()
  |> Nx.Serving.client_preprocessing(fn input ->
    IO.inspect("client_preprocessing #{Nx.to_number(input[0])}")
    {Nx.Batch.stack([input]), :client_info}
  end)
  |> Nx.Serving.client_postprocessing(fn {output, _metadata}, _client_info ->
    IO.inspect("client_postprocessing #{Nx.to_number(output[[0, 0]])}")
    Nx.squeeze(output, axes: [0])
  end)

Kino.start_child({Nx.Serving, name: Echo, serving: serving})

Local run

Nx.Serving.batched_run(Echo, Nx.tensor([1]))
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn input ->
  Nx.Serving.batched_run(Echo, input)
end)
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Flow.from_enumerable(stages: 3, max_demand: 1)
|> Flow.map(fn input ->
  Nx.Serving.batched_run(Echo, input)
end)
|> Enum.to_list()

Distributed run

Nx.Serving.batched_run({:distributed, Echo}, Nx.tensor([1]))
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn input ->
  Nx.Serving.batched_run({:distributed, Echo}, input)
end)
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Flow.from_enumerable(stages: 3, max_demand: 1)
|> Flow.map(fn input ->
  Nx.Serving.batched_run({:distributed, Echo}, input)
end)
|> Enum.to_list()
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn tensor ->
  Task.async(fn ->
    Nx.Serving.batched_run({:distributed, Echo}, tensor)
  end)
end)
|> Enum.map(fn task -> Task.await(task) end)