NxIREE.Backend
Mix.install([
{:nx_iree, "~> 0.0"},
{:benchee, "~> 1.3"}
])
デバイスの取得
dev =
NxIREE.list_devices("metal")
|> elem(1)
|> hd()
Softmax 関数のコンパイル
flags = [
"--iree-hal-target-backends=metal-spirv",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
Nx.Defn.default_options(
compiler: NxIREE.Compiler,
iree_compiler_flags: flags,
iree_runtime_options: [device: dev]
)
softmax = fn tensor ->
Nx.divide(
Nx.exp(tensor),
Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
)
end
iree_input =
{1000, 1000, 5}
|> Nx.iota(type: :f32, backend: NxIREE.Backend)
|> Nx.divide(1024 * 1024)
softmax.(iree_input)
compiled_softmax = Nx.Defn.compile(softmax, [Nx.template({1000, 1000, 5}, :f32)])
compiled_softmax.(iree_input)
速度比較
binary_input =
{1000, 1000, 5}
|> Nx.iota(type: :f32, backend: Nx.BinaryBackend)
|> Nx.divide(1024 * 1024)
exla_input =
{1000, 1000, 5}
|> Nx.iota(type: :f32, backend: EXLA.Backend)
|> Nx.divide(1024 * 1024)
Benchee.run(%{
"nx" => fn -> softmax.(binary_input) end,
"exla" => fn -> softmax.(exla_input) end,
"nx_iree" => fn -> compiled_softmax.(iree_input) end
})