Elixir to MLIR
Mix.install([
{:exla, "~> 0.9"},
{:nx, "~> 0.9"},
{:nx_iree, "~> 0.0"},
{:kino, "~> 0.14"},
{:benchee, "~> 1.3"},
{:statistics, "~> 0.6"}
])
Softmax
softmax = fn tensor ->
Nx.divide(
Nx.exp(tensor),
Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
)
end
input =
{1000, 1000, 5}
|> Nx.iota(type: :f32, backend: Nx.BinaryBackend)
|> Nx.divide(1024 * 1024)
args = [input]
%{
mlir_module: mlir_module,
output_container: output_container
} = EXLA.to_mlir_module(softmax, args)
Kino.Text.new(mlir_module)
dev =
NxIREE.list_devices("metal")
|> elem(1)
|> hd()
flags = [
"--iree-hal-target-backends=metal-spirv",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
module = NxIREE.compile(mlir_module, flags, output_container: output_container)
NxIREE.call(module, args, device: dev)
softmax.(input)
速度比較
exla_input = Nx.backend_transfer(input, EXLA.Backend)
Benchee.run(%{
"nx" => fn -> softmax.(input) end,
"exla" => fn -> softmax.(exla_input) end,
"nx_iree" => fn -> NxIREE.call(module, [input], device: dev) end
})