Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Nx ライブラリに慣れよう

ch02/02_origin.livemd

Nx ライブラリに慣れよう

Mix.install([
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:benchee, "~> 1.0"}
])

テンソルについて考える

Nx ライブラリにおけるテンソルについて

Nx.tensor([1, 2, 3])
#Nx.Tensor<
  s64[3]
  [1, 2, 3]
>
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor(1.0)
c = Nx.tensor([[[[[[1.0, 2]]]]]])
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
IO.inspect(c, label: :c)
a: #Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>
b: #Nx.Tensor<
  f32
  1.0
>
c: #Nx.Tensor<
  f32[1][1][1][1][1][2]
  [
    [
      [
        [
          [
            [1.0, 2.0]
          ]
        ]
      ]
    ]
  ]
>
#Nx.Tensor<
  f32[1][1][1][1][1][2]
  [
    [
      [
        [
          [
            [1.0, 2.0]
          ]
        ]
      ]
    ]
  ]
>

テンソルの型情報

a = Nx.tensor([1, 2, 3])
b = Nx.tensor([1.0, 2.0, 3.0])
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
a: #Nx.Tensor<
  s64[3]
  [1, 2, 3]
>
b: #Nx.Tensor<
  f32[3]
  [1.0, 2.0, 3.0]
>
#Nx.Tensor<
  f32[3]
  [1.0, 2.0, 3.0]
>
Nx.tensor(0.0000000000000000000000000000000000000000000001)
#Nx.Tensor<
  f32
  0.0
>
Nx.tensor(1.0e-45, type: {:f, 64})
#Nx.Tensor<
  f64
  1.0e-45
>
Nx.tensor(128, type: {:s, 8})
#Nx.Tensor<
  s8
  -128
>
Nx.tensor([1.0, 2, 3])
#Nx.Tensor<
  f32[3]
  [1.0, 2.0, 3.0]
>

テンソルの次元数

a = Nx.tensor([1, 2])
b = Nx.tensor([[1, 2], [3, 4]])
c = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
#Nx.Tensor<
  s64[2][2][2]
  [
    [
      [1, 2],
      [3, 4]
    ],
    [
      [5, 6],
      [7, 8]
    ]
  ]
>
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
IO.inspect(c, label: :c)
a: #Nx.Tensor<
  s64[2]
  [1, 2]
>
b: #Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>
c: #Nx.Tensor<
  s64[2][2][2]
  [
    [
      [1, 2],
      [3, 4]
    ],
    [
      [5, 6],
      [7, 8]
    ]
  ]
>
#Nx.Tensor<
  s64[2][2][2]
  [
    [
      [1, 2],
      [3, 4]
    ],
    [
      [5, 6],
      [7, 8]
    ]
  ]
>
Nx.tensor(10)
#Nx.Tensor<
  s64
  10
