Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

NxIREE.Backend

livebooks/nx_iree/nxiree_backend.livemd

NxIREE.Backend

Mix.install([
  {:nx_iree, "~> 0.0"},
  {:benchee, "~> 1.3"}
])

デバイスの取得

dev =
  NxIREE.list_devices("metal")
  |> elem(1)
  |> hd()

Softmax 関数のコンパイル

flags = [
  "--iree-hal-target-backends=metal-spirv",
  "--iree-input-type=stablehlo_xla",
  "--iree-execution-model=async-internal"
]

Nx.Defn.default_options(
  compiler: NxIREE.Compiler,
  iree_compiler_flags: flags,
  iree_runtime_options: [device: dev]
)
softmax = fn tensor ->  
  Nx.divide(
    Nx.exp(tensor),
    Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
  )
end
iree_input =
  {1000, 1000, 5}
  |> Nx.iota(type: :f32, backend: NxIREE.Backend)
  |> Nx.divide(1024 * 1024)
softmax.(iree_input)
compiled_softmax = Nx.Defn.compile(softmax, [Nx.template({1000, 1000, 5}, :f32)])
compiled_softmax.(iree_input)

速度比較

binary_input =
  {1000, 1000, 5}
  |> Nx.iota(type: :f32, backend: Nx.BinaryBackend)
  |> Nx.divide(1024 * 1024)

exla_input =
  {1000, 1000, 5}
  |> Nx.iota(type: :f32, backend: EXLA.Backend)
  |> Nx.divide(1024 * 1024)
Benchee.run(%{
  "nx" => fn -> softmax.(binary_input) end,
  "exla" => fn -> softmax.(exla_input) end,
  "nx_iree" => fn -> compiled_softmax.(iree_input) end
})