Get Comfortable with Nx
Mix.install([
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:benchee, "~> 1.0"}
])
Thinking in Tensors
Understanding Nx Tensors
Nx.tensor([1, 2, 3])
#Nx.Tensor<
s64[3]
[1, 2, 3]
>
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor(1.0)
c = Nx.tensor([[[[[[1.0, 2]]]]]])
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
IO.inspect(c, label: :c)
a: #Nx.Tensor<
s64[2][3]
[
[1, 2, 3],
[4, 5, 6]
]
>
b: #Nx.Tensor<
f32
1.0
>
c: #Nx.Tensor<
f32[1][1][1][1][1][2]
[
[
[
[
[
[1.0, 2.0]
]
]
]
]
]
>
#Nx.Tensor<
f32[1][1][1][1][1][2]
[
[
[
[
[
[1.0, 2.0]
]
]
]
]
]
>
Tensors Have a Type
a = Nx.tensor([1, 2, 3])
b = Nx.tensor([1.0, 2.0, 3.0])
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
a: #Nx.Tensor<
s64[3]
[1, 2, 3]
>
b: #Nx.Tensor<
f32[3]
[1.0, 2.0, 3.0]
>
#Nx.Tensor<
f32[3]
[1.0, 2.0, 3.0]
>
Nx.tensor(0.0000000000000000000000000000000000000000000001)
#Nx.Tensor<
f32
0.0
>
Nx.tensor(1.0e-45, type: {:f, 64})
#Nx.Tensor<
f64
1.0e-45
>
Nx.tensor(128, type: {:s, 8})
#Nx.Tensor<
s8
-128
>
Nx.tensor([1.0, 2, 3])
#Nx.Tensor<
f32[3]
[1.0, 2.0, 3.0]
>
Tensors Have Shape
a = Nx.tensor([1, 2])
b = Nx.tensor([[1, 2], [3, 4]])
c = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
#Nx.Tensor<
s64[2][2][2]
[
[
[1, 2],
[3, 4]
],
[
[5, 6],
[7, 8]
]
]
>
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
IO.inspect(c, label: :c)
a: #Nx.Tensor<
s64[2]
[1, 2]
>
b: #Nx.Tensor<
s64[2][2]
[
[1, 2],
[3, 4]
]
>
c: #Nx.Tensor<
s64[2][2][2]
[
[
[1, 2],
[3, 4]
],
[
[5, 6],
[7, 8]
]
]
>
#Nx.Tensor<
s64[2][2][2]
[
[
[1, 2],
[3, 4]
],
[
[5, 6],
[7, 8]
]
]
>
Nx.tensor(10)
#Nx.Tensor<
s64
10
>
Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y])
#Nx.Tensor<
s64[x: 2][y: 3]
[
[1, 2, 3],
[4, 5, 6]
]
>
Tensors Have Data
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
s64[2][3]
[
[1, 2, 3],
[4, 5, 6]
]
>
Nx.to_binary(a)
<<1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 5,
0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0>>
<<1::64-signed-native, 2::64-signed-native, 3::64-signed-native>>
|> Nx.from_binary({:s, 64})
#Nx.Tensor<
s64[3]
[1, 2, 3]
>
<<1::64-signed-native, 2::64-signed-native, 3::64-signed-native>>
|> Nx.from_binary({:s, 64})
|> Nx.reshape({1, 3})
#Nx.Tensor<
s64[1][3]
[
[1, 2, 3]
]
>
Using Nx Operations
Shape and Type Operations
a = Nx.tensor([1, 2, 3])
#Nx.Tensor<
s64[3]
[1, 2, 3]
>
a
|> Nx.as_type({:f, 32})
|> Nx.reshape({1, 3, 1})
#Nx.Tensor<
f32[1][3][1]
[
[
[1.0],
[2.0],
[3.0]
]
]
>
Nx.bitcast(a, {:f, 64})
#Nx.Tensor<
f64[3]
[5.0e-324, 1.0e-323, 1.5e-323]
>
Element-wise Unary Operations
a = [-1, -2, -3, 0, 1, 2, 3]
Enum.map(a, &abs/1)
[1, 2, 3, 0, 1, 2, 3]
a = Nx.tensor([[[-1, -2, -3], [-4, -5, -6]], [[1, 2, 3], [4, 5, 6]]])
#Nx.Tensor<
s64[2][2][3]
[
[
[-1, -2, -3],
[-4, -5, -6]
],
[
[1, 2, 3],
[4, 5, 6]
]
]
>
Nx.abs(a)
#Nx.Tensor<
s64[2][2][3]
[
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
]
]
>
Element-wise Binary Operations
a = [1, 2, 3]
b = [4, 5, 6]
Enum.zip_with(a, b, fn x, y -> x + y end)
[5, 7, 9]
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor([[6, 7, 8], [9, 10, 11]])
#Nx.Tensor<
s64[2][3]
[
[6, 7, 8],
[9, 10, 11]
]
>
Nx.add(a, b)
#Nx.Tensor<
s64[2][3]
[
[7, 9, 11],
[13, 15, 17]
]
>
Nx.multiply(a, b)
#Nx.Tensor<
s64[2][3]
[
[6, 14, 24],
[36, 50, 66]
]
>
Nx.add(5, Nx.tensor([1, 2, 3]))
#Nx.Tensor<
s64[3]
[6, 7, 8]
>
Nx.add(Nx.tensor([1, 2, 3]), Nx.tensor([[4, 5, 6], [7, 8, 9]]))
#Nx.Tensor<
s64[2][3]
[
[5, 7, 9],
[8, 10, 12]
]
>
Reductions
revs = Nx.tensor([85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51])
#Nx.Tensor<
s64[12]
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
>
Nx.sum(revs)
#Nx.Tensor<
s64
647
>
revs =
Nx.tensor(
[
[21, 64, 86, 26, 74, 81, 38, 79, 70, 48, 85, 33],
[64, 82, 48, 39, 70, 71, 81, 53, 50, 67, 36, 50],
[68, 74, 39, 78, 95, 62, 53, 21, 43, 59, 51, 88],
[47, 74, 97, 51, 98, 47, 61, 36, 83, 55, 74, 43]
],
names: [:year, :month]
)
#Nx.Tensor<
s64[year: 4][month: 12]
[
[21, 64, 86, 26, 74, 81, 38, 79, 70, 48, 85, 33],
[64, 82, 48, 39, 70, 71, 81, 53, 50, 67, 36, 50],
[68, 74, 39, 78, 95, 62, 53, 21, 43, 59, 51, 88],
[47, 74, 97, 51, 98, 47, 61, 36, 83, 55, 74, 43]
]
>
Nx.sum(revs, axes: [:year])
#Nx.Tensor<
s64[month: 12]
[200, 294, 270, 194, 337, 261, 233, 189, 246, 229, 246, 214]
>
Nx.sum(revs, axes: [:month])
#Nx.Tensor<
s64[year: 4]
[705, 711, 731, 766]
>
Going from def to defn
defmodule MyModule do
import Nx.Defn
defn adds_one(x) do
# Nx.add(x, 1) |> inspect_expr()
# deprecated
Nx.add(x, 1) |> print_expr()
end
end
{:module, MyModule, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
MyModule.adds_one(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
s64[3]
Nx.Defn.Expr
parameter a:0 s64[3]
b = add 1, a s64[3]
>
#Nx.Tensor<
s64[3]
EXLA.Backend
[2, 3, 4]
>
defmodule Softmax do
import Nx.Defn
defn(softmax(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n)))
end
{:module, Softmax, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
# tensor = Nx.random_uniform({1_000_000})
key = Nx.Random.key(1701)
{tensor, _new_key} = Nx.Random.uniform(key, shape: {1_000_000}, type: :f32)
# IO.inspect(tensor)
Benchee.run(
%{
"JIT with EXLA" => fn -> apply(EXLA.jit(&Softmax.softmax/1), [tensor]) end,
"Regular Elixir" => fn -> Softmax.softmax(tensor) end
},
time: 10
)
Error trying to determine erlang version enoent, falling back to overall OTP version
Warning: the benchmark JIT with EXLA is using an evaluated function.
Evaluated functions perform slower than compiled functions.
You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs
Warning: the benchmark Regular Elixir is using an evaluated function.
Evaluated functions perform slower than compiled functions.
You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs
Operating System: macOS
CPU Information: Apple M2 Pro
Number of Available Cores: 10
Available memory: 32 GB
Elixir 1.17.2
Erlang 27
JIT enabled: true
Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s
Benchmarking JIT with EXLA ...
Benchmarking Regular Elixir ...
Calculating statistics...
Formatting results...
Name ips average deviation median 99th %
Regular Elixir 1.41 K 711.36 μs ±44.07% 636.09 μs 1839.09 μs
JIT with EXLA 1.40 K 714.18 μs ±40.02% 665.50 μs 1668.41 μs
Comparison:
Regular Elixir 1.41 K
JIT with EXLA 1.40 K - 1.00x slower +2.83 μs
%Benchee.Suite{
system: %Benchee.System{
elixir: "1.17.2",
erlang: "27",
jit_enabled?: true,
num_cores: 10,
os: :macOS,
available_memory: "32 GB",
cpu_speed: "Apple M2 Pro"
},
configuration: %Benchee.Configuration{
parallel: 1,
time: 10000000000.0,
warmup: 2000000000.0,
memory_time: 0.0,
reduction_time: 0.0,
pre_check: false,
formatters: [Benchee.Formatters.Console],
percentiles: ~c"2c",
print: %{configuration: true, fast_warning: true, benchmarking: true},
inputs: nil,
input_names: [],
save: false,
load: false,
unit_scaling: :best,
assigns: %{},
before_each: nil,
after_each: nil,
before_scenario: nil,
after_scenario: nil,
measure_function_call_overhead: false,
title: nil,
profile_after: false
},
scenarios: [
%Benchee.Scenario{
name: "Regular Elixir",
job_name: "Regular Elixir",
function: #Function<43.39164016/0 in :erl_eval.expr/6>,
input_name: :__no_input,
input: :__no_input,
before_each: nil,
after_each: nil,
before_scenario: nil,
after_scenario: nil,
tag: nil,
run_time_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: 711357.6566303652,
ips: 1405.7626155834278,
std_dev: 313469.2960531113,
std_dev_ratio: 0.4406634175247176,
std_dev_ips: 619.4681584114791,
median: 636086.0,
percentiles: %{50 => 636086.0, 99 => 1839087.0},
mode: 435751,
minimum: 355250,
maximum: 8143352,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 14049
},
samples: [1161793, 579919, 573710, 712377, 611084, 918794, 529751, 543334, 419835, 958419,
668917, 683085, 801626, 1199753, 657377, 445501, 375500, 1069045, 502918, 828794, 688834,
390042, 888669, 515793, 402168, 643585, 1320670, 417918, 799334, 752127, 1963795, 532835,
607168, ...]
},
memory_usage_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: nil,
ips: nil,
std_dev: nil,
std_dev_ratio: nil,
std_dev_ips: nil,
median: nil,
percentiles: nil,
mode: nil,
minimum: nil,
maximum: nil,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 0
},
samples: []
},
reductions_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: nil,
ips: nil,
std_dev: nil,
std_dev_ratio: nil,
std_dev_ips: nil,
median: nil,
percentiles: nil,
mode: nil,
minimum: nil,
maximum: nil,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 0
},
samples: []
}
},
%Benchee.Scenario{
name: "JIT with EXLA",
job_name: "JIT with EXLA",
function: #Function<43.39164016/0 in :erl_eval.expr/6>,
input_name: :__no_input,
input: :__no_input,
before_each: nil,
after_each: nil,
before_scenario: nil,
after_scenario: nil,
tag: nil,
run_time_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: 714182.8157067314,
ips: 1400.201710272787,
std_dev: 285783.1655020827,
std_dev_ratio: 0.4001540771031871,
std_dev_ips: 560.2964231325113,
median: 665501.5,
percentiles: %{50 => 665501.5, 99 => 1668414.7499999984},
mode: 430709,
minimum: 336959,
maximum: 7904933,
relative_more: 1.003971503012632,
relative_less: 0.9960442074294793,
absolute_difference: 2825.1590763662243,
sample_size: 13994
},
samples: [870419, 868752, 498626, 933502, 640336, 501126, 968002, 544293, 858586, 481084,
761502, 711335, 812001, 673710, 663917, 568585, 547126, 752169, 611126, 876335, 888543,
599418, 657710, 865127, 531542, 535460, 952961, 766835, 701543, 864209, 926628, 854626,
...]
},
memory_usage_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: nil,
ips: nil,
std_dev: nil,
std_dev_ratio: nil,
std_dev_ips: nil,
median: nil,
percentiles: nil,
mode: nil,
minimum: nil,
maximum: nil,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 0
},
samples: []
},
reductions_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: nil,
ips: nil,
std_dev: nil,
std_dev_ratio: nil,
std_dev_ips: nil,
median: nil,
percentiles: nil,
mode: nil,
minimum: nil,
maximum: nil,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 0
},
samples: []
}
}
]
}
Nx.Defn.global_default_options(compiler: EXLA)
[compiler: EXLA]
# tensor = Nx.random_uniform({1_000_000})
key = Nx.Random.key(1701)
{tensor, _new_key} = Nx.Random.uniform(key, shape: {1_000_000}, type: :f32)
Benchee.run(
%{
"JIT with EXLA" => fn -> apply(EXLA.jit(&Softmax.softmax/1), [tensor]) end,
"Regular Elixir" => fn -> Softmax.softmax(tensor) end
},
time: 10
)
Error trying to determine erlang version enoent, falling back to overall OTP version
Warning: the benchmark JIT with EXLA is using an evaluated function.
Evaluated functions perform slower than compiled functions.
You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs
Warning: the benchmark Regular Elixir is using an evaluated function.
Evaluated functions perform slower than compiled functions.
You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs
Operating System: macOS
CPU Information: Apple M2 Pro
Number of Available Cores: 10
Available memory: 32 GB
Elixir 1.17.2
Erlang 27
JIT enabled: true
Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s
Benchmarking JIT with EXLA ...
Benchmarking Regular Elixir ...
Calculating statistics...
Formatting results...
Name ips average deviation median 99th %
Regular Elixir 1.43 K 700.07 μs ±42.80% 623.09 μs 1769.44 μs
JIT with EXLA 1.41 K 710.46 μs ±49.29% 647.27 μs 1645.45 μs
Comparison:
Regular Elixir 1.43 K
JIT with EXLA 1.41 K - 1.01x slower +10.39 μs
%Benchee.Suite{
system: %Benchee.System{
elixir: "1.17.2",
erlang: "27",
jit_enabled?: true,
num_cores: 10,
os: :macOS,
available_memory: "32 GB",
cpu_speed: "Apple M2 Pro"
},
configuration: %Benchee.Configuration{
parallel: 1,
time: 10000000000.0,
warmup: 2000000000.0,
memory_time: 0.0,
reduction_time: 0.0,
pre_check: false,
formatters: [Benchee.Formatters.Console],
percentiles: ~c"2c",
print: %{configuration: true, fast_warning: true, benchmarking: true},
inputs: nil,
input_names: [],
save: false,
load: false,
unit_scaling: :best,
assigns: %{},
before_each: nil,
after_each: nil,
before_scenario: nil,
after_scenario: nil,
measure_function_call_overhead: false,
title: nil,
profile_after: false
},
scenarios: [
%Benchee.Scenario{
name: "Regular Elixir",
job_name: "Regular Elixir",
function: #Function<43.39164016/0 in :erl_eval.expr/6>,
input_name: :__no_input,
input: :__no_input,
before_each: nil,
after_each: nil,
before_scenario: nil,
after_scenario: nil,
tag: nil,
run_time_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: 700065.2609274307,
ips: 1428.4382554209624,
std_dev: 299652.91088792804,
std_dev_ratio: 0.42803568126056507,
std_dev_ips: 611.4225417977647,
median: 623085.0,
percentiles: %{50 => 623085.0, 99 => 1769436.5},
mode: [444751, 548751, 441876, 540043, 465251],
minimum: 347876,
maximum: 5780887,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 14276
},
samples: [607834, 885293, 479043, 750709, 562834, 961128, 496418, 685292, 460126, 878585,
830461, 450626, 813585, 1202212, 874960, 780751, 396626, 1071753, 476293, 816127, 941461,
370625, 541668, 616293, 825086, 479543, 734752, 842877, 912419, 925835, 984086, 530959,
876419, ...]
},
memory_usage_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: nil,
ips: nil,
std_dev: nil,
std_dev_ratio: nil,
std_dev_ips: nil,
median: nil,
percentiles: nil,
mode: nil,
minimum: nil,
maximum: nil,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 0
},
samples: []
},
reductions_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: nil,
ips: nil,
std_dev: nil,
std_dev_ratio: nil,
std_dev_ips: nil,
median: nil,
percentiles: nil,
mode: nil,
minimum: nil,
maximum: nil,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 0
},
samples: []
}
},
%Benchee.Scenario{
name: "JIT with EXLA",
job_name: "JIT with EXLA",
function: #Function<43.39164016/0 in :erl_eval.expr/6>,
input_name: :__no_input,
input: :__no_input,
before_each: nil,
after_each: nil,
before_scenario: nil,
after_scenario: nil,
tag: nil,
run_time_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: 710460.1069092976,
ips: 1407.5385658883547,
std_dev: 350209.4821209828,
std_dev_ratio: 0.4929333522250715,
std_dev_ips: 693.8227036694163,
median: 647272.5,
percentiles: %{50 => 647272.5, 99 => 1645447.0399999982},
mode: [433251, 490501],
minimum: 343416,
maximum: 18455999,
relative_more: 1.0148483956597076,
relative_less: 0.9853688533940527,
absolute_difference: 10394.845981866936,
sample_size: 14068
},
samples: [655460, 601626, 505459, 524792, 489126, 485668, 465751, 492001, 868877, 717667,
521501, 459459, 808585, 557626, 737460, 750710, 560168, 875960, 913668, 459043, 743544,
509709, 519793, 709544, 550001, 483001, 894336, 488043, 556918, 500084, 800252, 484084,
...]
},
memory_usage_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: nil,
ips: nil,
std_dev: nil,
std_dev_ratio: nil,
std_dev_ips: nil,
median: nil,
percentiles: nil,
mode: nil,
minimum: nil,
maximum: nil,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 0
},
samples: []
},
reductions_data: %Benchee.CollectionData{
statistics: %Benchee.Statistics{
average: nil,
ips: nil,
std_dev: nil,
std_dev_ratio: nil,
std_dev_ips: nil,
median: nil,
percentiles: nil,
mode: nil,
minimum: nil,
maximum: nil,
relative_more: nil,
relative_less: nil,
absolute_difference: nil,
sample_size: 0
},
samples: []
}
}
]
}