Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

SAC Multi-Mark

livebooks/sac_multi_mark.livemd

SAC Multi-Mark

my_app_root = Path.join(__DIR__, "..")

Mix.install(
  [
    {:boat_learner, path: my_app_root, env: :dev}
  ],
  config_path: Path.join(my_app_root, "config/config.exs"),
  lockfile: Path.join(my_app_root, "mix.lock"),
  consolidate_protocols: false,
  system_env: %{"XLA_TARGET" => System.get_env("XLA_TARGET", "cpu")},
  force: true
)

Plot

alias VegaLite, as: Vl

{min_x, max_x, min_y, max_y} = BoatLearner.Environments.MultiMark.bounding_box()

r = 250

pi = :math.pi()
theta = 42.7 / 180 * pi

# valid_points = [[0, r], [0, 0], [0, -r]]
valid_points = [[0, r], [0, 0]]

# valid_points = [[0, 0], [r * :math.tan(theta), r]]

# points =
#   for start <- valid_points, target <- valid_points, start != target do
#     start ++ target
#   end

points = [[0, 0, 0, r]]

# for angle <- [0, theta, -theta, pi, pi + theta, pi - theta, pi / 2, -pi / 2] do
#   [
#     Float.round(r * :math.cos(angle + pi / 2), 4),
#     Float.round(r * :math.sin(angle + pi / 2), 4)
#   ]
# end

points_tensor = Nx.tensor(points)
points_probs = :uniform

width = height = 750

scale_max_x = max_x / 250
scale_min_x = min_x / -250

alfa = :math.tan(theta)
step = 5
num_lines = 5
x = Enum.map(-(step * num_lines)..(step * num_lines), fn k -> alfa * r * k / step end)
y = List.duplicate(0, length(x) * 4)

x2 =
  Enum.flat_map(x, fn x1 ->
    [x1 + alfa * r, x1 - alfa * r, x1 + alfa * r, x1 - alfa * r]
  end)

y2 = Enum.flat_map(x, fn _ -> [r, r, -r, -r] end)

x =
  Enum.flat_map(x, fn x ->
    [x, x, x, x]
  end)

{x, x2, y, y2} =
  Enum.flat_map(Enum.zip([x, x2, y, y2]), fn {x, x2, y, y2} ->
    [
      {x, x2, y, y2},
      {x, x2, y2, y}
    ]
  end)
  |> Enum.reduce({[], [], [], []}, fn {x, x2, y, y2}, {xacc, x2acc, yacc, y2acc} ->
    {[x | xacc], [x2 | x2acc], [y | yacc], [y2 | y2acc]}
  end)

arc = fn r_scale ->
  Vl.new()
  |> Vl.data_from_values(%{x: [0], y: [0]},
    name: "arc_grid"
  )
  |> Vl.mark(:arc,
    clip: true,
    radius: (r * r_scale + 1) * height / (max_y - min_y),
    radius2: (r * r_scale - 1) * height / (max_y - min_y),
    theta: 2 * :math.pi(),
    color: "#ddd",
    opacity: 0.75
  )
  |> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x]])
  |> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y]])
end

target_point_radius = 5
target_point_radius_px = (max_y - min_y) / height * target_point_radius
target_point_size = pi ** 2 * target_point_radius_px

target_layer =
  Vl.new()
  |> Vl.data(name: "target")
  |> Vl.mark(:point,
    size: target_point_size,
    opacity: %{expr: "if(datum.is_target == 1, 1, 0.75)"},
    size: %{expr: "if(datum.is_target == 1, #{target_point_size}, #{target_point_size / 2})"},
    filled: true,
    tooltip: [content: "data"]
  )
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative)
  |> Vl.encode(:color,
    condition: %{test: "datum.is_target == 1", value: "red"},
    value: "#333"
  )

