Run Nx.Serving with Enumerable
Mix.install([
{:nx, "~> 0.9"},
{:flow, "~> 1.2"},
{:kino, "~> 0.14"}
])
Serve
serving =
fn opts -> Nx.Defn.jit(& &1, opts) end
|> Nx.Serving.new()
|> Nx.Serving.client_preprocessing(fn input ->
IO.inspect("client_preprocessing #{Nx.to_number(input[0])}")
{Nx.Batch.stack([input]), :client_info}
end)
|> Nx.Serving.client_postprocessing(fn {output, _metadata}, _client_info ->
IO.inspect("client_postprocessing #{Nx.to_number(output[[0, 0]])}")
Nx.squeeze(output, axes: [0])
end)
Kino.start_child({Nx.Serving, name: Echo, serving: serving})
Local run
Nx.Serving.batched_run(Echo, Nx.tensor([1]))
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn input ->
Nx.Serving.batched_run(Echo, input)
end)
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Flow.from_enumerable(stages: 3, max_demand: 1)
|> Flow.map(fn input ->
Nx.Serving.batched_run(Echo, input)
end)
|> Enum.to_list()
Distributed run
Nx.Serving.batched_run({:distributed, Echo}, Nx.tensor([1]))
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn input ->
Nx.Serving.batched_run({:distributed, Echo}, input)
end)
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Flow.from_enumerable(stages: 3, max_demand: 1)
|> Flow.map(fn input ->
Nx.Serving.batched_run({:distributed, Echo}, input)
end)
|> Enum.to_list()
[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn tensor ->
Task.async(fn ->
Nx.Serving.batched_run({:distributed, Echo}, tensor)
end)
end)
|> Enum.map(fn task -> Task.await(task) end)