Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Get Comfortable with Nx

GettingComfortableWithNx.livemd

Get Comfortable with Nx

Mix.install([
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:benchee, "~> 1.0"}
])

Thinking in Tensors

Understanding Nx Tensors

Nx.tensor([1, 2, 3])
#Nx.Tensor<
  s64[3]
  [1, 2, 3]
>
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor(1.0)
c = Nx.tensor([[[[[[1.0, 2]]]]]])
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
IO.inspect(c, label: :c)
a: #Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>
b: #Nx.Tensor<
  f32
  1.0
>
c: #Nx.Tensor<
  f32[1][1][1][1][1][2]
  [
    [
      [
        [
          [
            [1.0, 2.0]
          ]
        ]
      ]
    ]
  ]
>
#Nx.Tensor<
  f32[1][1][1][1][1][2]
  [
    [
      [
        [
          [
            [1.0, 2.0]
          ]
        ]
      ]
    ]
  ]
>

Tensors Have a Type

a = Nx.tensor([1, 2, 3])
b = Nx.tensor([1.0, 2.0, 3.0])
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
a: #Nx.Tensor<
  s64[3]
  [1, 2, 3]
>
b: #Nx.Tensor<
  f32[3]
  [1.0, 2.0, 3.0]
>
#Nx.Tensor<
  f32[3]
  [1.0, 2.0, 3.0]
>
Nx.tensor(0.0000000000000000000000000000000000000000000001)
#Nx.Tensor<
  f32
  0.0
>
Nx.tensor(1.0e-45, type: {:f, 64})
#Nx.Tensor<
  f64
  1.0e-45
>
Nx.tensor(128, type: {:s, 8})
#Nx.Tensor<
  s8
  -128
>
Nx.tensor([1.0, 2, 3])
#Nx.Tensor<
  f32[3]
  [1.0, 2.0, 3.0]
>

Tensors Have Shape

a = Nx.tensor([1, 2])
b = Nx.tensor([[1, 2], [3, 4]])
c = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
#Nx.Tensor<
  s64[2][2][2]
  [
    [
      [1, 2],
      [3, 4]
    ],
    [
      [5, 6],
      [7, 8]
    ]
  ]
>
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
IO.inspect(c, label: :c)
a: #Nx.Tensor<
  s64[2]
  [1, 2]
>
b: #Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>
c: #Nx.Tensor<
  s64[2][2][2]
  [
    [
      [1, 2],
      [3, 4]
    ],
    [
      [5, 6],
      [7, 8]
    ]
  ]
>
#Nx.Tensor<
  s64[2][2][2]
  [
    [
      [1, 2],
      [3, 4]
    ],
    [
      [5, 6],
      [7, 8]
    ]
  ]
>
Nx.tensor(10)
#Nx.Tensor<
  s64
  10
