Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Elixir to MLIR

livebooks/nx_iree/elixir_to_mlir.livemd

Elixir to MLIR

Mix.install([
  {:exla, "~> 0.9"},
  {:nx, "~> 0.9"},
  {:nx_iree, "~> 0.0"},
  {:kino, "~> 0.14"},
  {:benchee, "~> 1.3"},
  {:statistics, "~> 0.6"}
])

Softmax

softmax = fn tensor ->  
  Nx.divide(
    Nx.exp(tensor),
    Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
  )
end

input =
  {1000, 1000, 5}
  |> Nx.iota(type: :f32, backend: Nx.BinaryBackend)
  |> Nx.divide(1024 * 1024)

args = [input]

%{
  mlir_module: mlir_module,
  output_container: output_container
} = EXLA.to_mlir_module(softmax, args)

Kino.Text.new(mlir_module)
dev =
  NxIREE.list_devices("metal")
  |> elem(1)
  |> hd()

flags = [
  "--iree-hal-target-backends=metal-spirv",
  "--iree-input-type=stablehlo_xla",
  "--iree-execution-model=async-internal"
]

module = NxIREE.compile(mlir_module, flags, output_container: output_container)
NxIREE.call(module, args, device: dev)
softmax.(input)

速度比較

exla_input = Nx.backend_transfer(input, EXLA.Backend)
Benchee.run(%{
  "nx" => fn -> softmax.(input) end,
  "exla" => fn -> softmax.(exla_input) end,
  "nx_iree" => fn -> NxIREE.call(module, [input], device: dev) end
})