Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

4PL

livebooks/nx/4pl.livemd

4PL

Mix.install([
  {:nx, "~> 0.8"},
  {:exla, "~> 0.8"},
  {:kino, "~> 0.14"},
  {:kino_vega_lite, "~> 0.1"},
  {:statistics, "~> 0.6"}
])

Nx.global_default_backend(EXLA.Backend)

データ作成

x_data = Nx.iota({20})
true_y_data =
  Nx.divide(
    0.5 - 7.3,
    Nx.divide(x_data, 8)
    |> Nx.pow(2.5)
    |> Nx.add(1)
  )
  |> Nx.add(7.3)
y_data =
  true_y_data
  |> Nx.add(
    1..20
    |> Enum.map(fn _ -> Statistics.Distributions.Normal.rand() end)
    |> Nx.tensor()
    |> Nx.multiply(0.2)
  )
plot_data =
  %{
    x: x_data |> Nx.to_flat_list(),
    y: y_data |> Nx.to_flat_list()
  }

true_plot_data =
  %{
    x: x_data |> Nx.to_flat_list(),
    y: true_y_data |> Nx.to_flat_list()
  }

VegaLite.new(width: 600)
|> VegaLite.layers([
  VegaLite.new()
  |> VegaLite.data_from_values(plot_data)
  |> VegaLite.mark(:point)
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative),
  VegaLite.new()
  |> VegaLite.data_from_values(true_plot_data)
  |> VegaLite.mark(:line, color: "#ff0000")
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative)
])
|> Kino.VegaLite.new()

トレーニング

defmodule FPL do
  import Nx.Defn

  defn pred({a, b, c, d}, x) do
    (a - d) / (1.0 + Nx.pow(x / c, b)) + d
  end

  defn mse(yp, y) do
    (yp - y)
    |> Nx.pow(2)
    |> Nx.mean()
  end

  defn loss(params, x, y) do
    yp = pred(params, x)
    mse(yp, y)
  end

  defn update({a, b, c, d} = params, x, y, lr) do
    {grad_a, grad_b, grad_c, grad_d} = grad(params, &loss(&1, x, y))

    {
      a - grad_a * lr,
      b - grad_b * lr,
      c - grad_c * lr,
      d - grad_d * lr
    }
  end

  defn init_params do
    {Nx.tensor(1.0), Nx.tensor(1.0), Nx.tensor(1.0), Nx.tensor(1.0)}
  end

  def loss_update({lvs, a, b, c, d}, x, y, lr) do
    lv = FPL.loss({a, b, c, d}, x, y)
    {a, b, c, d} = FPL.update({a, b, c, d}, x, y, lr)
    {[Nx.to_number(lv) | lvs], a, b, c, d}
  end
end
loss_widget =
  VegaLite.new(width: 600)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "x", type: :quantitative, title: "epoch")
  |> VegaLite.encode_field(:y, "y", type: :quantitative, title: "loss")
  |> Kino.VegaLite.new()

fpl_widget =
  VegaLite.new(width: 600)
  |> VegaLite.layers([
    VegaLite.new()
    |> VegaLite.data_from_values(plot_data)
    |> VegaLite.mark(:point)
    |> VegaLite.encode_field(:x, "x", type: :quantitative)
    |> VegaLite.encode_field(:y, "y", type: :quantitative),
    VegaLite.new()
    |> VegaLite.data_from_values(true_plot_data)
    |> VegaLite.mark(:line, color: "#ff0000")
    |> VegaLite.encode_field(:x, "x", type: :quantitative)
    |> VegaLite.encode_field(:y, "y", type: :quantitative),
    VegaLite.new()
    |> VegaLite.mark(:line)
    |> VegaLite.encode_field(:x, "x", type: :quantitative)
    |> VegaLite.encode_field(:y, "y", type: :quantitative)
  ])
  |> Kino.VegaLite.new()

Kino.VegaLite.clear(loss_widget)
Kino.VegaLite.clear(fpl_widget)

Kino.Layout.grid([loss_widget, fpl_widget], columns: 1)
update_plots = fn {epoch, lvs, a, b, c, d} ->
  loss_plot_data =
    1..epoch
    |> Enum.zip(Enum.reverse(lvs))
    |> Enum.map(fn {x, y} -> %{x: x, y: y} end)

  Kino.VegaLite.clear(loss_widget)
  Kino.VegaLite.push_many(loss_widget, loss_plot_data)

  yl_data = FPL.pred({a, b, c, d}, x_data)

  fpl_plot_data =
    Enum.zip(
      x_data |> Nx.to_flat_list(),
      yl_data |> Nx.to_flat_list()
    )
    |> Enum.map(fn {x, y} -> %{x: x, y: y} end)

  Kino.VegaLite.clear(fpl_widget)
  Kino.VegaLite.push_many(fpl_widget, fpl_plot_data)
end
{a, b, c, d} = FPL.init_params()

epochs = 2500
lr = 0.02

Enum.reduce(1..epochs, {[], a, b, c, d}, fn epoch, acc ->
  {lvs, a, b, c, d} = FPL.loss_update(acc, x_data, y_data, lr)

  if rem(epoch, 10) == 0 do
    update_plots.({epoch, lvs, a, b, c, d})
  end

  {lvs, a, b, c, d}
end)