>
Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y])
#Nx.Tensor<
  s64[x: 2][y: 3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>

Tensors Have Data

a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>
Nx.to_binary(a)
<<1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 5,
  0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0>>
<<1::64-signed-native, 2::64-signed-native, 3::64-signed-native>>
|> Nx.from_binary({:s, 64})
#Nx.Tensor<
  s64[3]
  [1, 2, 3]
>
<<1::64-signed-native, 2::64-signed-native, 3::64-signed-native>>
|> Nx.from_binary({:s, 64})
|> Nx.reshape({1, 3})
#Nx.Tensor<
  s64[1][3]
  [
    [1, 2, 3]
  ]
>

Using Nx Operations

Shape and Type Operations

a = Nx.tensor([1, 2, 3])
#Nx.Tensor<
  s64[3]
  [1, 2, 3]
>
a
|> Nx.as_type({:f, 32})
|> Nx.reshape({1, 3, 1})
#Nx.Tensor<
  f32[1][3][1]
  [
    [
      [1.0],
      [2.0],
      [3.0]
    ]
  ]
>
Nx.bitcast(a, {:f, 64})
#Nx.Tensor<
  f64[3]
  [5.0e-324, 1.0e-323, 1.5e-323]
>

Element-wise Unary Operations

a = [-1, -2, -3, 0, 1, 2, 3]
Enum.map(a, &amp;abs/1)
[1, 2, 3, 0, 1, 2, 3]
a = Nx.tensor([[[-1, -2, -3], [-4, -5, -6]], [[1, 2, 3], [4, 5, 6]]])
#Nx.Tensor<
  s64[2][2][3]
  [
    [
      [-1, -2, -3],
      [-4, -5, -6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ]
  ]
>
Nx.abs(a)
#Nx.Tensor<
  s64[2][2][3]
  [
    [
      [1, 2, 3],
      [4, 5, 6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ]
  ]
>

Element-wise Binary Operations

a = [1, 2, 3]
b = [4, 5, 6]
Enum.zip_with(a, b, fn x, y -> x + y end)
[5, 7, 9]
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor([[6, 7, 8], [9, 10, 11]])
#Nx.Tensor<
  s64[2][3]
  [
    [6, 7, 8],
    [9, 10, 11]
  ]
>
Nx.add(a, b)
#Nx.Tensor<
  s64[2][3]
  [
    [7, 9, 11],
    [13, 15, 17]
  ]
>
Nx.multiply(a, b)
#Nx.Tensor<
  s64[2][3]
  [
    [6, 14, 24],
    [36, 50, 66]
  ]
>
Nx.add(5, Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64[3]
  [6, 7, 8]
>
Nx.add(Nx.tensor([1, 2, 3]), Nx.tensor([[4, 5, 6], [7, 8, 9]]))
#Nx.Tensor<
  s64[2][3]
  [
    [5, 7, 9],
    [8, 10, 12]
  ]
>

Reductions

revs = Nx.tensor([85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51])
#Nx.Tensor<
  s64[12]
  [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
>
Nx.sum(revs)
#Nx.Tensor<
  s64
  647
>
revs =
  Nx.tensor(
    [
      [21, 64, 86, 26, 74, 81, 38, 79, 70, 48, 85, 33],
      [64, 82, 48, 39, 70, 71, 81, 53, 50, 67, 36, 50],
      [68, 74, 39, 78, 95, 62, 53, 21, 43, 59, 51, 88],
      [47, 74, 97, 51, 98, 47, 61, 36, 83, 55, 74, 43]
    ],
    names: [:year, :month]
  )
#Nx.Tensor<
  s64[year: 4][month: 12]
  [
    [21, 64, 86, 26, 74, 81, 38, 79, 70, 48, 85, 33],
    [64, 82, 48, 39, 70, 71, 81, 53, 50, 67, 36, 50],
    [68, 74, 39, 78, 95, 62, 53, 21, 43, 59, 51, 88],
    [47, 74, 97, 51, 98, 47, 61, 36, 83, 55, 74, 43]
  ]
>
Nx.sum(revs, axes: [:year])
#Nx.Tensor<
  s64[month: 12]
  [200, 294, 270, 194, 337, 261, 233, 189, 246, 229, 246, 214]
>
Nx.sum(revs, axes: [:month])
#Nx.Tensor<
  s64[year: 4]
  [705, 711, 731, 766]
>

Going from def to defn

defmodule MyModule do
  import Nx.Defn

  defn adds_one(x) do
    # Nx.add(x, 1) |> inspect_expr()
    # deprecated
    Nx.add(x, 1) |> print_expr()
  end
end
{:module, MyModule, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
MyModule.adds_one(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64[3]
  
  Nx.Defn.Expr
  parameter a:0   s64[3]
  b = add 1, a    s64[3]
>
#Nx.Tensor<
  s64[3]
  EXLA.Backend
  [2, 3, 4]
>
defmodule Softmax do
  import Nx.Defn

  defn(softmax(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n)))
end
{:module, Softmax, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
# tensor = Nx.random_uniform({1_000_000})
key = Nx.Random.key(1701)
{tensor, _new_key} = Nx.Random.uniform(key, shape: {1_000_000}, type: :f32)
# IO.inspect(tensor)

Benchee.run(
  %{
    "JIT with EXLA" => fn -> apply(EXLA.jit(&amp;Softmax.softmax/1), [tensor]) end,
    "Regular Elixir" => fn -> Softmax.softmax(tensor) end
  },
  time: 10
)
Error trying to determine erlang version enoent, falling back to overall OTP version
Warning: the benchmark JIT with EXLA is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Warning: the benchmark Regular Elixir is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Operating System: macOS
CPU Information: Apple M2 Pro
Number of Available Cores: 10
Available memory: 32 GB
Elixir 1.17.2
Erlang 27
JIT enabled: true

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s

Benchmarking JIT with EXLA ...
Benchmarking Regular Elixir ...
Calculating statistics...
Formatting results...

Name                     ips        average  deviation         median         99th %
Regular Elixir        1.41 K      711.36 μs    ±44.07%      636.09 μs     1839.09 μs
JIT with EXLA         1.40 K      714.18 μs    ±40.02%      665.50 μs     1668.41 μs

Comparison: 
Regular Elixir        1.41 K
JIT with EXLA         1.40 K - 1.00x slower +2.83 μs
%Benchee.Suite{
  system: %Benchee.System{
    elixir: "1.17.2",
    erlang: "27",
    jit_enabled?: true,
    num_cores: 10,
    os: :macOS,
    available_memory: "32 GB",
    cpu_speed: "Apple M2 Pro"
  },
  configuration: %Benchee.Configuration{
    parallel: 1,
    time: 10000000000.0,
    warmup: 2000000000.0,
    memory_time: 0.0,
    reduction_time: 0.0,
    pre_check: false,
    formatters: [Benchee.Formatters.Console],
    percentiles: ~c"2c",
    print: %{configuration: true, fast_warning: true, benchmarking: true},
    inputs: nil,
    input_names: [],
    save: false,
    load: false,
    unit_scaling: :best,
    assigns: %{},
    before_each: nil,
    after_each: nil,
    before_scenario: nil,
    after_scenario: nil,
    measure_function_call_overhead: false,
    title: nil,
    profile_after: false
  },
  scenarios: [
    %Benchee.Scenario{
      name: "Regular Elixir",
      job_name: "Regular Elixir",
      function: #Function<43.39164016/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 711357.6566303652,
          ips: 1405.7626155834278,
          std_dev: 313469.2960531113,
          std_dev_ratio: 0.4406634175247176,
          std_dev_ips: 619.4681584114791,
          median: 636086.0,
          percentiles: %{50 => 636086.0, 99 => 1839087.0},
          mode: 435751,
          minimum: 355250,
          maximum: 8143352,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 14049
        },
        samples: [1161793, 579919, 573710, 712377, 611084, 918794, 529751, 543334, 419835, 958419,
         668917, 683085, 801626, 1199753, 657377, 445501, 375500, 1069045, 502918, 828794, 688834,
         390042, 888669, 515793, 402168, 643585, 1320670, 417918, 799334, 752127, 1963795, 532835,
         607168, ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    },
    %Benchee.Scenario{
      name: "JIT with EXLA",
      job_name: "JIT with EXLA",
      function: #Function<43.39164016/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 714182.8157067314,
          ips: 1400.201710272787,
          std_dev: 285783.1655020827,
          std_dev_ratio: 0.4001540771031871,
          std_dev_ips: 560.2964231325113,
          median: 665501.5,
          percentiles: %{50 => 665501.5, 99 => 1668414.7499999984},
          mode: 430709,
          minimum: 336959,
          maximum: 7904933,
          relative_more: 1.003971503012632,
          relative_less: 0.9960442074294793,
          absolute_difference: 2825.1590763662243,
          sample_size: 13994
        },
        samples: [870419, 868752, 498626, 933502, 640336, 501126, 968002, 544293, 858586, 481084,
         761502, 711335, 812001, 673710, 663917, 568585, 547126, 752169, 611126, 876335, 888543,
         599418, 657710, 865127, 531542, 535460, 952961, 766835, 701543, 864209, 926628, 854626,
         ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    }
  ]
}
Nx.Defn.global_default_options(compiler: EXLA)
[compiler: EXLA]
# tensor = Nx.random_uniform({1_000_000})
key = Nx.Random.key(1701)
{tensor, _new_key} = Nx.Random.uniform(key, shape: {1_000_000}, type: :f32)



Benchee.run(
  %{
    "JIT with EXLA" => fn -> apply(EXLA.jit(&amp;Softmax.softmax/1), [tensor]) end,
    "Regular Elixir" => fn -> Softmax.softmax(tensor) end
  },
  time: 10
)
Error trying to determine erlang version enoent, falling back to overall OTP version
Warning: the benchmark JIT with EXLA is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Warning: the benchmark Regular Elixir is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Operating System: macOS
CPU Information: Apple M2 Pro
Number of Available Cores: 10
Available memory: 32 GB
Elixir 1.17.2
Erlang 27
JIT enabled: true

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s

Benchmarking JIT with EXLA ...
Benchmarking Regular Elixir ...
Calculating statistics...
Formatting results...

Name                     ips        average  deviation         median         99th %
Regular Elixir        1.43 K      700.07 μs    ±42.80%      623.09 μs     1769.44 μs
JIT with EXLA         1.41 K      710.46 μs    ±49.29%      647.27 μs     1645.45 μs

Comparison: 
Regular Elixir        1.43 K
JIT with EXLA         1.41 K - 1.01x slower +10.39 μs
%Benchee.Suite{
  system: %Benchee.System{
    elixir: "1.17.2",
    erlang: "27",
    jit_enabled?: true,
    num_cores: 10,
    os: :macOS,
    available_memory: "32 GB",
    cpu_speed: "Apple M2 Pro"
  },
  configuration: %Benchee.Configuration{
    parallel: 1,
    time: 10000000000.0,
    warmup: 2000000000.0,
    memory_time: 0.0,
    reduction_time: 0.0,
    pre_check: false,
    formatters: [Benchee.Formatters.Console],
    percentiles: ~c"2c",
    print: %{configuration: true, fast_warning: true, benchmarking: true},
    inputs: nil,
    input_names: [],
    save: false,
    load: false,
    unit_scaling: :best,
    assigns: %{},
    before_each: nil,
    after_each: nil,
    before_scenario: nil,
    after_scenario: nil,
    measure_function_call_overhead: false,
    title: nil,
    profile_after: false
  },
  scenarios: [
    %Benchee.Scenario{
      name: "Regular Elixir",
      job_name: "Regular Elixir",
      function: #Function<43.39164016/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 700065.2609274307,
          ips: 1428.4382554209624,
          std_dev: 299652.91088792804,
          std_dev_ratio: 0.42803568126056507,
          std_dev_ips: 611.4225417977647,
          median: 623085.0,
          percentiles: %{50 => 623085.0, 99 => 1769436.5},
          mode: [444751, 548751, 441876, 540043, 465251],
          minimum: 347876,
          maximum: 5780887,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 14276
        },
        samples: [607834, 885293, 479043, 750709, 562834, 961128, 496418, 685292, 460126, 878585,
         830461, 450626, 813585, 1202212, 874960, 780751, 396626, 1071753, 476293, 816127, 941461,
         370625, 541668, 616293, 825086, 479543, 734752, 842877, 912419, 925835, 984086, 530959,
         876419, ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    },
    %Benchee.Scenario{
      name: "JIT with EXLA",
      job_name: "JIT with EXLA",
      function: #Function<43.39164016/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 710460.1069092976,
          ips: 1407.5385658883547,
          std_dev: 350209.4821209828,
          std_dev_ratio: 0.4929333522250715,
          std_dev_ips: 693.8227036694163,
          median: 647272.5,
          percentiles: %{50 => 647272.5, 99 => 1645447.0399999982},
          mode: [433251, 490501],
          minimum: 343416,
          maximum: 18455999,
          relative_more: 1.0148483956597076,
          relative_less: 0.9853688533940527,
          absolute_difference: 10394.845981866936,
          sample_size: 14068
        },
        samples: [655460, 601626, 505459, 524792, 489126, 485668, 465751, 492001, 868877, 717667,
         521501, 459459, 808585, 557626, 737460, 750710, 560168, 875960, 913668, 459043, 743544,
         509709, 519793, 709544, 550001, 483001, 894336, 488043, 556918, 500084, 800252, 484084,
         ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    }
  ]
}