Deep Q-Learning on Atari with ALEx
Mix.install(
[
# This notebook lives in the repo, so we use the local checkout (Livebook sets
# the working directory to the notebook's folder). Swap for the published
# package with {:alex, "~> 0.4"} if running it standalone.
{:alex, path: ".."},
{:nx, "~> 0.12"},
{:axon, "~> 0.8"},
{:exla, "~> 0.12"},
{:kino, "~> 0.12"},
{:kino_vega_lite, "~> 0.1"}
],
config: [
# Run all Nx work on the EXLA (XLA) backend, and JIT-compile every defn with it.
nx: [default_backend: EXLA.Backend, default_defn_options: [compiler: EXLA]]
],
force: true
)
What this notebook does
This trains a Deep Q-Network to play an Atari game from inside Elixir, using:
- ALEx for the emulator,
- Nx + EXLA for tensors and XLA-compiled math,
- Axon for the neural network, and
- Kino to watch the agent play live while it learns.
The agent plays Pong and learns from the console’s 128 bytes of RAM rather than from pixels. RAM is a compact, information-rich state, so a small multilayer perceptron can make visible progress within a notebook session — whereas a pixel-based convolutional DQN famously needs millions of frames. We stack the last 4 RAM frames so the network can perceive motion (a single frame can’t tell you which way the ball is moving). The live view still shows the real game screen; only the agent’s eyes are the stacked RAM.
A serious caveat: this is a demonstration of the full pipeline and the live tooling, not a tuned agent. Even with these helps, expect tens of minutes before the reward curve clearly turns upward — real Atari agents train for tens of millions of frames.
> For the full pixel-based experience, replace Obs.frame/1 with a downsampled grayscale tensor
> (Alex.Screen.grayscale/1) and swap the MLP below for Axon.conv layers.
The environment
game = "pong"
# frame_skip: 4 is the standard Atari setting — each action is held for 4 frames.
env = Alex.new(game,
frame_skip: 4,
repeat_action_probability: 0.0,
random_seed: 0,
rom_dir: Path.expand("~/.alex/roms")
)
actions = Alex.minimal_actions(env)
n_actions = length(actions)
# Stack the last few RAM frames so the network can infer motion (which way the
# ball/paddle is moving) — a single frame can't convey velocity.
stack_size = 4
state_size = stack_size * env.ram_size
IO.puts("Game: #{game}")
IO.puts("Actions (#{n_actions}): #{inspect(actions)}")
IO.puts("State size (#{stack_size} x #{env.ram_size} RAM bytes): #{state_size}")
Helpers to turn raw RAM into a normalized :f32 frame, and to stack frames into a state.
defmodule Obs do
import Nx.Defn
@doc "RAM binary -> {128} f32 tensor scaled to [0, 1]."
def frame(env) do
env
|> Alex.RAM.read()
|> Nx.from_binary(:u8)
|> normalize()
end
defn normalize(bytes), do: Nx.as_type(bytes, :f32) / 255.0
@doc "Concatenate a list of frame tensors into a single stacked state vector."
def stack(frames), do: Nx.concatenate(frames)
end
The Q-network
A plain MLP that maps a state to one Q-value per action.
model =
Axon.input("state", shape: {nil, state_size})
|> Axon.dense(256, activation: :relu)
|> Axon.dense(256, activation: :relu)
|> Axon.dense(n_actions)
{init_fn, predict_fn} = Axon.build(model)
template = Nx.template({1, state_size}, :f32)
params = init_fn.(template, Axon.ModelState.empty())
# A separate, slowly-updated "target" network stabilizes the learning target.
target_params = params
IO.inspect model
The learning step
A DQN module holds the whole numeric pipeline as defns, exposed through three thin wrappers:
-
predict— Q-values for a batch of states (used for action selection), -
td_targets— the temporal-difference targets from the target network (no gradient), and -
train_step— one gradient update of the online network against those targets, using a Huber loss for robustness.
The network (predict_fn) and optimizer (opt_update) are passed into the defns as arguments —
defn accepts function-valued arguments and can call them — so all the tensor math, including the
forward pass and the optimizer step, lives in defn where the + - * / operators dispatch to Nx.
gamma = 0.99
{opt_init, opt_update} = Polaris.Optimizers.adam(learning_rate: 2.5e-4)
opt_state = opt_init.(params)
# The whole numeric pipeline is `defn`, where +, -, * and / dispatch to Nx.
# `predict_fn` (the network forward, a closure from Axon.build/2) and `opt_update`
# (the optimizer) are passed *into* the defns as arguments — defn accepts
# function-valued arguments and can call them — so there are no jitted `fn`
# wrappers around tensor math.
defmodule DQN do
import Nx.Defn
# Q-values for a batch of states.
defn predict(predict_fn, params, states) do
predict_fn.(params, %{"state" => states})
end
# TD target: r + gamma * max_a' Q_target(s', a') * (1 - done).
defn td_target(predict_fn, target_params, next_states, rewards, dones, gamma) do
q_next = predict_fn.(target_params, %{"state" => next_states})
max_next = Nx.reduce_max(q_next, axes: [1])
rewards + gamma * max_next * (1.0 - dones)
end
# Huber loss, elementwise: 0.5 * t^2 for |t| <= 1, else |t| - 0.5.
defnp huber(td) do
abs = Nx.abs(td)
quadratic = Nx.min(abs, 1.0)
linear = abs - quadratic
0.5 * quadratic * quadratic + linear
end
# Mean Huber loss between the chosen-action Q-values and the TD targets.
defnp loss(predict_fn, params, states, actions, targets) do
q = predict_fn.(params, %{"state" => states})
idx = Nx.new_axis(actions, -1)
q_taken = q |> Nx.take_along_axis(idx, axis: 1) |> Nx.squeeze(axes: [1])
Nx.mean(huber(targets - q_taken))
end
# One optimizer step minimizing the masked Huber loss over the batch.
defn train_step(predict_fn, opt_update, params, opt_state, states, actions, targets) do
{loss_value, grads} =
Nx.Defn.value_and_grad(params, fn params ->
loss(predict_fn, params, states, actions, targets)
end)
{updates, opt_state} = opt_update.(grads, opt_state, params)
{Polaris.Updates.apply_updates(params, updates), opt_state, loss_value}
end
end
# Thin partial applications so the training loop doesn't repeat the fixed args.
# Calling a defn from regular Elixir JIT-compiles it (once) with the global EXLA
# compiler set in Mix.install — no explicit Nx.Defn.jit needed.
predict = &DQN.predict(predict_fn, &1, &2)
td_targets = fn target_params, next_states, rewards, dones ->
DQN.td_target(predict_fn, target_params, next_states, rewards, dones, gamma)
end
train_step = fn params, opt_state, states, actions, targets ->
DQN.train_step(predict_fn, opt_update, params, opt_state, states, actions, targets)
end
:ok
Acting and remembering
Epsilon-greedy action selection, and a fixed-size replay buffer (a list capped to buffer_size).
defmodule Replay do
@moduledoc "A bounded experience-replay buffer backed by a list."
def new(max_size), do: %{items: [], size: 0, max: max_size}
def add(buf, transition) do
items = [transition | buf.items]
if buf.size >= buf.max do
%{buf | items: Enum.take(items, buf.max)}
else
%{buf | items: items, size: buf.size + 1}
end
end
def sample(buf, batch_size) do
batch = Enum.take_random(buf.items, batch_size)
states = Enum.map(batch, &elem(&1, 0)) |> Nx.stack()
actions = Enum.map(batch, &elem(&1, 1)) |> Nx.tensor(type: :s64)
rewards = Enum.map(batch, &elem(&1, 2)) |> Nx.tensor(type: :f32)
next_states = Enum.map(batch, &elem(&1, 3)) |> Nx.stack()
dones = Enum.map(batch, &(if elem(&1, 4), do: 1.0, else: 0.0)) |> Nx.tensor(type: :f32)
%{states: states, actions: actions, rewards: rewards, next_states: next_states, dones: dones}
end
end
# Returns an action *index* in 0..n_actions-1.
select_action = fn params, state, epsilon ->
if :rand.uniform() < epsilon do
:rand.uniform(n_actions) - 1
else
params
|> predict.(Nx.new_axis(state, 0))
|> Nx.argmax(axis: 1)
|> Nx.squeeze()
|> Nx.to_number()
end
end
# Reward clipping to [-1, 1] is standard DQN practice across games.
clip_reward = fn r -> r |> max(-1) |> min(1) end
# Start an episode: reset, press FIRE once to serve the ball (games like Pong and
# Breakout won't start otherwise), and seed the frame stack with the first frame.
# Returns {env, frames, state}.
start_episode = fn env ->
env = Alex.reset(env)
env = if :fire in actions, do: elem(Alex.step(env, :fire), 0), else: env
frames = List.duplicate(Obs.frame(env), stack_size)
{env, frames, Obs.stack(frames)}
end
# Advance the frame stack with a new frame, returning {frames, state}.
push_frame_stack = fn frames, env ->
frames = [Obs.frame(env) | frames] |> Enum.take(stack_size)
{frames, Obs.stack(frames)}
end
:ok
Live views
Create the widgets in their own cells so Livebook renders them. The training loop below pushes into them as it runs: the canvas shows a greedy “evaluation” episode every few training episodes, and the chart streams the reward per episode.
# A dedicated env used only for visualization, so rendering never disturbs training state.
eval_env = Alex.new(game,
frame_skip: 4,
repeat_action_probability: 0.0,
random_seed: 1,
rom_dir: Path.expand("~/.alex/roms")
)
viewer = Alex.Kino.view(eval_env, scale: 3)
alias VegaLite, as: Vl
reward_chart =
Vl.new(width: 640, height: 320, title: "Episode reward")
|> Vl.mark(:line, point: true)
|> Vl.encode_field(:x, "episode", type: :quantitative)
|> Vl.encode_field(:y, "reward", type: :quantitative)
|> Kino.VegaLite.new()
# Plays one fully-greedy episode on eval_env, pushing each frame to the live canvas.
play_and_render = fn params ->
{env, frames, state} = start_episode.(eval_env)
Enum.reduce_while(1..10_000, {env, frames, state, 0.0}, fn _, {env, frames, state, total} ->
action_idx = select_action.(params, state, 0.0)
{env, info} = Alex.step(env, Enum.at(actions, action_idx))
{frames, state} = push_frame_stack.(frames, env)
Alex.Kino.push_frame(viewer, env)
# Slow the playback down to a watchable speed.
Process.sleep(15)
if info.game_over?,
do: {:halt, total + info.reward},
else: {:cont, {env, frames, state, total + info.reward}}
end)
end
:ok
Train
The loop interleaves environment steps with gradient updates. It renders a live dashboard right
here — the game canvas (streamed from the training env, throttled to ~15 fps), the reward-per-episode
chart, and a status line — and also IO.puts one line per episode, so you can watch progress as it
runs.
hparams = %{
episodes: 400,
warmup: 5_000,
batch_size: 32,
train_every: 4,
target_update_every: 1_000,
buffer_size: 20_000,
eps_start: 1.0,
eps_end: 0.05,
eps_decay_steps: 30_000,
# Throttle live canvas updates to ~15 fps so we don't flood Livebook with frames.
render_every_ms: 66,
max_steps_per_episode: 10_000
}
epsilon_at = fn step ->
frac = min(step / hparams.eps_decay_steps, 1.0)
hparams.eps_start + frac * (hparams.eps_end - hparams.eps_start)
end
# --- Live dashboard, rendered inline in this cell ---
status = Kino.Frame.new()
Kino.render(status)
Kino.render(viewer)
Kino.render(reward_chart)
initial = %{
params: params,
target_params: target_params,
opt_state: opt_state,
buffer: Replay.new(hparams.buffer_size),
step: 0,
last_loss: nil,
rewards: [],
last_render_ms: System.monotonic_time(:millisecond)
}
final =
Enum.reduce(1..hparams.episodes, initial, fn episode, agent ->
{env, frames, state} = start_episode.(env)
{agent, episode_reward, _frames, _state, _env} =
Enum.reduce_while(1..hparams.max_steps_per_episode, {agent, 0.0, frames, state, env}, fn _, acc ->
{agent, ep_r, frames, state, env} = acc
epsilon = epsilon_at.(agent.step)
action_idx = select_action.(agent.params, state, epsilon)
{env, info} = Alex.step(env, Enum.at(actions, action_idx))
{frames, next_state} = push_frame_stack.(frames, env)
reward = clip_reward.(info.reward)
buffer = Replay.add(agent.buffer, {state, action_idx, reward, next_state, info.game_over?})
agent = %{agent | buffer: buffer, step: agent.step + 1}
# Learn from a minibatch once we have enough experience.
agent =
if buffer.size >= hparams.warmup and rem(agent.step, hparams.train_every) == 0 do
batch = Replay.sample(buffer, hparams.batch_size)
targets = td_targets.(agent.target_params, batch.next_states, batch.rewards, batch.dones)
{new_params, new_opt, loss} =
train_step.(agent.params, agent.opt_state, batch.states, batch.actions, targets)
%{agent | params: new_params, opt_state: new_opt, last_loss: loss}
else
agent
end
# Periodically refresh the target network.
agent =
if rem(agent.step, hparams.target_update_every) == 0,
do: %{agent | target_params: agent.params},
else: agent
# Stream the live game to the canvas, throttled by wall-clock time.
now = System.monotonic_time(:millisecond)
agent =
if now - agent.last_render_ms >= hparams.render_every_ms do
Alex.Kino.push_frame(viewer, env)
%{agent | last_render_ms: now}
else
agent
end
acc = {agent, ep_r + info.reward, frames, next_state, env}
if info.game_over?, do: {:halt, acc}, else: {:cont, acc}
end)
# --- Per-episode progress ---
rewards = Enum.take([episode_reward | agent.rewards], 100)
recent = Enum.take(rewards, 10)
avg10 = Float.round(Enum.sum(recent) / length(recent), 2)
eps = Float.round(epsilon_at.(agent.step), 3)
loss = if agent.last_loss, do: Float.round(Nx.to_number(agent.last_loss), 4)
fill = "#{agent.buffer.size}/#{hparams.buffer_size}"
Kino.VegaLite.push(reward_chart, %{episode: episode, reward: episode_reward})
Kino.Frame.render(
status,
Kino.Markdown.new("""
**Episode #{episode}/#{hparams.episodes}** — step #{agent.step}, ε #{eps}, buffer #{fill}
reward **#{episode_reward}** · avg(10) **#{avg10}** · last loss #{loss || "—"}
""")
)
IO.puts("ep #{episode}/#{hparams.episodes}\treward #{episode_reward}\tavg10 #{avg10}\tε #{eps}\tloss #{loss || "—"}")
%{agent | rewards: rewards}
end)
:ok
Watch the trained agent
Run this cell any time to watch the current policy play a full greedy episode in the live viewer above.
play_and_render.(final.params)
Where to go next
-
Pixels instead of RAM. Replace
Obs.frame/1with a downsampled grayscale tensor (Alex.Screen.grayscale/1+Nx.window_mean) and swap the MLP forAxon.convlayers. Expect to train far longer. -
Better targets. Add Double DQN (use the online net to pick the action, the target net to
evaluate it) — a small change in
DQN.td_target. -
Other games. Change
game—boxing(dense +1/-1 reward) also shows progress relatively quickly;breakoutis iconic but slower to learn. - Prioritized replay for sample efficiency.
-
Snapshots. Use
Alex.Snapshotto checkpoint interesting states for debugging or curriculum learning.
> Caveat: this notebook demonstrates the full pipeline and the live tooling; it is not tuned to > master a game. Real Atari agents train for tens of millions of frames. Treat the live viewer as > a window into learning dynamics, not a leaderboard run.