Chapter Two
Mix.install([
{:nx, "~> 0.6"},
{:exla, "~> 0.6"},
{:benchee, "~> 1.1"}
])
Understanding Tensors
uta =
Nx.tensor([
[1, 2, 3],
[4, 5, 6]
])
utb = Nx.tensor(1.1)
utc = Nx.tensor([[[[1, 2, 3]]]])
utd = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y])
ute =
Nx.tensor(
[
[
[1, 2, 3],
[3, 2, 1],
[2, 3, 1]
],
[
[4, 5, 6],
[6, 5, 4],
[5, 4, 6]
],
[
[7, 8, 9],
[9, 8, 7],
[8, 9, 7]
]
],
names: [:x, :y, :z],
type: :u8
)
utf = <<1::16-unsigned-native, 2::16-unsigned-native>> |> Nx.from_binary({:u, 16})
IO.inspect(uta, label: :a)
IO.inspect(utb, label: :b)
IO.inspect(utc, label: :c)
IO.inspect(utd, label: :d)
IO.inspect(ute, label: :e)
IO.inspect(Nx.to_binary(ute), label: :e_bin)
IO.inspect(utf, label: :f)
:ok
Size and Shape
ssa = Nx.tensor(for n <- 1..3, do: n)
ssb = ssa |> Nx.as_type({:bf, 16}) |> Nx.reshape({1, 3, 1})
ssc = Nx.bitcast(ssa, {:f, 64})
IO.inspect(ssa, label: :a)
IO.inspect(ssb, label: :b)
IO.inspect(ssc, label: :b)
:ok
Using Nx Operations
oa = Nx.tensor([[[-1, -2, -3], [-4, -5, -6]], [[1, 2, 3], [4, 5, 6]]])
ob = Nx.abs(oa)
oc = Nx.to_list(ob) |> List.flatten()
od = Nx.tensor(for n <- 4..14, do: n)
oe = Nx.tensor(for n <- -3..7, do: n)
of = Nx.add(od, oe)
og = Nx.divide(od, Nx.add(oe, 10))
oh = Nx.multiply(of, og)
oi =
Nx.add(
Nx.tensor([99, 100, 101]),
Nx.tensor([[3, 4, 5], [18, 19, 20]])
)
oj = Nx.add(oi, 10)
IO.inspect(oa, label: :a)
IO.inspect(ob, label: :b)
IO.inspect(oc, label: :c)
IO.inspect(od, label: :d)
IO.inspect(oe, label: :e)
IO.inspect(of, label: :f)
IO.inspect(og, label: :g)
IO.inspect(oh, label: :h)
IO.inspect(oi, label: :i)
IO.inspect(oj, label: :j)
:ok
Reductions
rk = Nx.Random.key(77_773)
{rr, _} = Nx.Random.uniform(rk, shape: {4, 16}, type: :f16, names: [:x, :y])
ra = Nx.sum(rr, axes: [:y]) |> Nx.sum()
rs = Nx.reshape(rr, {4, 4, 4}, type: :bf16, names: [:x, :y, :z])
IO.inspect(rk, label: :k)
IO.inspect(rr, label: :r)
IO.inspect(ra, label: :a)
IO.inspect(rs, label: :s)
:ok
defn
defmodule DefnModule do
import Nx.Defn
defn(add(x, y), do: Nx.add(x, y) |> print_expr())
end
na = DefnModule.add(Nx.tensor([1, 2, 3], type: :u8), Nx.tensor([4, 5, 6], type: :u8))
IO.inspect(na, label: :a)
:ok
Optimized Softmax
defmodule Softmax do
import Nx.Defn
defn(softmax(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n)))
end
# Nx.Defn.global_default_options(compiler: EXLA)
tensor = Nx.random_uniform({1_000_000})
Benchee.run(
%{
"JIT with EXLA" => fn -> apply(EXLA.jit(&Softmax.softmax/1), [tensor]) end,
"Regular Elixir" => fn -> Softmax.softmax(tensor) end
},
time: 10
)
κ = Nx.Random.key(77_773)
{α, _} = Nx.Random.uniform(κ, shape: {128}, type: :bf16)
{β, _} = Nx.Random.normal(κ, shape: {128}, type: :bf16)
γ = Nx.multiply(α, β)
δ = Nx.dot(α, β)
ε = Nx.reshape(α, {8, 16})
ζ = Nx.reshape(β, {16, 8})
η = Nx.dot(ε, ζ)
θ = Nx.slice(α, [0], [64])
ι = Nx.slice(α, [64], [64]) |> Nx.transpose()
λ = Nx.dot(θ, ι)
IO.inspect(α, label: :α)
IO.inspect(β, label: :β)
IO.inspect(γ, label: :γ)
IO.inspect(δ, label: :δ)
IO.inspect(ε, label: :ε)
IO.inspect(ζ, label: :ζ)
IO.inspect(η, label: :η)
IO.inspect(θ, label: :θ)
IO.inspect(ι, label: :ι)
IO.inspect(λ, label: :λ)
:ok