Powered by AppSignal & Oban Pro

Deep Q-Learning on Atari with ALEx

notebooks/deep_q_learning.livemd

Deep Q-Learning on Atari with ALEx

Mix.install(
  [
    # This notebook lives in the repo, so we use the local checkout (Livebook sets
    # the working directory to the notebook's folder). Swap for the published
    # package with {:alex, "~> 0.4"} if running it standalone.
    {:alex, path: ".."},
    {:nx, "~> 0.12"},
    {:axon, "~> 0.8"},
    {:exla, "~> 0.12"},
    {:kino, "~> 0.12"},
    {:kino_vega_lite, "~> 0.1"}
  ],
  config: [
    # Run all Nx work on the EXLA (XLA) backend, and JIT-compile every defn with it.
    nx: [default_backend: EXLA.Backend, default_defn_options: [compiler: EXLA]]
  ],
  force: true
)

What this notebook does

This trains a Deep Q-Network to play an Atari game from inside Elixir, using:

  • ALEx for the emulator,
  • Nx + EXLA for tensors and XLA-compiled math,
  • Axon for the neural network, and
  • Kino to watch the agent play live while it learns.

The agent plays Pong and learns from the console’s 128 bytes of RAM rather than from pixels. RAM is a compact, information-rich state, so a small multilayer perceptron can make visible progress within a notebook session — whereas a pixel-based convolutional DQN famously needs millions of frames. We stack the last 4 RAM frames so the network can perceive motion (a single frame can’t tell you which way the ball is moving). The live view still shows the real game screen; only the agent’s eyes are the stacked RAM.

A serious caveat: this is a demonstration of the full pipeline and the live tooling, not a tuned agent. Even with these helps, expect tens of minutes before the reward curve clearly turns upward — real Atari agents train for tens of millions of frames.

> For the full pixel-based experience, replace Obs.frame/1 with a downsampled grayscale tensor > (Alex.Screen.grayscale/1) and swap the MLP below for Axon.conv layers.

The environment

game = "pong"

# frame_skip: 4 is the standard Atari setting — each action is held for 4 frames.
env = Alex.new(game,
  frame_skip: 4,
  repeat_action_probability: 0.0,
  random_seed: 0,
  rom_dir: Path.expand("~/.alex/roms")
)

actions = Alex.minimal_actions(env)
n_actions = length(actions)

# Stack the last few RAM frames so the network can infer motion (which way the
# ball/paddle is moving) — a single frame can't convey velocity.
stack_size = 4
state_size = stack_size * env.ram_size

IO.puts("Game: #{game}")
IO.puts("Actions (#{n_actions}): #{inspect(actions)}")
IO.puts("State size (#{stack_size} x #{env.ram_size} RAM bytes): #{state_size}")

Helpers to turn raw RAM into a normalized :f32 frame, and to stack frames into a state.

defmodule Obs do
  import Nx.Defn

  @doc "RAM binary -> {128} f32 tensor scaled to [0, 1]."
  def frame(env) do
    env
    |> Alex.RAM.read()
    |> Nx.from_binary(:u8)
    |> normalize()
  end

  defn normalize(bytes), do: Nx.as_type(bytes, :f32) / 255.0

  @doc "Concatenate a list of frame tensors into a single stacked state vector."
  def stack(frames), do: Nx.concatenate(frames)
end

The Q-network

A plain MLP that maps a state to one Q-value per action.

model =
  Axon.input("state", shape: {nil, state_size})
  |> Axon.dense(256, activation: :relu)
  |> Axon.dense(256, activation: :relu)
  |> Axon.dense(n_actions)

{init_fn, predict_fn} = Axon.build(model)
template = Nx.template({1, state_size}, :f32)
params = init_fn.(template, Axon.ModelState.empty())

# A separate, slowly-updated "target" network stabilizes the learning target.
target_params = params

IO.inspect model

The learning step

A DQN module holds the whole numeric pipeline as defns, exposed through three thin wrappers:

  • predict — Q-values for a batch of states (used for action selection),
  • td_targets — the temporal-difference targets from the target network (no gradient), and
  • train_step — one gradient update of the online network against those targets, using a Huber loss for robustness.