grid_widget =
  Vl.new(width: width, height: height)
  |> Vl.layers([
    Vl.new()
    |> Vl.data_from_values(%{x: x, y: y, x2: x2, y2: y2}, name: "diagonal_grid")
    |> Vl.mark(:rule, clip: true, color: "#ddd", size: 2, opacity: 0.75)
    |> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x]])
    |> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y]])
    |> Vl.encode_field(:x2, "x2", type: :quantitative)
    |> Vl.encode_field(:y2, "y2", type: :quantitative),
    arc.(1),
    arc.(0.5),
    arc.(0.25),
    target_layer,
    Vl.new()
    |> Vl.data(name: "trajectory")
    |> Vl.mark(:line, point: true, clip: true, opacity: 1, tooltip: [content: "data"])
    |> Vl.encode_field(:x, "x",
      legend: false,
      type: :quantitative,
      scale: [domain: [min_x, max_x]]
    )
    |> Vl.encode_field(:y, "y",
      legend: false,
      type: :quantitative,
      scale: [domain: [min_y, max_y]]
    )
    |> Vl.encode_field(:color, "episode",
      legend: false,
      type: :nominal,
      scale: [scheme: "blues"],
      legend: false
    )
    |> Vl.encode_field(:order, "index", legend: false)
  ])
  |> Kino.VegaLite.new()

value_widget_fn = fn title, plot_median ->
  transform =
    if plot_median do
      &amp;Vl.transform(&amp;1,
        frame: [-30, 0],
        window: [
          [
            field: "y",
            op: "median",
            as: "rolling_median"
          ]
        ]
      )
    else
      &amp;Function.identity/1
    end

  Vl.new(width: width, height: div(height, 2), title: title)
  |> Vl.data(name: "values")
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:point,
      grid: true,
      tooltip: [content: "data"],
      opacity: 0.25
    )
    |> Vl.encode_field(:color, "source", type: :nominal)
    |> Vl.encode_field(:x, "x", legend: false, type: :quantitative)
    |> Vl.encode_field(:order, "x", legend: false)
    |> Vl.encode_field(:y, "y",
      legend: false,
      type: :quantitative,
      scale: [
        domain: %{expr: "[min_y, max_y]"},
        clamp: true
      ],
      axis: [tick_count: 10]
    ),
    Vl.new()
    |> Vl.mark(:line, grid: true)
    |> transform.()
    |> Vl.encode_field(:color, "source", type: :nominal)
    |> Vl.encode_field(:x, "x", legend: false, type: :quantitative)
    |> Vl.encode_field(:order, "x", legend: false)
    |> Vl.encode_field(:y, if(plot_median, do: "rolling_median", else: "y"),
      legend: false,
      type: :quantitative,
      scale: [
        domain: %{expr: "[min_y, max_y]"},
        clamp: true
      ]
    )
  ])
  |> Vl.param("max_y", value: 5, bind: [input: "number"])
  |> Vl.param("min_y", value: 0, bind: [input: "number"])
  |> Kino.VegaLite.new()
end

loss_widget = value_widget_fn.("Loss", true)
reward_widget = value_widget_fn.("Total Reward", true)
misc_value_widget = value_widget_fn.("Misc. Values", false)

global_kv_key = :rl_training_metadata_key
data = %{optimizer_steps: nil, total_reward: nil}
:persistent_term.put(global_kv_key, data)

metadata_widget = Kino.Frame.new()

Kino.listen(250, nil, fn _, prev_data ->
  data = :persistent_term.get(global_kv_key)

  if data != prev_data do
    Kino.Frame.render(
      metadata_widget,
      [
        "| Field | Value |\n",
        "| ----- | ----- |\n",
        Enum.map(data, fn {k, v} ->
          ["|", to_string(k), " | ", if(v, do: to_string(v), else: " "), " |\n"]
        end)
      ]
      |> IO.iodata_to_binary()
      |> Kino.Markdown.new()
    )
  end

  {:cont, data}
end)

simulation_widget =
  Kino.Layout.grid([
    Kino.Layout.grid([grid_widget], boxed: true),
    Kino.Layout.grid([metadata_widget], boxed: true)
  ])