>
Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y])
#Nx.Tensor<
  s64[x: 2][y: 3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>

テンソルに含まれるデータ

a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>
Nx.to_binary(a)
<<1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 5,
  0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0>>
<<1::64-signed-native, 2::64-signed-native, 3::64-signed-native>>
|> Nx.from_binary({:s, 64})
#Nx.Tensor<
  s64[3]
  [1, 2, 3]
>
<<1::64-signed-native, 2::64-signed-native, 3::64-signed-native>>
|> Nx.from_binary({:s, 64})
|> Nx.reshape({1, 3})
#Nx.Tensor<
  s64[1][3]
  [
    [1, 2, 3]
  ]
>

Nx を使う上での必須の操作

shape と type

a = Nx.tensor([1, 2, 3])
#Nx.Tensor<
  s64[3]
  [1, 2, 3]
>
a
|> Nx.as_type({:f, 32})
|> Nx.reshape({1, 3, 1})
#Nx.Tensor<
  f32[1][3][1]
  [
    [
      [1.0],
      [2.0],
      [3.0]
    ]
  ]
>
Nx.bitcast(a, {:f, 64})
#Nx.Tensor<
  f64[3]
  [5.0e-324, 1.0e-323, 1.5e-323]
>

要素ごとの単項演算

a = [-1, -2, -3, 0, 1, 2, 3]
Enum.map(a, &amp;abs/1)
[1, 2, 3, 0, 1, 2, 3]
a = Nx.tensor([[[-1, -2, -3], [-4, -5, -6]], [[1, 2, 3], [4, 5, 6]]])
#Nx.Tensor<
  s64[2][2][3]
  [
    [
      [-1, -2, -3],
      [-4, -5, -6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ]
  ]
>
Nx.abs(a)
#Nx.Tensor<
  s64[2][2][3]
  [
    [
      [1, 2, 3],
      [4, 5, 6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ]
  ]
>

要素ごとのバイナリ演算

a = [1, 2, 3]
b = [4, 5, 6]
Enum.zip_with(a, b, fn x, y -> x + y end)
[5, 7, 9]
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor([[6, 7, 8], [9, 10, 11]])
#Nx.Tensor<
  s64[2][3]
  [
    [6, 7, 8],
    [9, 10, 11]
  ]
>
Nx.add(a, b)
#Nx.Tensor<
  s64[2][3]
  [
    [7, 9, 11],
    [13, 15, 17]
  ]
>
Nx.multiply(a, b)
#Nx.Tensor<
  s64[2][3]
  [
    [6, 14, 24],
    [36, 50, 66]
  ]
>
Nx.add(5, Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64[3]
  [6, 7, 8]
>
Nx.add(Nx.tensor([1, 2, 3]), Nx.tensor([[4, 5, 6], [7, 8, 9]]))
#Nx.Tensor<
  s64[2][3]
  [
    [5, 7, 9],
    [8, 10, 12]
  ]
>

Reductions

revs = Nx.tensor([85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51])
#Nx.Tensor<
  s64[12]
  [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
>
Nx.sum(revs)
#Nx.Tensor<
  s64
  647
>
revs =
  Nx.tensor(
    [
      [21, 64, 86, 26, 74, 81, 38, 79, 70, 48, 85, 33],
      [64, 82, 48, 39, 70, 71, 81, 53, 50, 67, 36, 50],
      [68, 74, 39, 78, 95, 62, 53, 21, 43, 59, 51, 88],
      [47, 74, 97, 51, 98, 47, 61, 36, 83, 55, 74, 43]
    ],
    names: [:year, :month]
  )
#Nx.Tensor<
  s64[year: 4][month: 12]
  [
    [21, 64, 86, 26, 74, 81, 38, 79, 70, 48, 85, 33],
    [64, 82, 48, 39, 70, 71, 81, 53, 50, 67, 36, 50],
    [68, 74, 39, 78, 95, 62, 53, 21, 43, 59, 51, 88],
    [47, 74, 97, 51, 98, 47, 61, 36, 83, 55, 74, 43]
  ]
>
Nx.sum(revs, axes: [:year])
#Nx.Tensor<
  s64[month: 12]
  [200, 294, 270, 194, 337, 261, 233, 189, 246, 229, 246, 214]
>
Nx.sum(revs, axes: [:month])
#Nx.Tensor<
  s64[year: 4]
  [705, 711, 731, 766]
>

def 定義を defn 定義に変えると…

defmodule MyModule do
  import Nx.Defn

  defn adds_one(x) do
    Nx.add(x, 1) |> inspect_expr()
  end
end
warning: Nx.Defn.Kernel.inspect_expr/1 is deprecated. Use print_expr/2 instead
  code/GettingComfortableWithNx.livemd#cell:rlqcehi4parqdbgcejnwfuwvfogu37uj:5: MyModule."__defn:adds_one__"/1
{:module, MyModule, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
MyModule.adds_one(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64[3]
  
  Nx.Defn.Expr
  parameter a:0   s64[3]
  b = add 1, a    s64[3]
>
#Nx.Tensor<
  s64[3]
  [2, 3, 4]
>
defmodule Softmax do
  import Nx.Defn

  defn(softmax(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n)))
end
{:module, Softmax, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
tensor = Nx.random_uniform({1_000_000})

Benchee.run(
  %{
    "JIT with EXLA" => fn -> apply(EXLA.jit(&amp;Softmax.softmax/1), [tensor]) end,
    "Regular Elixir" => fn -> Softmax.softmax(tensor) end
  },
  time: 10
)
warning: Nx.random_uniform/1 is deprecated. Use Nx.Random.uniform/2 instead
  code/GettingComfortableWithNx.livemd#cell:2ds5qkfh7ogzfqj7zequ7jh3dqzuipvv:1

Operating System: macOS
CPU Information: Apple M1 Max
Number of Available Cores: 10
Available memory: 32 GB
Elixir 1.14.2
Erlang 25.0.2

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s

Benchmarking JIT with EXLA ...

07:32:24.688 [info] TfrtCpuClient created.
Benchmarking Regular Elixir ...

Name                     ips        average  deviation         median         99th %
JIT with EXLA         1.18 K        0.85 ms    ±59.87%        0.71 ms        3.34 ms
Regular Elixir     0.00306 K      327.12 ms     ±4.01%      331.48 ms      350.93 ms

Comparison: 
JIT with EXLA         1.18 K
Regular Elixir     0.00306 K - 385.80x slower +326.27 ms
%Benchee.Suite{
  system: %{
    available_memory: "32 GB",
    cpu_speed: "Apple M1 Max",
    elixir: "1.14.2",
    erlang: "25.0.2",
    num_cores: 10,
    os: :macOS
  },
  configuration: %Benchee.Configuration{
    parallel: 1,
    time: 10000000000.0,
    warmup: 2000000000.0,
    memory_time: 0.0,
    reduction_time: 0.0,
    pre_check: false,
    formatters: [Benchee.Formatters.Console],
    percentiles: '2c',
    print: %{benchmarking: true, configuration: true, fast_warning: true},
    inputs: nil,
    save: false,
    load: false,
    unit_scaling: :best,
    assigns: %{},
    before_each: nil,
    after_each: nil,
    before_scenario: nil,
    after_scenario: nil,
    measure_function_call_overhead: false,
    title: nil,
    profile_after: false
  },
  scenarios: [
    %Benchee.Scenario{
      name: "JIT with EXLA",
      job_name: "JIT with EXLA",
      function: #Function<43.3316493/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 847904.7806369427,
          ips: 1179.3777117859909,
          std_dev: 507620.6462510765,
          std_dev_ratio: 0.5986764762309205,
          std_dev_ips: 706.0656926373231,
          median: 708049.0,
          percentiles: %{50 => 708049.0, 99 => 3337619.8399999985},
          mode: 736132,
          minimum: 398920,
          maximum: 5893470,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 11775
        },
        samples: [591756, 751715, 769090, 744215, 794674, 775590, 694423, 1011010, 619422, 834966,
         861966, 706840, 738465, 579755, 646672, 896967, 643255, 923051, 664381, 944217, 640131,
         1142677, 834508, 919466, 757132, 807049, 787050, 631089, 1045885, 648797, 660547, 573463,
         696590, ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    },
    %Benchee.Scenario{
      name: "Regular Elixir",
      job_name: "Regular Elixir",
      function: #Function<43.3316493/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 327118284.32258064,
          ips: 3.0569981805537707,
          std_dev: 13126287.776793953,
          std_dev_ratio: 0.04012703785108431,
          std_dev_ips: 0.12266828170177703,
          median: 331484659.0,
          percentiles: %{50 => 331484659.0, 99 => 350929753.0},
          mode: nil,
          minimum: 300087997,
          maximum: 350929753,
          relative_more: 385.7960136477243,
          relative_less: 0.002592043371689978,
          absolute_difference: 326270379.54194367,
          sample_size: 31
        },
        samples: [309212580, 317130194, 310805929, 346726341, 307113061, 304059033, 328127253,
         322722787, 330623068, 306861059, 306479221, 300087997, 323321750, 333835473, 330243023,
         327888501, 331935247, 339329022, 331484659, 333029923, 334954732, 335370111, 330404357,
         334286018, 336033742, 343203641, 333621471, 332058039, 335043192, 333745637, 350929753]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    }
  ]
}
Nx.Defn.global_default_options(compiler: EXLA)
[]
tensor = Nx.random_uniform({1_000_000})

Benchee.run(
  %{
    "JIT with EXLA" => fn -> apply(EXLA.jit(&amp;Softmax.softmax/1), [tensor]) end,
    "Regular Elixir" => fn -> Softmax.softmax(tensor) end
  },
  time: 10
)
warning: Nx.random_uniform/1 is deprecated. Use Nx.Random.uniform/2 instead
  code/GettingComfortableWithNx.livemd#cell:6pn5gelypjrncbr65vihwswamhywsstt:1

Operating System: macOS
CPU Information: Apple M1 Max
Number of Available Cores: 10
Available memory: 32 GB
Elixir 1.14.2
Erlang 25.0.2

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s

Benchmarking JIT with EXLA ...
Benchmarking Regular Elixir ...

Name                     ips        average  deviation         median         99th %
JIT with EXLA         1.22 K      816.65 μs    ±82.49%      603.70 μs     4194.74 μs
Regular Elixir        1.17 K      857.83 μs    ±93.16%      638.37 μs     4788.49 μs

Comparison: 
JIT with EXLA         1.22 K
Regular Elixir        1.17 K - 1.05x slower +41.18 μs
%Benchee.Suite{
  system: %{
    available_memory: "32 GB",
    cpu_speed: "Apple M1 Max",
    elixir: "1.14.2",
    erlang: "25.0.2",
    num_cores: 10,
    os: :macOS
  },
  configuration: %Benchee.Configuration{
    parallel: 1,
    time: 10000000000.0,
    warmup: 2000000000.0,
    memory_time: 0.0,
    reduction_time: 0.0,
    pre_check: false,
    formatters: [Benchee.Formatters.Console],
    percentiles: '2c',
    print: %{benchmarking: true, configuration: true, fast_warning: true},
    inputs: nil,
    save: false,
    load: false,
    unit_scaling: :best,
    assigns: %{},
    before_each: nil,
    after_each: nil,
    before_scenario: nil,
    after_scenario: nil,
    measure_function_call_overhead: false,
    title: nil,
    profile_after: false
  },
  scenarios: [
    %Benchee.Scenario{
      name: "JIT with EXLA",
      job_name: "JIT with EXLA",
      function: #Function<43.3316493/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 816652.9643470439,
          ips: 1224.5103411821342,
          std_dev: 673678.5772711856,
          std_dev_ratio: 0.8249263845014342,
          std_dev_ips: 1010.1308885359956,
          median: 603704.0,
          percentiles: %{50 => 603704.0, 99 => 4194738.100000007},
          mode: 575746,
          minimum: 385122,
          maximum: 7120618,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 12229
        },
        samples: [557914, 623828, 618579, 974951, 594288, 1006659, 700953, 1535406, 635955, 1383157,
         1013452, 1186326, 638746, 607454, 713328, 1735405, 962660, 860619, 576329, 1437865,
         1174242, 1088327, 629953, 1519365, 1254324, 897036, 644078, 1455240, 1122658, 1064160,
         606538, 1453782, 871661, ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    },
    %Benchee.Scenario{
      name: "Regular Elixir",
      job_name: "Regular Elixir",
      function: #Function<43.3316493/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 857830.8879058581,
          ips: 1165.7309314673967,
          std_dev: 799114.1825977226,
          std_dev_ratio: 0.9315521204284506,
          std_dev_ips: 1085.939121057486,
          median: 638370.0,
          percentiles: %{50 => 638370.0, 99 => 4788488.199999998},
          mode: 592746,
          minimum: 361455,
          maximum: 11309296,
          relative_more: 1.0504227932261754,
          relative_less: 0.9519976208138903,
          absolute_difference: 41177.92355881422,
          sample_size: 11642
        },
        samples: [633953, 1054951, 772078, 762328, 932202, 892535, 1055992, 1127326, 784495, 721703,
         739662, 782869, 779578, 847286, 869535, 1072659, 981452, 1196241, 770995, 804369, 832160,
         740578, 766453, 965743, 782828, 1105742, 766745, 683370, 776078, 778952, 691370, 716829,
         ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    }
  ]
}