Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

axon

stage2/axon.livemd

axon

Mix.install(
  [
    {:nx, "~>0.6.1"},
    {:bumblebee, github: :"elixir-nx/bumblebee"},
    {:kino_vega_lite, "~> 0.1.7"},
    {:httpoison, "~> 1.8"},
    {:xla, "~> 0.5.1", XLA_BUILD: true},
    {:exla, "~> 0.6.1"},
    {:adbc, "~> 0.1"},
    {:kino, "~> 0.10.0"}
  ],
  config: [
    nx: [default_backend: EXLA.Backend]
  ]
)

Section

defmodule View do
  def plotxy(x1, y1, x2 \\ [], y2 \\ []) do
    VegaLite.new(width: 400, height: 400)
    |> VegaLite.layers([
      VegaLite.new()
      |> VegaLite.data_from_values(x: x1, y: y1)
      |> VegaLite.mark(:point, tooltip: true)
      |> VegaLite.encode_field(:x, "x", type: :quantitative)
      |> VegaLite.encode_field(:y, "y", type: :quantitative),
      VegaLite.new()
      |> VegaLite.data_from_values(x: x2, y: y2)
      |> VegaLite.mark(:line)
      |> VegaLite.encode_field(:x, "x", type: :quantitative)
      |> VegaLite.encode_field(:y, "y", type: :quantitative)
    ])
  end
end
defmodule Basic do
  require Axon

  def build_model(input_shape) do
    inp1 = Axon.input("x", shape: input_shape)

    inp1
    |> Axon.dense(1)
  end

  defp batch do
    x = Nx.tensor(for _ <- 1..100, do: [:rand.uniform()])

    y =
      Nx.multiply(x, 2)
      |> Nx.add(0.5)
      |> Nx.add(Nx.tensor(for _ <- 1..100, do: [:rand.uniform()]) |> Nx.multiply(0.5))

    {x, y}
  end

  defp train_model(model, data, epochs) do
    model
    |> Axon.Loop.trainer(:mean_squared_error, :sgd)
    |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
  end

  def run() do
    model = build_model({nil})
    input_data_100 = batch()

    data =
      Stream.repeatedly(fn ->
        input_data_100
      end)

    model_state =
      train_model(model, data, 10)
      |> IO.inspect(label: "model_state")

    {x1_tensor, y1_tensor} = input_data_100
    x2_tensor = Nx.tensor(for i <- 0..100, do: [i / 100])
    y2_tensor = Axon.predict(model, model_state, %{"x" => x2_tensor})

    View.plotxy(
      Nx.to_flat_list(x1_tensor),
      Nx.to_flat_list(y1_tensor),
      Nx.to_flat_list(x2_tensor),
      Nx.to_flat_list(y2_tensor)
    )
  end
end

Basic.run()
System.version()
System.otp_release()
IEx.Helpers.runtime_info()
x1_input = Axon.input("x1", shape: {nil, 1})
x2_input = Axon.input("x2", shape: {nil, 1})

model =
  x1_input
  |> Axon.concatenate(x2_input)
  |> Axon.dense(8, activation: :tanh)
  |> Axon.dense(1, activation: :sigmoid)

batch_size = 32

data =
  Stream.repeatedly(fn ->
    x1 = Nx.random_uniform({batch_size, 1}, 0, 2)
    x2 = Nx.random_uniform({batch_size, 1}, 0, 2)
    y = Nx.logical_xor(x1, x2)
    {%{"x1" => x1, "x2" => x2}, y}
  end)

epochs = 10

params =
  model
  |> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
  |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)

Axon.predict(model, params, %{
  "x1" => Nx.tensor([[1]]),
  "x2" => Nx.tensor([[1]])
})
Axon.predict(model, params, %{
  "x1" => Nx.tensor([[0]]),
  "x2" => Nx.tensor([[1]])
})
Axon.predict(model, params, %{
  "x1" => Nx.tensor([[0]]),
  "x2" => Nx.tensor([[0]])
})
defmodule Sin do
  require Axon

  def build_model(input_shape) do
    inp1 = Axon.input("x", shape: input_shape)

    inp1
    |> Axon.dense(8, activation: :tanh)
    |> Axon.dense(1, activation: :tanh)
  end

  defp batch do
    x = Nx.tensor(for x <- 1..100, do: [x / 15])
    y = Nx.tensor(for x <- 1..100, do: [:math.sin(x / 15)])
    {x, y}
  end

  defp train_model(model, data, epochs) do
    model
    |> Axon.Loop.trainer(:mean_squared_error, :sgd)
    |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
  end

  def run() do
    model = build_model({nil})
    input_data_100 = batch()

    data =
      Stream.repeatedly(fn ->
        input_data_100
      end)

    model_state =
      train_model(model, data, 10)
      |> IO.inspect(label: "model_state")

    {x1_tensor, y1_tensor} = input_data_100
    x2_tensor = Nx.tensor(for i <- 0..100, do: [i / 15])
    y2_tensor = Axon.predict(model, model_state, %{"x" => x2_tensor})

    View.plotxy(
      Nx.to_flat_list(x1_tensor),
      Nx.to_flat_list(y1_tensor),
      Nx.to_flat_list(x2_tensor),
      Nx.to_flat_list(y2_tensor)
    )
  end
end

Sin.run()