Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

First steps with Gridworld

guides/gridworld.livemd

First steps with Gridworld

my_app_root = Path.join(__DIR__, "..")

Mix.install(
  [
    {:rein, path: my_app_root},
    {:kino_vega_lite, "~> 0.1"}
  ],
  config_path: Path.join(my_app_root, "config/config.exs"),
  lockfile: Path.join(my_app_root, "mix.lock"),
  # change to "cuda118" or "cuda120" to use CUDA
  system_env: %{"XLA_TARGET" => "cpu"}
)

Initializing the plot

In the code block below, we initialize some meta variables and configure our VegaLite plot in way that it can be updated iteratively over the algorithm iterations.

alias VegaLite, as: Vl

{min_x, max_x, min_y, max_y} = Rein.Environments.Gridworld.bounding_box()

possible_targets_l = [[round((min_x + max_x) / 2), max_y]]

# possible_targets_l =
#   for x <- (min_x + 2)..(max_x - 2), y <- 2..max_y do
#     [x, y]
#   end

possible_targets = Nx.tensor(Enum.shuffle(possible_targets_l))

width = 600
height = 600

grid_widget =
  Vl.new(width: width, height: height)
  |> Vl.layers([
    Vl.new()
    |> Vl.data(name: "target")
    |> Vl.mark(:point,
      fill: true,
      tooltip: [content: "data"],
      grid: true,
      size: [expr: "height * 4 * #{:math.pi()} / #{max_y - min_y}"]
    )
    |> Vl.encode_field(:x, "x", type: :quantitative)
    |> Vl.encode_field(:y, "y", type: :quantitative)
    |> Vl.encode_field(:color, "episode",
      type: :nominal,
      scale: [scheme: "blues"],
      legend: false
    ),
    Vl.new()
    |> Vl.data(name: "trajectory")
    |> Vl.mark(:line, point: true, opacity: 1, tooltip: [content: "data"])
    |> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x], clamp: true])
    |> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y], clamp: true])
    |> Vl.encode_field(:order, "index")
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

nil

Configuring and running the Q Learning Agent

Now we’re ready to start configuring our agent. The plot_fn function defined below is a callback that Rein calls at the end of each iteration, so that we can do anything with the data.

Usually, this means that we’ll extract data to either plot, report or save somewhere.

# 250 max_iter * 15 episodes
max_points = 1000

plot_fn = fn axon_state ->
  if axon_state.iteration > 1 do
    episode = axon_state.episode

    Kino.VegaLite.clear(grid_widget, dataset: "target")
    Kino.VegaLite.clear(grid_widget, dataset: "trajectory")

    Kino.VegaLite.push(
      grid_widget,
      %{
        x: Nx.to_number(axon_state.step_state.environment_state.target_x),
        y: Nx.to_number(axon_state.step_state.environment_state.target_y)
      },
      dataset: "target"
    )

    IO.inspect("Episode #{episode} ended")

    trajectory = axon_state.step_state.trajectory

    iteration = Nx.to_number(axon_state.step_state.iteration)

    points =
      trajectory[0..(iteration - 1)//1]
      |> Nx.to_list()
      |> Enum.with_index(fn [x, y], index ->
        %{
          x: x,
          y: y,
          index: index
        }
      end)

    Kino.VegaLite.push_many(grid_widget, points, dataset: "trajectory")
  end

  axon_state
end

Now, we get to the actual training!

The code below calls Rein.train with some configuration for the Gridworld environment being solved through a QLearning agent.

This will return the whole Axon.Loop struct in the result variable, so that we can inspect and/or save it afterwards.

Kino.VegaLite.clear(grid_widget)

episodes = 15_000
max_iter = 20

environment_to_state_vector_fn = fn %{x: x, y: y, target_x: target_x, target_y: target_y} ->
  delta_x = Nx.subtract(x, min_x)
  delta_y = Nx.subtract(y, min_y)

  Nx.stack([delta_x, delta_y, Nx.subtract(target_x, min_x), Nx.subtract(target_y, min_y)])
end

state_to_trajectory_fn = fn %{environment_state: %{x: x, y: y}} ->
  Nx.stack([x, y])
end

delta_x = max_x - min_x + 1
delta_y = max_y - min_y + 1

state_space_shape = {delta_x, delta_y, delta_x, delta_y}

{t, result} =
  :timer.tc(fn ->
    Rein.train(
      {Rein.Environments.Gridworld, possible_targets: possible_targets},
      {Rein.Agents.QLearning,
       state_space_shape: state_space_shape,
       num_actions: 4,
       environment_to_state_vector_fn: environment_to_state_vector_fn,
       learning_rate: 1.0e-2,
       gamma: 0.99,
       exploration_eps: 1.0e-4},
      plot_fn,
      state_to_trajectory_fn,
      checkpoint_path: "/tmp/gridworld",
      num_episodes: episodes,
      max_iter: max_iter
    )
  end)

"#{Float.round(t / 1_000_000, 3)} s"

With the code below, we can check some points of interest in the learned Q matrix.

Especially, we can see below that for a target at x = 2, y = 4:

  • For the position x = 2, y = 3, the selected action is to go up;
  • For the position x = 1, y = 4, the selected action is to go right;
  • For the position x = 3, y = 4, the selected action is to go left.

This shows that at least for the positions closer to the target, our agent already knows the best policy for those respective states!

state_vector_to_index = fn state_vector, shape ->
  {linear_indices_offsets_list, _} =
    shape
    |> Tuple.to_list()
    |> Enum.reverse()
    |> Enum.reduce({[], 1}, fn x, {acc, multiplier} ->
      {[multiplier | acc], multiplier * x}
    end)

  linear_indices_offsets = Nx.tensor(linear_indices_offsets_list)

  Nx.dot(state_vector, linear_indices_offsets)
end

# Actions are [up, down, right, left]

# up
idx = state_vector_to_index.(Nx.tensor([2, 3, 2, 4]), {5, 5, 5, 5})
IO.inspect(result.step_state.agent_state.q_matrix[idx])

# right
idx = state_vector_to_index.(Nx.tensor([1, 4, 2, 4]), {5, 5, 5, 5})
IO.inspect(result.step_state.agent_state.q_matrix[idx])

# left
idx = state_vector_to_index.(Nx.tensor([3, 4, 2, 4]), {5, 5, 5, 5})
IO.inspect(result.step_state.agent_state.q_matrix[idx])

nil