Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Nx

livebooks/nx.livemd

Nx

Mix.install([
  {:nx, "~> 0.9"},
  {:exla, "~> 0.9"},
  {:emlx, github: "elixir-nx/emlx", branch: "main"}
])

Getting started with Nx

Nx.default_backend({EMLX.Backend, device: :gpu})
{Nx.BinaryBackend, []}
Nx.Defn.default_options(compiler: EMLX)
[]
 t = Nx.tensor([[1, 2], [3, 4]])
#Nx.Tensor<
  s32[2][2]
  EMLX.Backend
  [
    [1, 2],
    [3, 4]
  ]
>
Nx.shape(t)
{2, 2}
Nx.divide(Nx.exp(t), Nx.sum(Nx.exp(t)))
#Nx.Tensor<
  f32[2][2]
  EMLX.Backend
  [
    [0.032058604061603546, 0.08714432269334793],
    [0.23688285052776337, 0.6439142227172852]
  ]
>
defmodule MyModule do
  import Nx.Defn

  defn softmax(t) do
    Nx.exp(t) / Nx.sum(Nx.exp(t))
  end
end
{:module, MyModule, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
MyModule.softmax(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  f32[3]
  EMLX.Backend
  [0.09003056585788727, 0.2447284609079361, 0.665241003036499]
>
will_jit = EXLA.jit(&amp;MyModule.softmax/1)
will_jit.(t)
#Nx.Tensor<
  f32[2][2]
  EXLA.Backend
  [
    [0.032058604061603546, 0.08714432269334793],
    [0.23688282072544098, 0.6439142227172852]
  ]
>