The network (predict_fn) and optimizer (opt_update) are passed into the defns as arguments — defn accepts function-valued arguments and can call them — so all the tensor math, including the forward pass and the optimizer step, lives in defn where the + - * / operators dispatch to Nx.

gamma = 0.99
{opt_init, opt_update} = Polaris.Optimizers.adam(learning_rate: 2.5e-4)
opt_state = opt_init.(params)

# The whole numeric pipeline is `defn`, where +, -, * and / dispatch to Nx.
# `predict_fn` (the network forward, a closure from Axon.build/2) and `opt_update`
# (the optimizer) are passed *into* the defns as arguments — defn accepts
# function-valued arguments and can call them — so there are no jitted `fn`
# wrappers around tensor math.
defmodule DQN do
  import Nx.Defn

  # Q-values for a batch of states.
  defn predict(predict_fn, params, states) do
    predict_fn.(params, %{"state" => states})
  end

  # TD target: r + gamma * max_a' Q_target(s', a') * (1 - done).
  defn td_target(predict_fn, target_params, next_states, rewards, dones, gamma) do
    q_next = predict_fn.(target_params, %{"state" => next_states})
    max_next = Nx.reduce_max(q_next, axes: [1])
    rewards + gamma * max_next * (1.0 - dones)
  end

  # Huber loss, elementwise: 0.5 * t^2 for |t| <= 1, else |t| - 0.5.
  defnp huber(td) do
    abs = Nx.abs(td)
    quadratic = Nx.min(abs, 1.0)
    linear = abs - quadratic
    0.5 * quadratic * quadratic + linear
  end

  # Mean Huber loss between the chosen-action Q-values and the TD targets.
  defnp loss(predict_fn, params, states, actions, targets) do
    q = predict_fn.(params, %{"state" => states})
    idx = Nx.new_axis(actions, -1)
    q_taken = q |> Nx.take_along_axis(idx, axis: 1) |> Nx.squeeze(axes: [1])
    Nx.mean(huber(targets - q_taken))
  end

  # One optimizer step minimizing the masked Huber loss over the batch.
  defn train_step(predict_fn, opt_update, params, opt_state, states, actions, targets) do
    {loss_value, grads} =
      Nx.Defn.value_and_grad(params, fn params ->
        loss(predict_fn, params, states, actions, targets)
      end)

    {updates, opt_state} = opt_update.(grads, opt_state, params)
    {Polaris.Updates.apply_updates(params, updates), opt_state, loss_value}
  end
end

# Thin partial applications so the training loop doesn't repeat the fixed args.
# Calling a defn from regular Elixir JIT-compiles it (once) with the global EXLA
# compiler set in Mix.install — no explicit Nx.Defn.jit needed.
predict = &amp;DQN.predict(predict_fn, &amp;1, &amp;2)
td_targets = fn target_params, next_states, rewards, dones ->
  DQN.td_target(predict_fn, target_params, next_states, rewards, dones, gamma)
end

train_step = fn params, opt_state, states, actions, targets ->
  DQN.train_step(predict_fn, opt_update, params, opt_state, states, actions, targets)
end

:ok

Acting and remembering

Epsilon-greedy action selection, and a fixed-size replay buffer (a list capped to buffer_size).

defmodule Replay do
  @moduledoc "A bounded experience-replay buffer backed by a list."

  def new(max_size), do: %{items: [], size: 0, max: max_size}

  def add(buf, transition) do
    items = [transition | buf.items]

    if buf.size >= buf.max do
      %{buf | items: Enum.take(items, buf.max)}
    else
      %{buf | items: items, size: buf.size + 1}
    end
  end

  def sample(buf, batch_size) do
    batch = Enum.take_random(buf.items, batch_size)

    states = Enum.map(batch, &amp;elem(&amp;1, 0)) |> Nx.stack()
    actions = Enum.map(batch, &amp;elem(&amp;1, 1)) |> Nx.tensor(type: :s64)
    rewards = Enum.map(batch, &amp;elem(&amp;1, 2)) |> Nx.tensor(type: :f32)
    next_states = Enum.map(batch, &amp;elem(&amp;1, 3)) |> Nx.stack()
    dones = Enum.map(batch, &amp;(if elem(&amp;1, 4), do: 1.0, else: 0.0)) |> Nx.tensor(type: :f32)

    %{states: states, actions: actions, rewards: rewards, next_states: next_states, dones: dones}
  end
