Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us
Notesclub

OR train

livebooks/axon/or_train.livemd

OR train

Mix.install([
  {:nx, "~> 0.6"},
  {:axon, "~> 0.5"},
  {:exla, "~> 0.6"},
  {:kino, "~> 0.12"},
  {:kino_vega_lite, "~> 0.1"}
])

準備

require Axon

学習データ生成

generate_train_data = fn ->
  inputs =
    1..2
    |> Enum.into(%{}, fn index ->
      {
        "input#{index}",
        1..32
        |> Enum.map(fn _ -> Enum.random(0..1) end)
        |> Nx.tensor()
        |> Nx.new_axis(1)
      }
    end)

  labels = Nx.logical_or(inputs["input1"], inputs["input2"])

  {inputs, labels}
end
generate_train_data.()
train_data =
  generate_train_data
  |> Stream.repeatedly()
  |> Enum.take(1000)
Enum.count(train_data)

モデル定義

input1 = Axon.input("input1", shape: {nil, 1})
input2 = Axon.input("input2", shape: {nil, 1})
model =
  Axon.concatenate(input1, input2)
  |> Axon.dense(8, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)
loss_plot =
  VegaLite.new(width: 300)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "step", type: :quantitative)
  |> VegaLite.encode_field(:y, "loss", type: :quantitative)
  |> Kino.VegaLite.new()

acc_plot =
  VegaLite.new(width: 300)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "step", type: :quantitative)
  |> VegaLite.encode_field(:y, "accuracy", type: :quantitative)
  |> Kino.VegaLite.new()

Kino.Layout.grid([loss_plot, acc_plot], columns: 2)
trained_state =
  model
  |> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
  |> Axon.Loop.metric(:accuracy, "accuracy")
  |> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
  |> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
  |> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 1000, compiler: EXLA)
test_datum = %{
  "input1" => Nx.tensor([[0]]),
  "input2" => Nx.tensor([[0]])
}
Axon.predict(model, trained_state, test_datum)
predict = fn model, trained_state, {input_1, input_2} ->
  %{
    "input1" => Nx.tensor([[input_1]]),
    "input2" => Nx.tensor([[input_2]])
  }
  |> then(&Axon.predict(model, trained_state, &1))
  |> then(& &1[[0, 0]])
  |> Nx.to_number()
end
predict.(model, trained_state, {0, 0})
[
  {0, 0},
  {0, 1},
  {1, 0},
  {1, 1}
]
|> Enum.map(fn {input_1, input_2} ->
  predicted_value = predict.(model, trained_state, {input_1, input_2})

  predicted_label =
    if predicted_value < 0.5 do
      0
    else
      1
    end

  %{
    "input1" => input_1,
    "input2" => input_2,
    "value" => predicted_value,
    "label" => predicted_label
  }
end)
|> Kino.DataTable.new()

推論の可視化

plot = fn trained_state, model ->
  x =
    0..99
    |> Enum.map(&amp;(&amp;1 / 100))
    |> Nx.tensor()
    |> Nx.new_axis(1)

  y = Axon.predict(model, trained_state, %{"input1" => x, "input2" => x})

  points =
    [Nx.to_flat_list(x), Nx.to_flat_list(y)]
    |> Enum.zip()
    |> Enum.map(fn {x, y} -> %{x: x, y: y} end)

  VegaLite.new(width: 600, height: 400)
  |> VegaLite.data_from_values(points)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative)
  |> Kino.VegaLite.new()
end
plot.(trained_state, model)

学習率による変化

fit = fn learning_rate, model ->
  model
  |> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.sgd(learning_rate))
  |> Axon.Loop.metric(:accuracy, "accuracy")
  |> Axon.Loop.run(train_data, %{}, epochs: 1, iterations: 1000, compiler: EXLA)
end
1..10
|> Enum.map(&amp;(&amp;1 / 100))
|> Enum.map(fn learning_rate ->
  {
    "lr=#{learning_rate}",
    learning_rate
    |> fit.(model)
    |> plot.(model)
  }
end)
|> Kino.Layout.tabs()