AxonのHello world
import IEx.Helpers
Mix.install(
[
{:nx, "~> 0.4.0"},
{:exla, "~> 0.4.0"},
{:axon, "~> 0.3.0"},
{:kino_vega_lite, "~> 0.1.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
About
-
Axonを使って線形回帰のパラメータ(傾きと切片)の値を求める
-
最も簡単な、
y = ax + b
にフィッテング
Resources
処理の流れ
-
y = 2 * x + 0.5 + 乱数
の{x,y}
の組を100個つくる
-
train_model()
で、フィッティングする
-
結果表示
Everything together
defmodule Main do
require Axon
## Model
def build_model(input_shape) do
# Axon.dense は、input * kernel + bias の計算をする
Axon.input("x", shape: input_shape)
|> Axon.dense(1)
end
## Data
def build_batch(n \\ 100) do
x = generate_random(n)
y = generate_data(x, n)
{x, y}
end
def generate_data(x, n) do
# f(x) = 2 * x + 0.5 + 乱数
Nx.multiply(x, 2)
|> Nx.add(0.5)
|> Nx.add(generate_random(n))
end
def generate_random(n) do
Nx.tensor(for _ <- 1..n, do: [:rand.uniform()])
end
## Training
def train_model(model, data, epochs \\ 5) do
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
end
## Evaluation of the model
def main do
# 学習データの準備
{x1, y1} = build_batch()
data = Stream.repeatedly(fn -> {x1, y1} end)
model = build_model({nil})
# XとYの関係性を学習
model_state = train_model(model, data) |> dbg()
# 予測
x2 = Nx.tensor(for i <- 0..100, do: [i / 100])
y2 = Axon.predict(model, model_state, %{"x" => x2})
{{x1, y1}, {x2, y2}}
end
end
{{x1, y1}, {x2, y2}} = Main.main()
:ok
Visualize the results
data_to_plot = %{
x1: Nx.to_flat_list(x1),
y1: Nx.to_flat_list(y1),
x2: Nx.to_flat_list(x2),
y2: Nx.to_flat_list(y2)
}
VegaLite.new(width: 600, height: 600)
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.data_from_values(x: data_to_plot.x1, y: data_to_plot.y1)
|> VegaLite.mark(:point, tooltip: true)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative),
VegaLite.new()
|> VegaLite.data_from_values(x: data_to_plot.x2, y: data_to_plot.y2)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative)
])