end
# Returns an action *index* in 0..n_actions-1.
select_action = fn params, state, epsilon ->
  if :rand.uniform() < epsilon do
    :rand.uniform(n_actions) - 1
  else
    params
    |> predict.(Nx.new_axis(state, 0))
    |> Nx.argmax(axis: 1)
    |> Nx.squeeze()
    |> Nx.to_number()
  end
end

# Reward clipping to [-1, 1] is standard DQN practice across games.
clip_reward = fn r -> r |> max(-1) |> min(1) end

# Start an episode: reset, press FIRE once to serve the ball (games like Pong and
# Breakout won't start otherwise), and seed the frame stack with the first frame.
# Returns {env, frames, state}.
start_episode = fn env ->
  env = Alex.reset(env)
  env = if :fire in actions, do: elem(Alex.step(env, :fire), 0), else: env
  frames = List.duplicate(Obs.frame(env), stack_size)
  {env, frames, Obs.stack(frames)}
end

# Advance the frame stack with a new frame, returning {frames, state}.
push_frame_stack = fn frames, env ->
  frames = [Obs.frame(env) | frames] |> Enum.take(stack_size)
  {frames, Obs.stack(frames)}
end

:ok

Live views

Create the widgets in their own cells so Livebook renders them. The training loop below pushes into them as it runs: the canvas shows a greedy “evaluation” episode every few training episodes, and the chart streams the reward per episode.

# A dedicated env used only for visualization, so rendering never disturbs training state.
eval_env = Alex.new(game,
  frame_skip: 4,
  repeat_action_probability: 0.0,
  random_seed: 1,
  rom_dir: Path.expand("~/.alex/roms")
)
viewer = Alex.Kino.view(eval_env, scale: 3)
alias VegaLite, as: Vl

reward_chart =
  Vl.new(width: 640, height: 320, title: "Episode reward")
  |> Vl.mark(:line, point: true)
  |> Vl.encode_field(:x, "episode", type: :quantitative)
  |> Vl.encode_field(:y, "reward", type: :quantitative)
  |> Kino.VegaLite.new()
# Plays one fully-greedy episode on eval_env, pushing each frame to the live canvas.
play_and_render = fn params ->
  {env, frames, state} = start_episode.(eval_env)

  Enum.reduce_while(1..10_000, {env, frames, state, 0.0}, fn _, {env, frames, state, total} ->
    action_idx = select_action.(params, state, 0.0)
    {env, info} = Alex.step(env, Enum.at(actions, action_idx))
    {frames, state} = push_frame_stack.(frames, env)
    Alex.Kino.push_frame(viewer, env)
    # Slow the playback down to a watchable speed.
    Process.sleep(15)

    if info.game_over?,
      do: {:halt, total + info.reward},
      else: {:cont, {env, frames, state, total + info.reward}}
  end)
end

:ok

Train

The loop interleaves environment steps with gradient updates. It renders a live dashboard right here — the game canvas (streamed from the training env, throttled to ~15 fps), the reward-per-episode chart, and a status line — and also IO.puts one line per episode, so you can watch progress as it runs.

hparams = %{
  episodes: 400,
  warmup: 5_000,
  batch_size: 32,
  train_every: 4,
  target_update_every: 1_000,
  buffer_size: 20_000,
  eps_start: 1.0,
  eps_end: 0.05,
  eps_decay_steps: 30_000,
  # Throttle live canvas updates to ~15 fps so we don't flood Livebook with frames.
  render_every_ms: 66,
  max_steps_per_episode: 10_000
}

epsilon_at = fn step ->
  frac = min(step / hparams.eps_decay_steps, 1.0)
  hparams.eps_start + frac * (hparams.eps_end - hparams.eps_start)
end

# --- Live dashboard, rendered inline in this cell ---
status = Kino.Frame.new()
Kino.render(status)
Kino.render(viewer)
Kino.render(reward_chart)

initial = %{
  params: params,
  target_params: target_params,
  opt_state: opt_state,
  buffer: Replay.new(hparams.buffer_size),
  step: 0,
  last_loss: nil,
  rewards: [],
  last_render_ms: System.monotonic_time(:millisecond)
}

