SAC Multi-Mark
my_app_root = Path.join(__DIR__, "..")
Mix.install(
[
{:boat_learner, path: my_app_root, env: :dev}
],
config_path: Path.join(my_app_root, "config/config.exs"),
lockfile: Path.join(my_app_root, "mix.lock"),
consolidate_protocols: false,
system_env: %{"XLA_TARGET" => System.get_env("XLA_TARGET", "cpu")},
force: true
)
Plot
alias VegaLite, as: Vl
{min_x, max_x, min_y, max_y} = BoatLearner.Environments.MultiMark.bounding_box()
r = 250
pi = :math.pi()
theta = 42.7 / 180 * pi
# valid_points = [[0, r], [0, 0], [0, -r]]
valid_points = [[0, r], [0, 0]]
# valid_points = [[0, 0], [r * :math.tan(theta), r]]
# points =
# for start <- valid_points, target <- valid_points, start != target do
# start ++ target
# end
points = [[0, 0, 0, r]]
# for angle <- [0, theta, -theta, pi, pi + theta, pi - theta, pi / 2, -pi / 2] do
# [
# Float.round(r * :math.cos(angle + pi / 2), 4),
# Float.round(r * :math.sin(angle + pi / 2), 4)
# ]
# end
points_tensor = Nx.tensor(points)
points_probs = :uniform
width = height = 750
scale_max_x = max_x / 250
scale_min_x = min_x / -250
alfa = :math.tan(theta)
step = 5
num_lines = 5
x = Enum.map(-(step * num_lines)..(step * num_lines), fn k -> alfa * r * k / step end)
y = List.duplicate(0, length(x) * 4)
x2 =
Enum.flat_map(x, fn x1 ->
[x1 + alfa * r, x1 - alfa * r, x1 + alfa * r, x1 - alfa * r]
end)
y2 = Enum.flat_map(x, fn _ -> [r, r, -r, -r] end)
x =
Enum.flat_map(x, fn x ->
[x, x, x, x]
end)
{x, x2, y, y2} =
Enum.flat_map(Enum.zip([x, x2, y, y2]), fn {x, x2, y, y2} ->
[
{x, x2, y, y2},
{x, x2, y2, y}
]
end)
|> Enum.reduce({[], [], [], []}, fn {x, x2, y, y2}, {xacc, x2acc, yacc, y2acc} ->
{[x | xacc], [x2 | x2acc], [y | yacc], [y2 | y2acc]}
end)
arc = fn r_scale ->
Vl.new()
|> Vl.data_from_values(%{x: [0], y: [0]},
name: "arc_grid"
)
|> Vl.mark(:arc,
clip: true,
radius: (r * r_scale + 1) * height / (max_y - min_y),
radius2: (r * r_scale - 1) * height / (max_y - min_y),
theta: 2 * :math.pi(),
color: "#ddd",
opacity: 0.75
)
|> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x]])
|> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y]])
end
target_point_radius = 5
target_point_radius_px = (max_y - min_y) / height * target_point_radius
target_point_size = pi ** 2 * target_point_radius_px
target_layer =
Vl.new()
|> Vl.data(name: "target")
|> Vl.mark(:point,
size: target_point_size,
opacity: %{expr: "if(datum.is_target == 1, 1, 0.75)"},
size: %{expr: "if(datum.is_target == 1, #{target_point_size}, #{target_point_size / 2})"},
filled: true,
tooltip: [content: "data"]
)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode(:color,
condition: %{test: "datum.is_target == 1", value: "red"},
value: "#333"
)
grid_widget =
Vl.new(width: width, height: height)
|> Vl.layers([
Vl.new()
|> Vl.data_from_values(%{x: x, y: y, x2: x2, y2: y2}, name: "diagonal_grid")
|> Vl.mark(:rule, clip: true, color: "#ddd", size: 2, opacity: 0.75)
|> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x]])
|> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y]])
|> Vl.encode_field(:x2, "x2", type: :quantitative)
|> Vl.encode_field(:y2, "y2", type: :quantitative),
arc.(1),
arc.(0.5),
arc.(0.25),
target_layer,
Vl.new()
|> Vl.data(name: "trajectory")
|> Vl.mark(:line, point: true, clip: true, opacity: 1, tooltip: [content: "data"])
|> Vl.encode_field(:x, "x",
legend: false,
type: :quantitative,
scale: [domain: [min_x, max_x]]
)
|> Vl.encode_field(:y, "y",
legend: false,
type: :quantitative,
scale: [domain: [min_y, max_y]]
)
|> Vl.encode_field(:color, "episode",
legend: false,
type: :nominal,
scale: [scheme: "blues"],
legend: false
)
|> Vl.encode_field(:order, "index", legend: false)
])
|> Kino.VegaLite.new()
value_widget_fn = fn title, plot_median ->
transform =
if plot_median do
&Vl.transform(&1,
frame: [-30, 0],
window: [
[
field: "y",
op: "median",
as: "rolling_median"
]
]
)
else
&Function.identity/1
end
Vl.new(width: width, height: div(height, 2), title: title)
|> Vl.data(name: "values")
|> Vl.layers([
Vl.new()
|> Vl.mark(:point,
grid: true,
tooltip: [content: "data"],
opacity: 0.25
)
|> Vl.encode_field(:color, "source", type: :nominal)
|> Vl.encode_field(:x, "x", legend: false, type: :quantitative)
|> Vl.encode_field(:order, "x", legend: false)
|> Vl.encode_field(:y, "y",
legend: false,
type: :quantitative,
scale: [
domain: %{expr: "[min_y, max_y]"},
clamp: true
],
axis: [tick_count: 10]
),
Vl.new()
|> Vl.mark(:line, grid: true)
|> transform.()
|> Vl.encode_field(:color, "source", type: :nominal)
|> Vl.encode_field(:x, "x", legend: false, type: :quantitative)
|> Vl.encode_field(:order, "x", legend: false)
|> Vl.encode_field(:y, if(plot_median, do: "rolling_median", else: "y"),
legend: false,
type: :quantitative,
scale: [
domain: %{expr: "[min_y, max_y]"},
clamp: true
]
)
])
|> Vl.param("max_y", value: 5, bind: [input: "number"])
|> Vl.param("min_y", value: 0, bind: [input: "number"])
|> Kino.VegaLite.new()
end
loss_widget = value_widget_fn.("Loss", true)
reward_widget = value_widget_fn.("Total Reward", true)
misc_value_widget = value_widget_fn.("Misc. Values", false)
global_kv_key = :rl_training_metadata_key
data = %{optimizer_steps: nil, total_reward: nil}
:persistent_term.put(global_kv_key, data)
metadata_widget = Kino.Frame.new()
Kino.listen(250, nil, fn _, prev_data ->
data = :persistent_term.get(global_kv_key)
if data != prev_data do
Kino.Frame.render(
metadata_widget,
[
"| Field | Value |\n",
"| ----- | ----- |\n",
Enum.map(data, fn {k, v} ->
["|", to_string(k), " | ", if(v, do: to_string(v), else: " "), " |\n"]
end)
]
|> IO.iodata_to_binary()
|> Kino.Markdown.new()
)
end
{:cont, data}
end)
simulation_widget =
Kino.Layout.grid([
Kino.Layout.grid([grid_widget], boxed: true),
Kino.Layout.grid([metadata_widget], boxed: true)
])
training_widget =
Kino.Layout.grid([
Kino.Layout.grid([loss_widget], boxed: true),
Kino.Layout.grid([reward_widget], boxed: true),
Kino.Layout.grid([misc_value_widget], boxed: true)
])
Kino.Layout.tabs([{"Simulation", simulation_widget}, {"Training", training_widget}])
defmodule AccumulateAndExportData do
use GenServer
import Nx.Defn
import Nx.Constants
defnp rad_to_deg(angle) do
angle = angle * 180 / pi()
Nx.select(angle > 180, angle - 360, angle)
end
def state_to_trajectory_entry(%{environment_state: env, agent_state: _agent}) do
distance_to_mark =
Nx.sqrt(
Nx.pow(Nx.subtract(env.target_y, env.y), 2)
|> Nx.add(Nx.pow(Nx.subtract(env.target_x, env.x), 2))
)
Nx.stack([
env.x,
env.y,
env.remaining_seconds,
env.reward,
rad_to_deg(env.heading),
env.speed,
env.vmg,
distance_to_mark,
env.target_x,
env.target_y,
env.is_terminal
])
end
def init(_opts), do: {:ok, %{trajectories: [], marks: %{x: [], y: [], epoch: [], index: []}}}
def start_link(opts \\ []), do: GenServer.start_link(__MODULE__, opts, name: __MODULE__)
def reset, do: GenServer.cast(__MODULE__, :reset)
def handle_cast(:reset, _) do
{:ok, state} = init([])
{:noreply, state}
end
def handle_cast({:add_epoch, _epoch, 0, _trajectory_tensor}, state) do
{:noreply, state}
end
def handle_cast({:add_epoch, epoch, iterations, trajectory_tensor}, state) do
trajectory_data =
trajectory_tensor
|> Nx.revectorize([x: :auto], target_shape: trajectory_tensor.shape)
|> Nx.devectorize()
|> Nx.take(0)
|> Nx.slice_along_axis(0, iterations, axis: 0)
|> Nx.to_list()
|> Enum.with_index(fn [
x,
y,
remaining_seconds,
reward,
heading,
speed,
vmg,
distance_to_mark,
target_x,
target_y,
is_terminal
],
index ->
%{
x: x,
y: y,
epoch: epoch,
heading: heading,
boat_speed: speed,
distance_to_mark: distance_to_mark,
vmg: vmg,
index: index,
remaining_seconds: remaining_seconds,
reward: reward,
target_x: target_x,
target_y: target_y,
is_terminal: is_terminal
}
end)
{:noreply, %{trajectories: [state.trajectories, trajectory_data], marks: []}}
end
def handle_call({:save, filename}, _from, state) do
data =
state.trajectories
|> Enum.reverse()
|> List.flatten()
|> Enum.flat_map(&Map.to_list/1)
|> Enum.group_by(&elem(&1, 0), &elem(&1, 1))
layers = [
%{
mark: %{type: :line, opts: [tooltip: [content: "data"], point: true]},
data: data
}
]
contents = :erlang.term_to_binary(layers)
File.write!(filename, contents)
{:reply, :ok, state}
end
end
# 250 max_iter * 15 episodes
max_points = 1000
target_data = %{
x: Enum.map(valid_points, &Enum.at(&1, 0)),
y: Enum.map(valid_points, &Enum.at(&1, 1))
}
to_number = fn
%{vectorized_axes: []} = t, _index ->
Nx.to_number(Nx.backend_copy(t, Nx.BinaryBackend))
t, index ->
t
|> Nx.backend_copy(Nx.BinaryBackend)
|> Nx.devectorize()
|> Nx.take(Nx.backend_copy(index))
|> Nx.to_number()
end
plot_fn = fn state ->
episode = state.episode
trajectory = state.step_state.trajectory
env_state = state.step_state.environment_state
agent_state = state.step_state.agent_state
IO.inspect("Episode #{episode} ended")
devec_trajectory =
Nx.devectorize(trajectory) |> Nx.backend_copy(Nx.BinaryBackend) |> Nx.to_list()
devec_trajectory =
if trajectory.vectorized_axes == [] do
[devec_trajectory]
else
devec_trajectory
end
num_vectors = length(devec_trajectory)
{{_, iteration}, index} =
devec_trajectory
|> Enum.map(fn traj ->
{traj, _nan_rows} =
Enum.split_while(traj, fn row ->
hd(row) != :nan
end)
Enum.reduce_while(traj, {0, 0}, fn row, {reward_acc, len_acc} ->
is_terminal = Enum.at(row, 10)
reward = Enum.at(row, 3)
if is_terminal == 1 do
{:halt, {reward_acc + reward, len_acc + 1}}
else
{:cont, {reward_acc + reward, len_acc + 1}}
end
end)
end)
|> Enum.with_index()
|> Enum.max_by(fn {{sum, _len}, _idx} -> sum end)
if rem(episode, 1) == 0 and iteration > 0 do
Kino.VegaLite.clear(grid_widget, dataset: "trajectory")
Kino.VegaLite.clear(grid_widget, dataset: "target")
traj =
trajectory
|> Nx.backend_copy(Nx.BinaryBackend)
|> Nx.revectorize([x: :auto], target_shape: trajectory.shape)
|> Nx.devectorize()
|> then(& &1[[index, 0..(iteration - 1)//1, 0..10]])
{points, terminal_points} =
traj
|> Nx.to_list()
|> Enum.with_index(fn [
x,
y,
remaining_time,
reward,
heading,
_speed,
vmg,
distance_to_mark,
target_x,
target_y,
is_terminal
],
index ->
%{
x: x,
y: y,
heading: heading,
index: index,
episode: episode,
remaining_time: remaining_time,
reward: reward,
distance: distance_to_mark,
vmg: vmg,
target_x: target_x,
target_y: target_y,
is_terminal: is_terminal == 1
}
end)
|> Enum.split_while(fn row -> not row.is_terminal end)
points = points ++ Enum.take(terminal_points, 1)
total_reward = Enum.reduce(points, 0, fn row, acc -> acc + row.reward end)
Kino.VegaLite.push_many(grid_widget, points, dataset: "trajectory")
target_x = to_number.(env_state.target_x, index)
target_y = to_number.(env_state.target_y, index)
target_data =
Enum.zip_with(target_data.x, target_data.y, fn x, y ->
is_target =
if abs(x - target_x) < 0.05 and abs(y - target_y) < 0.05 do
1
else
0
end
%{x: x, y: y, is_target: is_target}
end)
Kino.VegaLite.push_many(grid_widget, target_data, dataset: "target")
loss = to_number.(agent_state.loss, index)
loss_den = to_number.(agent_state.loss_denominator, index)
if loss_den > 0 do
loss = if is_number(loss), do: loss / loss_den, else: -1
Kino.VegaLite.push(
loss_widget,
%{
x: episode,
y: loss
},
dataset: "values"
)
end
Kino.VegaLite.push(
reward_widget,
%{
x: state.episode,
y: total_reward,
iterations: state.iteration
},
dataset: "values"
)
log_entropy_coefficient =
to_number.(agent_state.log_entropy_coefficient, index)
exp_buf_index =
to_number.(agent_state.experience_replay_buffer.index, index)
Kino.VegaLite.push_many(
misc_value_widget,
[
%{
x: state.episode,
y: :math.exp(log_entropy_coefficient),
source: "entropy_coefficient",
iterations: state.iteration
},
%{
x: state.episode,
y: exp_buf_index,
source: "IDX(exp. replay buffer)",
iterations: state.iteration
}
],
dataset: "values"
)
global_kv_key = :rl_training_metadata_key
:persistent_term.put(global_kv_key, %{
optimizer_steps: episode * num_vectors,
total_reward: total_reward
})
end
end
model_name = "sac_multi_mark_v2"
checkpoint_path = Path.join(System.fetch_env!("HOME"), "Desktop/checkpoints/")
filename = Path.join(checkpoint_path, "#{model_name}_latest.ckpt")
saved_state =
try do
serialized = File.read!(filename)
File.write!(filename <> "_bak", serialized)
Nx.with_default_backend({EXLA.Backend, client: :host}, fn -> Nx.deserialize(serialized) end)
rescue
_ in [File.Error, MatchError] -> %{}
end
fields = [
:distance,
:vmg,
:heading,
:angle_to_mark,
:has_tacked,
:has_reached_target
]
num_actions = 1
state_features_size = length(fields)
state_features_memory_length = 1
state_input = Axon.input("state", shape: {nil, state_features_memory_length, state_features_size})
action_input = Axon.input("actions", shape: {nil, num_actions})
policy = Axon.MixedPrecision.create_policy(params: {:f, 32}, compute: {:f, 32}, output: {:f, 32})
state_size = state_features_size * state_features_memory_length
actor_net_base =
state_input
|> Axon.flatten()
|> Axon.dense(256, activation: :relu)
|> Axon.dense(256, activation: :relu)
actor_net_mean_out =
actor_net_base
|> Axon.dense(num_actions, activation: :linear)
|> Axon.reshape({:batch, num_actions, 1})
actor_net_stddev_out =
actor_net_base
|> Axon.dense(num_actions, activation: :linear)
|> Axon.nx(&Nx.clip(&1, -20, 2))
|> Axon.reshape({:batch, num_actions, 1})
actor_net =
[actor_net_mean_out, actor_net_stddev_out]
|> Axon.concatenate(name: "actor_net_output")
|> Axon.MixedPrecision.apply_policy(policy, except: [:batch_norm])
critic_net =
state_input
|> Axon.flatten()
|> then(&Axon.concatenate([&1, action_input], name: "critic_combined_input"))
|> Axon.flatten()
|> Axon.dense(256, activation: :relu)
|> Axon.dropout()
|> Axon.dense(256, activation: :relu)
|> Axon.dropout()
|> Axon.dense(1, activation: :linear)
|> Axon.MixedPrecision.apply_policy(policy, except: [:batch_norm])
# These might seem redundant, but will make more sense for multi-input models
normalize_x = fn x ->
import Kernel, only: []
import Nx.Defn.Kernel, only: [-: 2, /: 2]
(x - min_x) / (max_x - min_x)
end
normalize_y = fn y ->
import Kernel, only: []
import Nx.Defn.Kernel, only: [-: 2, /: 2]
(y - min_y) / (max_y - min_y)
end
normalize_speed = fn s ->
import Kernel, only: []
import Nx.Defn.Kernel, only: [-: 2, +: 2, /: 2]
s / 10
end
normalize_angle = fn a ->
import Kernel, only: []
import Nx.Defn.Kernel, only: [*: 2, /: 2]
a / (2 * :math.pi())
end
environment_to_state_features_fn = fn env_state ->
[
Nx.divide(env_state.distance, :math.sqrt((max_x - min_x) ** 2 + (max_y - min_y) ** 2)),
normalize_speed.(env_state.vmg),
normalize_angle.(env_state.heading),
normalize_angle.(env_state.angle_to_mark),
env_state.has_tacked,
env_state.has_reached_target
]
|> Nx.stack()
end
state_features_memory_to_input_fn = fn state_features ->
%{
"state" =>
Nx.reshape(state_features, {:auto, state_features_memory_length, state_features_size})
}
end
# total_episodes = 0
saved_state = %{}
# saved_state = Map.delete(saved_state, :log_entropy_coefficient)
IO.inspect(saved_state, label: "{saved_state, total_episodes}")
Kino.Layout.grid(
[
Kino.Layout.tabs([
{"Critic",
Axon.Display.as_graph(critic_net, %{
"actions" => Nx.template({1, num_actions}, :f32),
"state" => Nx.template({1, state_features_memory_length, state_features_size}, :f32)
})}
]),
Kino.Layout.tabs([
{"Actor",
Axon.Display.as_graph(actor_net, %{
"state" => Nx.template({1, state_features_memory_length, state_features_size}, :f32)
})}
])
],
columns: 2
)
monitor_frame = Kino.Frame.new()
Kino.listen(1000, fn _ ->
{text, _} = System.cmd("nvidia-smi", [])
text
|> Kino.Text.new()
|> then(&Kino.Frame.render(monitor_frame, &1))
end)
monitor_frame
Train
Kino.VegaLite.clear(grid_widget)
Kino.VegaLite.clear(loss_widget, dataset: "values")
Kino.VegaLite.clear(reward_widget, dataset: "values")
Kino.VegaLite.clear(misc_value_widget, dataset: "values")
checkpoint_serialization_fn = fn loop_state ->
to_serialize = loop_state.agent_state
Nx.serialize(to_serialize)
end
episodes = 10_000
max_iter = 10000
random_key_init = Nx.Random.key(42)
num_vectors = 20
random_key_devec =
Nx.Random.randint_split(random_key_init, 0, Nx.Constants.max_finite(:u32),
type: :u32,
shape: {num_vectors, 2}
)
vectorized_axes = [vectors: num_vectors]
random_key = Nx.revectorize(random_key_devec, vectorized_axes, target_shape: {2})
# random_key = Nx.Random.key(42)
{t, result} =
:timer.tc(fn ->
Rein.train(
{
BoatLearner.Environments.MultiMark,
max_remaining_seconds: 500, coords: points_tensor, coord_probabilities: points_probs
},
{
Rein.Agents.SAC,
tau: 0.005,
gamma: 0.99,
actor_net: actor_net,
critic_net: critic_net,
state_features_memory_length: state_features_memory_length,
experience_replay_buffer_max_size: 1_000_000,
environment_to_state_features_fn: environment_to_state_features_fn,
state_features_memory_to_input_fn: state_features_memory_to_input_fn,
state_features_size: state_features_size,
training_frequency: 1,
batch_size: 32,
entropy_coefficient: 5,
saved_state: saved_state,
actor_optimizer: Polaris.Optimizers.sgd(learning_rate: 3.0e-4),
critic_optimizer: Polaris.Optimizers.sgd(learning_rate: 3.0e-4),
entropy_coefficient_optimizer: Polaris.Optimizers.sgd(learning_rate: 1.0e-3)
},
plot_fn,
&AccumulateAndExportData.state_to_trajectory_entry/1,
num_episodes: episodes,
max_iter: max_iter,
accumulated_episodes: 0,
random_key: random_key,
checkpoint_serialization_fn: checkpoint_serialization_fn,
checkpoint_path: checkpoint_path,
model_name: model_name
)
end)
"#{Float.round(t / 1_000_000, 3)} s" |> IO.puts()
contents = checkpoint_serialization_fn.(result)
File.write!(filename, contents)