Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

AxonのHello world

axon_basics.livemd

AxonのHello world

import IEx.Helpers

Mix.install(
  [
    {:nx, "~> 0.4.0"},
    {:exla, "~> 0.4.0"},
    {:axon, "~> 0.3.0"},
    {:kino_vega_lite, "~> 0.1.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

About

  • Axonを使って線形回帰のパラメータ(傾きと切片)の値を求める
  • 最も簡単な、y = ax + bにフィッテング

Resources

処理の流れ

  • y = 2 * x + 0.5 + 乱数{x,y}の組を100個つくる
  • train_model()で、フィッティングする
  • 結果表示

Everything together

defmodule Main do
  require Axon

  ## Model

  def build_model(input_shape) do
    # Axon.dense は、input * kernel + bias の計算をする
    Axon.input("x", shape: input_shape)
    |> Axon.dense(1)
  end

  ## Data

  def build_batch(n \\ 100) do
    x = generate_random(n)
    y = generate_data(x, n)

    {x, y}
  end

  def generate_data(x, n) do
    # f(x) = 2 * x + 0.5 + 乱数
    Nx.multiply(x, 2)
    |> Nx.add(0.5)
    |> Nx.add(generate_random(n))
  end

  def generate_random(n) do
    Nx.tensor(for _ <- 1..n, do: [:rand.uniform()])
  end

  ## Training

  def train_model(model, data, epochs \\ 5) do
    model
    |> Axon.Loop.trainer(:mean_squared_error, :sgd)
    |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
  end

  ## Evaluation of the model

  def main do
    # 学習データの準備
    {x1, y1} = build_batch()
    data = Stream.repeatedly(fn -> {x1, y1} end)

    model = build_model({nil})

    # XとYの関係性を学習
    model_state = train_model(model, data) |> dbg()

    # 予測
    x2 = Nx.tensor(for i <- 0..100, do: [i / 100])
    y2 = Axon.predict(model, model_state, %{"x" => x2})

    {{x1, y1}, {x2, y2}}
  end
end

{{x1, y1}, {x2, y2}} = Main.main()

:ok

Visualize the results

data_to_plot = %{
  x1: Nx.to_flat_list(x1),
  y1: Nx.to_flat_list(y1),
  x2: Nx.to_flat_list(x2),
  y2: Nx.to_flat_list(y2)
}

VegaLite.new(width: 600, height: 600)
|> VegaLite.layers([
  VegaLite.new()
  |> VegaLite.data_from_values(x: data_to_plot.x1, y: data_to_plot.y1)
  |> VegaLite.mark(:point, tooltip: true)
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative),
  VegaLite.new()
  |> VegaLite.data_from_values(x: data_to_plot.x2, y: data_to_plot.y2)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative)
])