final =
  Enum.reduce(1..hparams.episodes, initial, fn episode, agent ->
    {env, frames, state} = start_episode.(env)

    {agent, episode_reward, _frames, _state, _env} =
      Enum.reduce_while(1..hparams.max_steps_per_episode, {agent, 0.0, frames, state, env}, fn _, acc ->
        {agent, ep_r, frames, state, env} = acc

        epsilon = epsilon_at.(agent.step)
        action_idx = select_action.(agent.params, state, epsilon)

        {env, info} = Alex.step(env, Enum.at(actions, action_idx))
        {frames, next_state} = push_frame_stack.(frames, env)
        reward = clip_reward.(info.reward)

        buffer = Replay.add(agent.buffer, {state, action_idx, reward, next_state, info.game_over?})
        agent = %{agent | buffer: buffer, step: agent.step + 1}

        # Learn from a minibatch once we have enough experience.
        agent =
          if buffer.size >= hparams.warmup and rem(agent.step, hparams.train_every) == 0 do
            batch = Replay.sample(buffer, hparams.batch_size)
            targets = td_targets.(agent.target_params, batch.next_states, batch.rewards, batch.dones)

            {new_params, new_opt, loss} =
              train_step.(agent.params, agent.opt_state, batch.states, batch.actions, targets)

            %{agent | params: new_params, opt_state: new_opt, last_loss: loss}
          else
            agent
          end

        # Periodically refresh the target network.
        agent =
          if rem(agent.step, hparams.target_update_every) == 0,
            do: %{agent | target_params: agent.params},
            else: agent

        # Stream the live game to the canvas, throttled by wall-clock time.
        now = System.monotonic_time(:millisecond)

        agent =
          if now - agent.last_render_ms >= hparams.render_every_ms do
            Alex.Kino.push_frame(viewer, env)
            %{agent | last_render_ms: now}
          else
            agent
          end

        acc = {agent, ep_r + info.reward, frames, next_state, env}
        if info.game_over?, do: {:halt, acc}, else: {:cont, acc}
      end)

    # --- Per-episode progress ---
    rewards = Enum.take([episode_reward | agent.rewards], 100)
    recent = Enum.take(rewards, 10)
    avg10 = Float.round(Enum.sum(recent) / length(recent), 2)
    eps = Float.round(epsilon_at.(agent.step), 3)
    loss = if agent.last_loss, do: Float.round(Nx.to_number(agent.last_loss), 4)
    fill = "#{agent.buffer.size}/#{hparams.buffer_size}"

    Kino.VegaLite.push(reward_chart, %{episode: episode, reward: episode_reward})

    Kino.Frame.render(
      status,
      Kino.Markdown.new("""
      **Episode #{episode}/#{hparams.episodes}** — step #{agent.step}, ε #{eps}, buffer #{fill}

      reward **#{episode_reward}** · avg(10) **#{avg10}** · last loss #{loss || "—"}
      """)
    )

    IO.puts("ep #{episode}/#{hparams.episodes}\treward #{episode_reward}\tavg10 #{avg10}\tε #{eps}\tloss #{loss || "—"}")

    %{agent | rewards: rewards}
  end)

:ok

Watch the trained agent

Run this cell any time to watch the current policy play a full greedy episode in the live viewer above.

play_and_render.(final.params)

Where to go next

  • Pixels instead of RAM. Replace Obs.frame/1 with a downsampled grayscale tensor (Alex.Screen.grayscale/1 + Nx.window_mean) and swap the MLP for Axon.conv layers. Expect to train far longer.
  • Better targets. Add Double DQN (use the online net to pick the action, the target net to evaluate it) — a small change in DQN.td_target.
  • Other games. Change gameboxing (dense +1/-1 reward) also shows progress relatively quickly; breakout is iconic but slower to learn.
  • Prioritized replay for sample efficiency.
  • Snapshots. Use Alex.Snapshot to checkpoint interesting states for debugging or curriculum learning.

> Caveat: this notebook demonstrates the full pipeline and the live tooling; it is not tuned to > master a game. Real Atari agents train for tens of millions of frames. Treat the live viewer as > a window into learning dynamics, not a leaderboard run.