OR train
Mix.install([
{:nx, "~> 0.9", override: true},
{:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
{:exla, "~> 0.9"},
{:kino, "~> 0.14"},
{:kino_vega_lite, "~> 0.1"}
])
準備
require Axon
学習データ生成
generate_train_data = fn ->
inputs =
1..2
|> Enum.into(%{}, fn index ->
{
"input#{index}",
1..32
|> Enum.map(fn _ -> Enum.random(0..1) end)
|> Nx.tensor()
|> Nx.new_axis(1)
}
end)
labels = Nx.logical_or(inputs["input1"], inputs["input2"])
{inputs, labels}
end
generate_train_data.()
train_data =
generate_train_data
|> Stream.repeatedly()
|> Enum.take(1000)
Enum.count(train_data)
モデル定義
input1 = Axon.input("input1", shape: {nil, 1})
input2 = Axon.input("input2", shape: {nil, 1})
model =
Axon.concatenate(input1, input2)
|> Axon.dense(8, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
loss_plot =
VegaLite.new(width: 300)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "step", type: :quantitative)
|> VegaLite.encode_field(:y, "loss", type: :quantitative)
|> Kino.VegaLite.new()
acc_plot =
VegaLite.new(width: 300)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "step", type: :quantitative)
|> VegaLite.encode_field(:y, "accuracy", type: :quantitative)
|> Kino.VegaLite.new()
Kino.Layout.grid([loss_plot, acc_plot], columns: 2)
trained_state =
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy, "accuracy")
|> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
|> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
|> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 5, iterations: 1000, compiler: EXLA)
test_datum = %{
"input1" => Nx.tensor([[0]]),
"input2" => Nx.tensor([[0]])
}
Axon.predict(model, trained_state, test_datum)
predict = fn model, trained_state, {input_1, input_2} ->
%{
"input1" => Nx.tensor([[input_1]]),
"input2" => Nx.tensor([[input_2]])
}
|> then(&Axon.predict(model, trained_state, &1))
|> then(& &1[[0, 0]])
|> Nx.to_number()
end
predict.(model, trained_state, {0, 0})
[
{0, 0},
{0, 1},
{1, 0},
{1, 1}
]
|> Enum.map(fn {input_1, input_2} ->
predicted_value = predict.(model, trained_state, {input_1, input_2})
predicted_label =
if predicted_value < 0.5 do
0
else
1
end
%{
"input1" => input_1,
"input2" => input_2,
"value" => predicted_value,
"label" => predicted_label
}
end)
|> Kino.DataTable.new()
推論の可視化
plot = fn trained_state, model ->
x =
0..99
|> Enum.map(&(&1 / 100))
|> Nx.tensor()
|> Nx.new_axis(1)
y = Axon.predict(model, trained_state, %{"input1" => x, "input2" => x})
points =
[Nx.to_flat_list(x), Nx.to_flat_list(y)]
|> Enum.zip()
|> Enum.map(fn {x, y} -> %{x: x, y: y} end)
VegaLite.new(width: 600, height: 400)
|> VegaLite.data_from_values(points)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative)
|> Kino.VegaLite.new()
end
plot.(trained_state, model)
学習率による変化
fit = fn learning_rate, model ->
model
|> Axon.Loop.trainer(:binary_cross_entropy, Polaris.Optimizers.sgd(learning_rate: learning_rate))
|> Axon.Loop.metric(:accuracy, "accuracy")
|> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 1, iterations: 1000, compiler: EXLA)
end
1..10
|> Enum.map(&(&1 / 100))
|> Enum.map(fn learning_rate ->
{
"lr=#{learning_rate}",
learning_rate
|> fit.(model)
|> plot.(model)
}
end)
|> Kino.Layout.tabs()