training_widget =
  Kino.Layout.grid([
    Kino.Layout.grid([loss_widget], boxed: true),
    Kino.Layout.grid([reward_widget], boxed: true),
    Kino.Layout.grid([misc_value_widget], boxed: true)
  ])

Kino.Layout.tabs([{"Simulation", simulation_widget}, {"Training", training_widget}])
defmodule AccumulateAndExportData do
  use GenServer
  import Nx.Defn
  import Nx.Constants

  defnp rad_to_deg(angle) do
    angle = angle * 180 / pi()
    Nx.select(angle > 180, angle - 360, angle)
  end

  def state_to_trajectory_entry(%{environment_state: env, agent_state: _agent}) do
    distance_to_mark =
      Nx.sqrt(
        Nx.pow(Nx.subtract(env.target_y, env.y), 2)
        |> Nx.add(Nx.pow(Nx.subtract(env.target_x, env.x), 2))
      )

    Nx.stack([
      env.x,
      env.y,
      env.remaining_seconds,
      env.reward,
      rad_to_deg(env.heading),
      env.speed,
      env.vmg,
      distance_to_mark,
      env.target_x,
      env.target_y,
      env.is_terminal
    ])
  end

  def init(_opts), do: {:ok, %{trajectories: [], marks: %{x: [], y: [], epoch: [], index: []}}}

  def start_link(opts \\ []), do: GenServer.start_link(__MODULE__, opts, name: __MODULE__)

  def reset, do: GenServer.cast(__MODULE__, :reset)

  def handle_cast(:reset, _) do
    {:ok, state} = init([])
    {:noreply, state}
  end

  def handle_cast({:add_epoch, _epoch, 0, _trajectory_tensor}, state) do
    {:noreply, state}
  end

  def handle_cast({:add_epoch, epoch, iterations, trajectory_tensor}, state) do
    trajectory_data =
      trajectory_tensor
      |> Nx.revectorize([x: :auto], target_shape: trajectory_tensor.shape)
      |> Nx.devectorize()
      |> Nx.take(0)
      |> Nx.slice_along_axis(0, iterations, axis: 0)
      |> Nx.to_list()
      |> Enum.with_index(fn [
                              x,
                              y,
                              remaining_seconds,
                              reward,
                              heading,
                              speed,
                              vmg,
                              distance_to_mark,
                              target_x,
                              target_y,
                              is_terminal
                            ],
                            index ->
        %{
          x: x,
          y: y,
          epoch: epoch,
          heading: heading,
          boat_speed: speed,
          distance_to_mark: distance_to_mark,
          vmg: vmg,
          index: index,
          remaining_seconds: remaining_seconds,
          reward: reward,
          target_x: target_x,
          target_y: target_y,
          is_terminal: is_terminal
        }
      end)

    {:noreply, %{trajectories: [state.trajectories, trajectory_data], marks: []}}
  end

  def handle_call({:save, filename}, _from, state) do
    data =
      state.trajectories
      |> Enum.reverse()
      |> List.flatten()
      |> Enum.flat_map(&amp;Map.to_list/1)
      |> Enum.group_by(&amp;elem(&amp;1, 0), &amp;elem(&amp;1, 1))

    layers = [
      %{
        mark: %{type: :line, opts: [tooltip: [content: "data"], point: true]},
        data: data
      }
    ]

    contents = :erlang.term_to_binary(layers)
    File.write!(filename, contents)
    {:reply, :ok, state}
  end
end
# 250 max_iter * 15 episodes
max_points = 1000

target_data = %{
  x: Enum.map(valid_points, &amp;Enum.at(&amp;1, 0)),
  y: Enum.map(valid_points, &amp;Enum.at(&amp;1, 1))
}

to_number = fn
  %{vectorized_axes: []} = t, _index ->
    Nx.to_number(Nx.backend_copy(t, Nx.BinaryBackend))

  t, index ->
    t
    |> Nx.backend_copy(Nx.BinaryBackend)
    |> Nx.devectorize()
    |> Nx.take(Nx.backend_copy(index))
    |> Nx.to_number()
