Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

NxとVegaの基礎

vega_basics.livemd

NxとVegaの基礎

import IEx.Helpers

Mix.install([
  {:nx, "~> 0.4.0"},
  {:exla, "~> 0.4.0"},
  {:kino_vega_lite, "~> 0.1.7"}
])

alias VegaLite, as: Vl

概要

資料

Nx の基本

exp

Nx.exp(2)

log

Nx.exp(1) |> Nx.log()

sin

# これが正しいらしい
:math.pi() |> :math.sin() |> Nx.tensor() |> dbg()

# これの結果がなぜ正しくないのか謎
# :math.pi() |> Nx.sin() |> dbg()

:ok

sqrt

Nx.sqrt(3)

1 次元配列

  • 添字で値を取得

array

a = Nx.tensor([2, 3, 5, 7, 8])

a[0] |> dbg()
a[1..2] |> dbg()
a[-2..-1//1] |> dbg()
a[2..-2//1] |> dbg()

:ok

iota

Nx.iota({5})
arange = fn start, stop, step ->
  how_many = trunc((stop - start) / step)

  Nx.iota({how_many})
  |> Nx.multiply(step)
  |> Nx.add(start)
end

arange.(1, 3, 0.2)
linspace = fn start, stop, how_many ->
  step = (stop - start) / how_many

  Nx.iota({how_many})
  |> Nx.multiply(step)
  |> Nx.add(start)
end

linspace.(1, 3, 10)

###

2 次元配列

a =
  Nx.tensor([
    [2, 3, 4],
    [5, 6, 7]
  ])

a[0][1] |> dbg()

# 縦方向で値を取得
a[[0..1, 1]] |> dbg()
a[[0..-1//1, 1]] |> dbg()

# 横方向で値を取得
a[1] |> dbg()
a[1][1..2] |> dbg()
a[1][1..-1//1] |> dbg()

:ok

形状の変更

import Nx, only: :sigils

a = ~M"""
0 1 2 3 4
5 6 7 8 9
10 11 12 13 14
"""

# 行列の形
Nx.shape(a) |> dbg

# 行列の深さ
Nx.rank(a) |> dbg

# 要素数
Nx.size(a) |> dbg

:ok
Nx.iota({15})
|> Nx.reshape({3, 5})
|> Nx.flatten()
|> dbg()

:ok
# 縦方向に変換
b = Nx.iota({4}) |> Nx.reshape({4, 1}) |> dbg()

:ok

zeros and ones

# Pythonのnp.zerosに相当
Nx.random_normal({3, 4}, 0.0, 0.0)
# Pythonのnp.onesに相当
Nx.random_normal({2, 2}, 1.0, 0.0)

行列の連結

a = Nx.iota({2, 3})
b = a |> Nx.add(6) |> dbg()

# 縦方向に連結
Nx.concatenate([a, b]) |> dbg()

# 横方向に連結
Nx.concatenate([a, b], axis: 1) |> dbg()

:ok
c = Nx.iota({3})
d = c |> Nx.add(3) |> dbg()

Nx.concatenate([c, d]) |> dbg()

Nx.concatenate(
  [
    Nx.reshape(c, {3, 1}),
    Nx.reshape(d, {3, 1})
  ],
  axis: 1
)
|> dbg()

:ok

形状が違うものを連結する場合はnew_axisで軸を足すか、reshapeで変形して合わせる。

a = Nx.iota({2, 3})
b = Nx.iota({3})

Nx.concatenate([
  a |> dbg(),
  b |> Nx.new_axis(0) |> dbg()
])
|> dbg()

Nx.concatenate([
  a |> dbg(),
  b |> Nx.reshape({1, 3}) |> dbg()
])
|> dbg()

:ok

配列操作の基本

a = Nx.iota({5})
b = Nx.iota({3, 3})

# 1次元配列の合計
a |> Nx.sum() |> dbg()

# 多次元配列の合計
b |> Nx.sum() |> dbg()

# 多次元配列の縦軸で合計
b |> Nx.sum(axes: [0]) |> dbg()

# 多次元配列の横軸で合計
b |> Nx.sum(axes: [1]) |> dbg()

:ok
a = Nx.iota({5})

# 平均
a |> Nx.mean() |> dbg()

:ok
a = Nx.iota({5})

a |> Nx.reduce_min() |> dbg()
a |> Nx.reduce_max() |> dbg()

:ok

ブロードキャスト

# 1次元配列
a = Nx.iota({5}) |> Nx.add(3)
a |> Nx.exp() |> dbg()
a |> Nx.log() |> dbg()
a |> Nx.sqrt() |> dbg()

# 多次元配列
b = Nx.iota({3, 3})
b |> Nx.exp() |> dbg()

:ok

配列とスカラの演算

# 1次元配列
a = Nx.iota({5})
a |> Nx.add(3) |> dbg()
a |> Nx.multiply(3) |> dbg()
a |> Nx.power(2) |> dbg()
a |> Nx.greater_equal(2) |> dbg()
a |> Nx.not_equal(3) |> dbg()

# 多次元配列
b = Nx.iota({3, 3})
b |> Nx.greater(3) |> dbg()

:ok

線形代数

Nx.tensor([
  [3, 1, 1],
  [1, 2, 1],
  [0, -1, 1]
])
|> Nx.LinAlg.invert()
|> dbg()

:ok
Nx.tensor([
  [3, 1, 1],
  [1, 2, 1],
  [0, -1, 1]
])
|> Nx.LinAlg.solve(Nx.tensor([1, 2, 3]))
|> dbg()

:ok
## LU分解により連立方程式を解く
# Note: XLA does not currently support the LU operation

a =
  Nx.tensor([
    [3, 1, 1],
    [1, 2, 1],
    [0, -1, 1]
  ])

b = Nx.tensor([1, 2, 3])

{_p, l, u} = Nx.LinAlg.lu(a) |> dbg()

l |> Nx.dot(u) |> Nx.LinAlg.solve(b) |> dbg()

:ok

乱数

こちらは第2引数から第3引数までの間の値をランダムに返します マイナスの範囲にしたい場合は第2引数をマイナス値にすれば大丈夫です

key = Nx.Random.key(1701)

Nx.Random.uniform(key) |> dbg()
Nx.Random.uniform(key, shape: {3, 2}, type: :f16) |> dbg()
Nx.Random.uniform(key, 0, 5) |> dbg()
Nx.Random.uniform(key, 0, 5, shape: {3, 3}) |> dbg()

:ok

折れ線グラフ

Vega Lite の値は Nx は対応していないのでリストで渡す必要があります

x = [0, 1, 2, 3]
y = [3, 7, 4, 8]

Vl.new(width: 200, height: 200)
|> Vl.data_from_values(x: x, y: y)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)

散布図

x = [0, 1, 2, 3]
y = [3, 7, 4, 8]

Vl.new(width: 200, height: 200)
|> Vl.data_from_values(x: x, y: y)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)

曲線のグラフ

  • 階差数列(linspace)がないので値から逆算して iota で作る
x = linspace.(-5, 5, 300) |> Nx.reshape({1, 300}) |> dbg()
y = Nx.power(x, 2) |> dbg()

Vl.new(width: 200, height: 200)
|> Vl.data_from_values(x: Nx.to_flat_list(x), y: Nx.to_flat_list(y))
|> Vl.mark(:line)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)

複数の線を表示する

複数の線を1つのグラフに描画する場合はlayers関数を使用

x = linspace.(-5, 5, 300) |> Nx.reshape({1, 300})
y1 = Nx.power(x, 2) |> Nx.to_flat_list()
y2 = Nx.subtract(x, 2) |> Nx.power(2) |> Nx.to_flat_list()

Vl.new(width: 200, height: 200)
|> Vl.data_from_values(x: Nx.to_flat_list(x), y1: y1, y2: y2)
|> Vl.layers([
  Vl.new()
  |> Vl.mark(:line, color: :red)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y1", type: :quantitative),
  Vl.new()
  |> Vl.mark(:line, color: :black, stroke_dash: [6, 4])
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y2", type: :quantitative)
])

ヒストグラム

## サイコロを10回振って目の合計を計算するを1000回行う

x =
  Nx.iota({1000})
  |> Nx.map(fn _ ->
    Nx.tensor(for _ <- 0..9, do: Enum.random(1..6))
    |> Nx.sum()
  end)
  |> Nx.to_flat_list()

y = Nx.iota({1000}) |> Nx.to_flat_list()

Vl.new(width: 200, height: 200)
|> Vl.data_from_values(x: x, y: y)
|> Vl.mark(:bar, color: :gray)
|> Vl.encode_field(:x, "x", type: :quantitative, bin: %{maxbins: 20})
|> Vl.encode_field(:y, "y", type: :quantitative, aggregate: :count)

複数のグラフを並べて表示する

  • 複数のグラフを並べて表示する場合はconcat関数を使う
  • デフォルトは横方向に並べられる(:horizontal
  • :vertical を指定すると縦方向に並べられる
x = linspace.(-5, 5, 300)
sin_x = Nx.sin(x) |> Nx.to_flat_list()
cos_x = Nx.cos(x) |> Nx.to_flat_list()

# 縦方向
Vl.new(width: 400, height: 200)
|> Vl.data_from_values(x: Nx.to_flat_list(x), y: sin_x, y2: cos_x)
|> Vl.transform(filter: %{field: "y", range: [-1.5, 1.5]})
|> Vl.transform(filter: %{field: "y2", range: [-1.5, 1.5]})
|> Vl.concat(
  [
    Vl.new(width: 600)
    |> Vl.mark(:line, color: :red)
    |> Vl.encode_field(:x, "x", type: :quantitative)
    |> Vl.encode_field(:y, "y", type: :quantitative),
    Vl.new(width: 600)
    |> Vl.mark(:line, color: :black)
    |> Vl.encode_field(:x, "x", type: :quantitative)
    |> Vl.encode_field(:y, "y2", type: :quantitative)
  ],
  :vertical
)
x = linspace.(-5, 5, 300)
sin_x = Nx.sin(x) |> Nx.to_flat_list()
cos_x = Nx.cos(x) |> Nx.to_flat_list()

# 横方向
Vl.new(width: 400, height: 200)
|> Vl.data_from_values(x: Nx.to_flat_list(x), y: sin_x, y2: cos_x)
|> Vl.transform(filter: %{field: "y", range: [-1.5, 1.5]})
|> Vl.transform(filter: %{field: "y2", range: [-1.5, 1.5]})
|> Vl.concat(
  [
    Vl.new(width: 600)
    |> Vl.mark(:line, color: :red)
    |> Vl.encode_field(:x, "x", type: :quantitative)
    |> Vl.encode_field(:y, "y", type: :quantitative),
    Vl.new(width: 600)
    |> Vl.mark(:line, color: :black)
    |> Vl.encode_field(:x, "x", type: :quantitative)
    |> Vl.encode_field(:y, "y2", type: :quantitative)
  ],
  :horizontal
)