Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter Two

Chapter02.livemd

Chapter Two

Mix.install([
  {:nx, "~> 0.6"},
  {:exla, "~> 0.6"},
  {:benchee, "~> 1.1"}
])

Understanding Tensors

uta =
  Nx.tensor([
    [1, 2, 3],
    [4, 5, 6]
  ])

utb = Nx.tensor(1.1)

utc = Nx.tensor([[[[1, 2, 3]]]])

utd = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y])

ute =
  Nx.tensor(
    [
      [
        [1, 2, 3],
        [3, 2, 1],
        [2, 3, 1]
      ],
      [
        [4, 5, 6],
        [6, 5, 4],
        [5, 4, 6]
      ],
      [
        [7, 8, 9],
        [9, 8, 7],
        [8, 9, 7]
      ]
    ],
    names: [:x, :y, :z],
    type: :u8
  )

utf = <<1::16-unsigned-native, 2::16-unsigned-native>> |> Nx.from_binary({:u, 16})

IO.inspect(uta, label: :a)
IO.inspect(utb, label: :b)
IO.inspect(utc, label: :c)
IO.inspect(utd, label: :d)
IO.inspect(ute, label: :e)
IO.inspect(Nx.to_binary(ute), label: :e_bin)
IO.inspect(utf, label: :f)

:ok

Size and Shape

ssa = Nx.tensor(for n <- 1..3, do: n)
ssb = ssa |> Nx.as_type({:bf, 16}) |> Nx.reshape({1, 3, 1})
ssc = Nx.bitcast(ssa, {:f, 64})

IO.inspect(ssa, label: :a)
IO.inspect(ssb, label: :b)
IO.inspect(ssc, label: :b)

:ok

Using Nx Operations

oa = Nx.tensor([[[-1, -2, -3], [-4, -5, -6]], [[1, 2, 3], [4, 5, 6]]])
ob = Nx.abs(oa)
oc = Nx.to_list(ob) |> List.flatten()
od = Nx.tensor(for n <- 4..14, do: n)
oe = Nx.tensor(for n <- -3..7, do: n)
of = Nx.add(od, oe)
og = Nx.divide(od, Nx.add(oe, 10))
oh = Nx.multiply(of, og)

oi =
  Nx.add(
    Nx.tensor([99, 100, 101]),
    Nx.tensor([[3, 4, 5], [18, 19, 20]])
  )

oj = Nx.add(oi, 10)

IO.inspect(oa, label: :a)
IO.inspect(ob, label: :b)
IO.inspect(oc, label: :c)
IO.inspect(od, label: :d)
IO.inspect(oe, label: :e)
IO.inspect(of, label: :f)
IO.inspect(og, label: :g)
IO.inspect(oh, label: :h)
IO.inspect(oi, label: :i)
IO.inspect(oj, label: :j)

:ok

Reductions

rk = Nx.Random.key(77_773)
{rr, _} = Nx.Random.uniform(rk, shape: {4, 16}, type: :f16, names: [:x, :y])
ra = Nx.sum(rr, axes: [:y]) |> Nx.sum()
rs = Nx.reshape(rr, {4, 4, 4}, type: :bf16, names: [:x, :y, :z])

IO.inspect(rk, label: :k)
IO.inspect(rr, label: :r)
IO.inspect(ra, label: :a)
IO.inspect(rs, label: :s)

:ok

defn

defmodule DefnModule do
  import Nx.Defn

  defn(add(x, y), do: Nx.add(x, y) |> print_expr())
end

na = DefnModule.add(Nx.tensor([1, 2, 3], type: :u8), Nx.tensor([4, 5, 6], type: :u8))

IO.inspect(na, label: :a)

:ok

Optimized Softmax

defmodule Softmax do
  import Nx.Defn

  defn(softmax(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n)))
end

# Nx.Defn.global_default_options(compiler: EXLA)

tensor = Nx.random_uniform({1_000_000})

Benchee.run(
  %{
    "JIT with EXLA" => fn -> apply(EXLA.jit(&amp;Softmax.softmax/1), [tensor]) end,
    "Regular Elixir" => fn -> Softmax.softmax(tensor) end
  },
  time: 10
)
κ = Nx.Random.key(77_773)
{α, _} = Nx.Random.uniform(κ, shape: {128}, type: :bf16)
{β, _} = Nx.Random.normal(κ, shape: {128}, type: :bf16)
γ = Nx.multiply(α, β)
δ = Nx.dot(α, β)
ε = Nx.reshape(α, {8, 16})
ζ = Nx.reshape(β, {16, 8})
η = Nx.dot(ε, ζ)
θ = Nx.slice(α, [0], [64])
ι = Nx.slice(α, [64], [64]) |> Nx.transpose()
λ = Nx.dot(θ, ι)

IO.inspect(α, label: )
IO.inspect(β, label: )
IO.inspect(γ, label: )
IO.inspect(δ, label: )
IO.inspect(ε, label: )
IO.inspect(ζ, label: )
IO.inspect(η, label: )
IO.inspect(θ, label: )
IO.inspect(ι, label: )
IO.inspect(λ, label: )

:ok