4PL
Mix.install([
{:nx, "~> 0.9"},
{:exla, "~> 0.9"},
{:kino, "~> 0.14"},
{:kino_vega_lite, "~> 0.1"},
{:statistics, "~> 0.6"}
])
Nx.global_default_backend(EXLA.Backend)
データ作成
x_data = Nx.iota({20})
true_y_data =
Nx.divide(
0.5 - 7.3,
Nx.divide(x_data, 8)
|> Nx.pow(2.5)
|> Nx.add(1)
)
|> Nx.add(7.3)
y_data =
true_y_data
|> Nx.add(
1..20
|> Enum.map(fn _ -> Statistics.Distributions.Normal.rand() end)
|> Nx.tensor()
|> Nx.multiply(0.2)
)
plot_data =
%{
x: x_data |> Nx.to_flat_list(),
y: y_data |> Nx.to_flat_list()
}
true_plot_data =
%{
x: x_data |> Nx.to_flat_list(),
y: true_y_data |> Nx.to_flat_list()
}
VegaLite.new(width: 600)
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.data_from_values(plot_data)
|> VegaLite.mark(:point)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative),
VegaLite.new()
|> VegaLite.data_from_values(true_plot_data)
|> VegaLite.mark(:line, color: "#ff0000")
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative)
])
|> Kino.VegaLite.new()
トレーニング
defmodule FPL do
import Nx.Defn
defn pred({a, b, c, d}, x) do
(a - d) / (1.0 + Nx.pow(x / c, b)) + d
end
defn mse(yp, y) do
(yp - y)
|> Nx.pow(2)
|> Nx.mean()
end
defn loss(params, x, y) do
yp = pred(params, x)
mse(yp, y)
end
defn update({a, b, c, d} = params, x, y, lr) do
{grad_a, grad_b, grad_c, grad_d} = grad(params, &loss(&1, x, y))
{
a - grad_a * lr,
b - grad_b * lr,
c - grad_c * lr,
d - grad_d * lr
}
end
defn init_params do
{Nx.tensor(1.0), Nx.tensor(1.0), Nx.tensor(1.0), Nx.tensor(1.0)}
end
def loss_update({lvs, a, b, c, d}, x, y, lr) do
lv = FPL.loss({a, b, c, d}, x, y)
{a, b, c, d} = FPL.update({a, b, c, d}, x, y, lr)
{[Nx.to_number(lv) | lvs], a, b, c, d}
end
end
loss_widget =
VegaLite.new(width: 600)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "x", type: :quantitative, title: "epoch")
|> VegaLite.encode_field(:y, "y", type: :quantitative, title: "loss")
|> Kino.VegaLite.new()
fpl_widget =
VegaLite.new(width: 600)
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.data_from_values(plot_data)
|> VegaLite.mark(:point)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative),
VegaLite.new()
|> VegaLite.data_from_values(true_plot_data)
|> VegaLite.mark(:line, color: "#ff0000")
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative),
VegaLite.new()
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative)
])
|> Kino.VegaLite.new()
Kino.VegaLite.clear(loss_widget)
Kino.VegaLite.clear(fpl_widget)
Kino.Layout.grid([loss_widget, fpl_widget], columns: 1)
update_plots = fn {epoch, lvs, a, b, c, d} ->
loss_plot_data =
1..epoch
|> Enum.zip(Enum.reverse(lvs))
|> Enum.map(fn {x, y} -> %{x: x, y: y} end)
Kino.VegaLite.clear(loss_widget)
Kino.VegaLite.push_many(loss_widget, loss_plot_data)
yl_data = FPL.pred({a, b, c, d}, x_data)
fpl_plot_data =
Enum.zip(
x_data |> Nx.to_flat_list(),
yl_data |> Nx.to_flat_list()
)
|> Enum.map(fn {x, y} -> %{x: x, y: y} end)
Kino.VegaLite.clear(fpl_widget)
Kino.VegaLite.push_many(fpl_widget, fpl_plot_data)
end
{a, b, c, d} = FPL.init_params()
epochs = 2500
lr = 0.02
Enum.reduce(1..epochs, {[], a, b, c, d}, fn epoch, acc ->
{lvs, a, b, c, d} = FPL.loss_update(acc, x_data, y_data, lr)
if rem(epoch, 10) == 0 do
update_plots.({epoch, lvs, a, b, c, d})
end
{lvs, a, b, c, d}
end)