Nx
Mix.install([
{:nx, "~> 0.9"},
{:exla, "~> 0.9"},
{:emlx, github: "elixir-nx/emlx", branch: "main"}
])
Getting started with Nx
Nx.default_backend({EMLX.Backend, device: :gpu})
{Nx.BinaryBackend, []}
Nx.Defn.default_options(compiler: EMLX)
[]
t = Nx.tensor([[1, 2], [3, 4]])
#Nx.Tensor<
s32[2][2]
EMLX.Backend
[
[1, 2],
[3, 4]
]
>
Nx.shape(t)
{2, 2}
Nx.divide(Nx.exp(t), Nx.sum(Nx.exp(t)))
#Nx.Tensor<
f32[2][2]
EMLX.Backend
[
[0.032058604061603546, 0.08714432269334793],
[0.23688285052776337, 0.6439142227172852]
]
>
defmodule MyModule do
import Nx.Defn
defn softmax(t) do
Nx.exp(t) / Nx.sum(Nx.exp(t))
end
end
{:module, MyModule, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
MyModule.softmax(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
f32[3]
EMLX.Backend
[0.09003056585788727, 0.2447284609079361, 0.665241003036499]
>
will_jit = EXLA.jit(&MyModule.softmax/1)
will_jit.(t)
#Nx.Tensor<
f32[2][2]
EXLA.Backend
[
[0.032058604061603546, 0.08714432269334793],
[0.23688282072544098, 0.6439142227172852]
]
>