Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Nx.Serving Preprocess Postprocess

preprocess_postprocess.livemd

Nx.Serving Preprocess Postprocess

Mix.install(
  [
    {:nx, "~> 0.9"},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.14"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Default

default_serving =
  fn opts -> Nx.Defn.jit(&Nx.multiply(&1, 2), opts) end
  |> Nx.Serving.new(compiler: EXLA)

Kino.start_child({Nx.Serving, name: DefaultServing, serving: default_serving})
batch = Nx.Batch.stack([Nx.tensor([1])])

Nx.Serving.run(default_serving, batch)
Nx.Serving.batched_run(DefaultServing, batch)

Preprocess and Postprocess

pre_post_serving =
  fn opts -> Nx.Defn.jit(&Nx.multiply(&1, 2), opts) end
  |> Nx.Serving.new(compiler: EXLA)
  |> Nx.Serving.client_preprocessing(fn input ->
    # 前処理
    IO.inspect("client_preprocessing")
    IO.inspect(Node.self())
    # テンソルで与えられた入力をバッチに変換する
    {Nx.Batch.stack([input]), :client_info}
  end)
  |> Nx.Serving.client_postprocessing(fn {output, _metadata}, _client_info ->
    # 後処理
    IO.inspect("client_postprocessing")
    IO.inspect(Node.self())
    # 出力テンソルの次元を減らす
    Nx.squeeze(output, axes: [0])
  end)
  |> Nx.Serving.distributed_postprocessing(fn output ->
    # 分散している場合の後処理
    IO.inspect("distributed_postprocessing")
    IO.inspect(Node.self())
    # バイナリバックエンドに変換する
    Nx.backend_transfer(output, Nx.BinaryBackend)
  end)

Kino.start_child({Nx.Serving, name: PrePostServing, serving: pre_post_serving})
Node.self()
Nx.Serving.run(pre_post_serving, Nx.tensor([1]))
Nx.Serving.batched_run(PrePostServing, Nx.tensor([1]))
distributed_preprocessing = fn input ->
  # 分散している場合の前処理
  IO.inspect("distributed_preprocessing")
  IO.inspect(Node.self())
  # バイナリバックエンドに変換する
  Nx.backend_transfer(input, Nx.BinaryBackend)
end

Nx.Serving.batched_run({:distributed, PrePostServing}, Nx.tensor([1]), distributed_preprocessing)
Node.get_cookie()