end

plot_fn = fn state ->
  episode = state.episode
  trajectory = state.step_state.trajectory
  env_state = state.step_state.environment_state
  agent_state = state.step_state.agent_state
  IO.inspect("Episode #{episode} ended")

  devec_trajectory =
    Nx.devectorize(trajectory) |> Nx.backend_copy(Nx.BinaryBackend) |> Nx.to_list()

  devec_trajectory =
    if trajectory.vectorized_axes == [] do
      [devec_trajectory]
    else
      devec_trajectory
    end

  num_vectors = length(devec_trajectory)

  {{_, iteration}, index} =
    devec_trajectory
    |> Enum.map(fn traj ->
      {traj, _nan_rows} =
        Enum.split_while(traj, fn row ->
          hd(row) != :nan
        end)

      Enum.reduce_while(traj, {0, 0}, fn row, {reward_acc, len_acc} ->
        is_terminal = Enum.at(row, 10)
        reward = Enum.at(row, 3)

        if is_terminal == 1 do
          {:halt, {reward_acc + reward, len_acc + 1}}
        else
          {:cont, {reward_acc + reward, len_acc + 1}}
        end
      end)
    end)
    |> Enum.with_index()
    |> Enum.max_by(fn {{sum, _len}, _idx} -> sum end)

  if rem(episode, 1) == 0 and iteration > 0 do
    Kino.VegaLite.clear(grid_widget, dataset: "trajectory")
    Kino.VegaLite.clear(grid_widget, dataset: "target")

    traj =
      trajectory
      |> Nx.backend_copy(Nx.BinaryBackend)
      |> Nx.revectorize([x: :auto], target_shape: trajectory.shape)
      |> Nx.devectorize()
      |> then(&amp; &amp;1[[index, 0..(iteration - 1)//1, 0..10]])

    {points, terminal_points} =
      traj
      |> Nx.to_list()
      |> Enum.with_index(fn [
                              x,
                              y,
                              remaining_time,
                              reward,
                              heading,
                              _speed,
                              vmg,
                              distance_to_mark,
                              target_x,
                              target_y,
                              is_terminal
                            ],
                            index ->
        %{
          x: x,
          y: y,
          heading: heading,
          index: index,
          episode: episode,
          remaining_time: remaining_time,
          reward: reward,
          distance: distance_to_mark,
          vmg: vmg,
          target_x: target_x,
          target_y: target_y,
          is_terminal: is_terminal == 1
        }
      end)
      |> Enum.split_while(fn row -> not row.is_terminal end)

    points = points ++ Enum.take(terminal_points, 1)

    total_reward = Enum.reduce(points, 0, fn row, acc -> acc + row.reward end)

    Kino.VegaLite.push_many(grid_widget, points, dataset: "trajectory")

    target_x = to_number.(env_state.target_x, index)
    target_y = to_number.(env_state.target_y, index)

    target_data =
      Enum.zip_with(target_data.x, target_data.y, fn x, y ->
        is_target =
          if abs(x - target_x) < 0.05 and abs(y - target_y) < 0.05 do
            1
          else
            0
          end

        %{x: x, y: y, is_target: is_target}
      end)

    Kino.VegaLite.push_many(grid_widget, target_data, dataset: "target")

    loss = to_number.(agent_state.loss, index)
    loss_den = to_number.(agent_state.loss_denominator, index)

    if loss_den > 0 do
      loss = if is_number(loss), do: loss / loss_den, else: -1

      Kino.VegaLite.push(
        loss_widget,
        %{
          x: episode,
          y: loss
        },
        dataset: "values"
      )
    end

    Kino.VegaLite.push(
      reward_widget,
      %{
        x: state.episode,
        y: total_reward,
        iterations: state.iteration
      },
      dataset: "values"
    )

    log_entropy_coefficient =
      to_number.(agent_state.log_entropy_coefficient, index)

    exp_buf_index =
      to_number.(agent_state.experience_replay_buffer.index, index)

    Kino.VegaLite.push_many(
      misc_value_widget,
      [
        %{
          x: state.episode,
          y: :math.exp(log_entropy_coefficient),
          source: "entropy_coefficient",
          iterations: state.iteration
        },
        %{
          x: state.episode,
          y: exp_buf_index,
          source: "IDX(exp. replay buffer)",
          iterations: state.iteration
        }
      ],
      dataset: "values"
    )

    global_kv_key = :rl_training_metadata_key

    :persistent_term.put(global_kv_key, %{
      optimizer_steps: episode * num_vectors,
      total_reward: total_reward
    })
  end
end
model_name = "sac_multi_mark_v2"
checkpoint_path = Path.join(System.fetch_env!("HOME"), "Desktop/checkpoints/")
filename = Path.join(checkpoint_path, "#{model_name}_latest.ckpt")

saved_state =
  try do
    serialized = File.read!(filename)
    File.write!(filename <> "_bak", serialized)

    Nx.with_default_backend({EXLA.Backend, client: :host}, fn -> Nx.deserialize(serialized) end)
  rescue
    _ in [File.Error, MatchError] -> %{}
  end

fields = [
  :distance,
  :vmg,
  :heading,
  :angle_to_mark,
  :has_tacked,
  :has_reached_target
]

num_actions = 1
state_features_size = length(fields)
state_features_memory_length = 1

state_input = Axon.input("state", shape: {nil, state_features_memory_length, state_features_size})

action_input = Axon.input("actions", shape: {nil, num_actions})

policy = Axon.MixedPrecision.create_policy(params: {:f, 32}, compute: {:f, 32}, output: {:f, 32})

state_size = state_features_size * state_features_memory_length

actor_net_base =
  state_input
  |> Axon.flatten()
  |> Axon.dense(256, activation: :relu)
  |> Axon.dense(256, activation: :relu)

actor_net_mean_out =
  actor_net_base
  |> Axon.dense(num_actions, activation: :linear)
  |> Axon.reshape({:batch, num_actions, 1})

actor_net_stddev_out =
  actor_net_base
  |> Axon.dense(num_actions, activation: :linear)
  |> Axon.nx(&amp;Nx.clip(&amp;1, -20, 2))
  |> Axon.reshape({:batch, num_actions, 1})

actor_net =
  [actor_net_mean_out, actor_net_stddev_out]
  |> Axon.concatenate(name: "actor_net_output")
  |> Axon.MixedPrecision.apply_policy(policy, except: [:batch_norm])

critic_net =
  state_input
  |> Axon.flatten()
  |> then(&amp;Axon.concatenate([&amp;1, action_input], name: "critic_combined_input"))
  |> Axon.flatten()
  |> Axon.dense(256, activation: :relu)
  |> Axon.dropout()
  |> Axon.dense(256, activation: :relu)
  |> Axon.dropout()
  |> Axon.dense(1, activation: :linear)
  |> Axon.MixedPrecision.apply_policy(policy, except: [:batch_norm])

# These might seem redundant, but will make more sense for multi-input models
normalize_x = fn x ->
  import Kernel, only: []
  import Nx.Defn.Kernel, only: [-: 2, /: 2]
  (x - min_x) / (max_x - min_x)
end

normalize_y = fn y ->
  import Kernel, only: []
  import Nx.Defn.Kernel, only: [-: 2, /: 2]
  (y - min_y) / (max_y - min_y)
end

normalize_speed = fn s ->
  import Kernel, only: []
  import Nx.Defn.Kernel, only: [-: 2, +: 2, /: 2]
  s / 10
end

normalize_angle = fn a ->
  import Kernel, only: []
  import Nx.Defn.Kernel, only: [*: 2, /: 2]
  a / (2 * :math.pi())
end

environment_to_state_features_fn = fn env_state ->
  [
    Nx.divide(env_state.distance, :math.sqrt((max_x - min_x) ** 2 + (max_y - min_y) ** 2)),
    normalize_speed.(env_state.vmg),
    normalize_angle.(env_state.heading),
    normalize_angle.(env_state.angle_to_mark),
    env_state.has_tacked,
    env_state.has_reached_target
  ]
  |> Nx.stack()
end

state_features_memory_to_input_fn = fn state_features ->
  %{
    "state" =>
      Nx.reshape(state_features, {:auto, state_features_memory_length, state_features_size})
  }
end

# total_episodes = 0
saved_state = %{}
# saved_state = Map.delete(saved_state, :log_entropy_coefficient)
IO.inspect(saved_state, label: "{saved_state, total_episodes}")

Kino.Layout.grid(
  [
    Kino.Layout.tabs([
      {"Critic",
       Axon.Display.as_graph(critic_net, %{
         "actions" => Nx.template({1, num_actions}, :f32),
         "state" => Nx.template({1, state_features_memory_length, state_features_size}, :f32)
       })}
    ]),
    Kino.Layout.tabs([
      {"Actor",
       Axon.Display.as_graph(actor_net, %{
         "state" => Nx.template({1, state_features_memory_length, state_features_size}, :f32)
       })}
    ])
  ],
  columns: 2
)
monitor_frame = Kino.Frame.new()

Kino.listen(1000, fn _ ->
  {text, _} = System.cmd("nvidia-smi", [])

  text
  |> Kino.Text.new()
  |> then(&amp;Kino.Frame.render(monitor_frame, &amp;1))
end)

monitor_frame

Train

Kino.VegaLite.clear(grid_widget)
Kino.VegaLite.clear(loss_widget, dataset: "values")
Kino.VegaLite.clear(reward_widget, dataset: "values")
Kino.VegaLite.clear(misc_value_widget, dataset: "values")

checkpoint_serialization_fn = fn loop_state ->
  to_serialize = loop_state.agent_state
  Nx.serialize(to_serialize)
end

episodes = 10_000

max_iter = 10000

random_key_init = Nx.Random.key(42)
num_vectors = 20

random_key_devec =
  Nx.Random.randint_split(random_key_init, 0, Nx.Constants.max_finite(:u32),
    type: :u32,
    shape: {num_vectors, 2}
  )

vectorized_axes = [vectors: num_vectors]
random_key = Nx.revectorize(random_key_devec, vectorized_axes, target_shape: {2})
# random_key = Nx.Random.key(42)
{t, result} =
  :timer.tc(fn ->
    Rein.train(
      {
        BoatLearner.Environments.MultiMark,
        max_remaining_seconds: 500, coords: points_tensor, coord_probabilities: points_probs
      },
      {
        Rein.Agents.SAC,
        tau: 0.005,
        gamma: 0.99,
        actor_net: actor_net,
        critic_net: critic_net,
        state_features_memory_length: state_features_memory_length,
        experience_replay_buffer_max_size: 1_000_000,
        environment_to_state_features_fn: environment_to_state_features_fn,
        state_features_memory_to_input_fn: state_features_memory_to_input_fn,
        state_features_size: state_features_size,
        training_frequency: 1,
        batch_size: 32,
        entropy_coefficient: 5,
        saved_state: saved_state,
        actor_optimizer: Polaris.Optimizers.sgd(learning_rate: 3.0e-4),
        critic_optimizer: Polaris.Optimizers.sgd(learning_rate: 3.0e-4),
        entropy_coefficient_optimizer: Polaris.Optimizers.sgd(learning_rate: 1.0e-3)
      },
      plot_fn,
      &amp;AccumulateAndExportData.state_to_trajectory_entry/1,
      num_episodes: episodes,
      max_iter: max_iter,
      accumulated_episodes: 0,
      random_key: random_key,
      checkpoint_serialization_fn: checkpoint_serialization_fn,
      checkpoint_path: checkpoint_path,
      model_name: model_name
    )
  end)

"#{Float.round(t / 1_000_000, 3)} s" |> IO.puts()

contents = checkpoint_serialization_fn.(result)
File.write!(filename, contents)