ExPhil Evaluation Dashboard
What is ExPhil?
ExPhil is an Elixir-based AI that learns to play Super Smash Bros. Melee by watching human replay files. It uses imitation learning (also called behavioral cloning) to copy how humans play, then can be improved with reinforcement learning (PPO) to develop its own strategies.
This notebook lets you:
- Inspect trained models - See the neural network architecture and parameters
- Analyze training data - Visualize button press rates and stick positions from replays
- Test inference - Run the model on game states and see what actions it chooses
- Compare to humans - Measure how closely the model matches human play
Setup
Run this cell first to load all dependencies. This may take a minute on first run.
Mix.install([
{:exphil, path: Path.expand(".."), env: :dev},
{:kino, "~> 0.14"},
{:kino_vega_lite, "~> 0.1"}
])
alias ExPhil.{Agents, Embeddings, Training}
alias ExPhil.Training.{Imitation, Data}
alias ExPhil.Bridge.{GameState, Player, ControllerState}
alias ExPhil.Data.Peppi
alias VegaLite, as: Vl
IO.puts("Setup complete!")
1. Load a Trained Policy
A policy is a neural network that takes a game state as input and outputs controller actions. The policy file (.bin) contains:
- Model weights - The learned parameters of the neural network
- Architecture config - How the network is structured (layer sizes, etc.)
- Training config - Settings used during training (temporal mode, backbone type)
policy_path = Kino.Input.text("Policy Path",
default: "checkpoints/imitation_latest_policy.bin")
|> Kino.render()
:ok
policy_file = Kino.Input.read(policy_path)
case Training.load_policy(policy_file) do
{:ok, policy} ->
Process.put(:loaded_policy, policy)
Kino.Markdown.new("""
### Policy Loaded Successfully!
| Setting | Value | What it means |
|---------|-------|---------------|
| **Temporal** | `#{policy.config.temporal}` | #{if policy.config.temporal, do: "Uses past frames for context (better for combos)", else: "Single-frame decisions only"} |
| **Backbone** | `#{policy.config.backbone}` | #{case policy.config.backbone do; :sliding_window -> "Transformer attention over recent frames"; :lstm -> "LSTM recurrent memory"; :hybrid -> "LSTM + attention combined"; _ -> "Simple feedforward MLP"; end} |
| **Embed Size** | `#{policy.config.embed_size}` | Dimensions of the game state representation |
| **Hidden Sizes** | `#{inspect(policy.config.hidden_sizes)}` | Neurons in each hidden layer |
""")
{:error, reason} ->
Kino.Markdown.new("""
### Failed to load policy
**Error:** `#{inspect(reason)}`
Make sure the path is correct and the file exists.
""")
end
2. Model Architecture
Neural networks are made of layers. Each layer has weights (learned parameters) that transform the input. Here we show:
- Total parameters - How many numbers the model learned (more = more capacity but slower)
- Layer breakdown - Each layer’s shape and size
policy = Process.get(:loaded_policy)
if policy do
# Count total parameters
param_count = policy.params.data
|> Enum.flat_map(fn {_layer, params} ->
Enum.map(params, fn {_name, tensor} -> Nx.size(tensor) end)
end)
|> Enum.sum()
# Get layer details
layer_info = policy.params.data
|> Enum.map(fn {layer, params} ->
layer_params = Enum.map(params, fn {name, tensor} ->
{name, Nx.shape(tensor), Nx.size(tensor)}
end)
{layer, layer_params}
end)
|> Enum.sort()
layer_docs = Enum.map(layer_info, fn {layer, params} ->
# Explain what each layer type does
layer_explanation = cond do
String.contains?(layer, "backbone") -> "Hidden layer - learns game patterns"
String.contains?(layer, "buttons") -> "Button output - predicts A/B/X/Y/Z/L/R/D-pad"
String.contains?(layer, "main_x") -> "Main stick X - predicts left/right movement"
String.contains?(layer, "main_y") -> "Main stick Y - predicts up/down movement"
String.contains?(layer, "c_x") or String.contains?(layer, "c_y") -> "C-stick - predicts aerial/smash direction"
String.contains?(layer, "shoulder") -> "Shoulder - predicts L/R trigger pressure"
String.contains?(layer, "embed") -> "Embedding - encodes game state"
true -> "Neural network layer"
end
param_details = Enum.map(params, fn {name, shape, size} ->
" - `#{name}`: shape=`#{inspect(shape)}`, params=#{Number.delimit(size)}"
end) |> Enum.join("\n")
"""
**#{layer}** - #{layer_explanation}
#{param_details}
"""
end) |> Enum.join("\n")
Kino.Markdown.new("""
### Model Summary
| Metric | Value |
|--------|-------|
| **Total Parameters** | #{Number.delimit(param_count)} |
| **Model Size** | ~#{Float.round(param_count * 4 / 1024 / 1024, 2)} MB (float32) |
### What do the parameters mean?
- **kernel** = Weight matrix that transforms inputs
- **bias** = Offset added after transformation
- Shape `{input_dim, output_dim}` means: takes `input_dim` features, outputs `output_dim` features
### Layer Breakdown
#{layer_docs}
""")
else
Kino.Markdown.new("**No policy loaded.** Run the cell above first.")
end
3. Training Data Analysis
Before we can understand the model, we need to understand the training data. Melee replays (.slp files) contain frame-by-frame recordings of:
- Player positions, damage, stocks
- Controller inputs (buttons, sticks)
- Game state (stage, items, etc.)
What to look for:
- Button rates show which buttons humans press most often
- Stick distributions show common movement patterns
- A well-trained model should roughly match these distributions
replay_dir = Kino.Input.text("Replay Directory",
default: "./replays")
|> Kino.render()
max_files = Kino.Input.number("Max Files to Parse", default: 5)
|> Kino.render()
:ok
replay_path = Kino.Input.read(replay_dir)
max = Kino.Input.read(max_files)
replay_files = Path.wildcard(Path.join(replay_path, "**/*.slp"))
|> Enum.take(max)
if length(replay_files) == 0 do
Kino.Markdown.new("""
**No replay files found** in `#{replay_path}`
Make sure you have `.slp` files in that directory.
""")
else
IO.puts("Parsing #{length(replay_files)} replay files...")
all_frames = replay_files
|> Enum.flat_map(fn path ->
case Peppi.parse(path) do
{:ok, replay} -> Peppi.to_training_frames(replay)
{:error, _} -> []
end
end)
Process.put(:training_frames, all_frames)
Kino.Markdown.new("""
### Replays Parsed
| Stat | Value |
|------|-------|
| Files parsed | #{length(replay_files)} |
| Total frames | #{Number.delimit(length(all_frames))} |
| Duration | ~#{Float.round(length(all_frames) / 60 / 60, 1)} minutes of gameplay |
*Melee runs at 60 FPS, so each frame is ~16.7ms*
""")
end
Button Press Rates
Each bar shows what percentage of frames had that button pressed. In Melee:
- A = Normal attacks (jab, tilts, aerials)
- B = Special moves (character-specific)
- X/Y = Jump
- Z = Grab / aerial attack
- L/R = Shield, wavedash, L-cancel
- D-pad = Taunt (rarely used competitively)
frames = Process.get(:training_frames, [])
if length(frames) > 0 do
# Count button presses
button_counts = frames
|> Enum.reduce(%{a: 0, b: 0, x: 0, y: 0, z: 0, l: 0, r: 0, d_up: 0}, fn frame, acc ->
buttons = frame.action.buttons
Map.merge(acc, buttons, fn _k, v1, v2 ->
v1 + if(v2, do: 1, else: 0)
end)
end)
total = length(frames)
button_rates = Map.new(button_counts, fn {k, v} -> {k, v / total * 100} end)
chart_data = Enum.map(button_rates, fn {button, rate} ->
%{button: to_string(button) |> String.upcase(), rate: Float.round(rate, 2)}
end)
Vl.new(width: 500, height: 250)
|> Vl.data_from_values(chart_data)
|> Vl.mark(:bar, color: "#4C78A8")
|> Vl.encode_field(:x, "button",
type: :nominal,
sort: "-y",
title: "Button",
axis: [label_angle: 0])
|> Vl.encode_field(:y, "rate",
type: :quantitative,
title: "Press Rate (%)")
|> Vl.config(title: [text: "Human Button Press Rates", fontSize: 16])
else
Kino.Markdown.new("**No frames loaded.** Run the replay parsing cell first.")
end
Stick Position Distributions
The GameCube controller has two analog sticks:
- Main Stick (left) - Movement, attacks direction
- C-Stick (right) - Smash attacks, aerial direction
Positions are discretized to 0-16:
- 0 = Full left/down
- 8 = Neutral (center)
- 16 = Full right/up
Competitive players often use:
- Cardinal directions (0, 8, 16) for precise inputs
- Slight angles for DI (directional influence)
- Neutral position when not moving
frames = Process.get(:training_frames, [])
if length(frames) > 0 do
# Sample for performance (scatter plots with too many points are slow)
sample_size = min(3000, length(frames))
sampled = Enum.take_random(frames, sample_size)
stick_data = Enum.map(sampled, fn frame ->
%{
main_x: frame.action.main_x,
main_y: frame.action.main_y,
c_x: frame.action.c_x,
c_y: frame.action.c_y
}
end)
main_chart = Vl.new(width: 280, height: 280)
|> Vl.data_from_values(stick_data)
|> Vl.mark(:circle, opacity: 0.15, size: 8, color: "#E45756")
|> Vl.encode_field(:x, "main_x",
type: :quantitative,
scale: [domain: [0, 16]],
title: "X (0=left, 16=right)")
|> Vl.encode_field(:y, "main_y",
type: :quantitative,
scale: [domain: [0, 16]],
title: "Y (0=down, 16=up)")
|> Vl.config(title: [text: "Main Stick", fontSize: 14])
c_chart = Vl.new(width: 280, height: 280)
|> Vl.data_from_values(stick_data)
|> Vl.mark(:circle, opacity: 0.15, size: 8, color: "#72B7B2")
|> Vl.encode_field(:x, "c_x",
type: :quantitative,
scale: [domain: [0, 16]],
title: "X (0=left, 16=right)")
|> Vl.encode_field(:y, "c_y",
type: :quantitative,
scale: [domain: [0, 16]],
title: "Y (0=down, 16=up)")
|> Vl.config(title: [text: "C-Stick", fontSize: 14])
Vl.concat([main_chart, c_chart], :horizontal)
else
Kino.Markdown.new("**No frames loaded.**")
end
4. Live Inference Testing
Now let’s see what the model actually outputs! We create a synthetic game state and ask the model what action it would take.
Understanding the outputs:
- Buttons - Which buttons the model wants to press (sampled from probabilities)
- Main Stick - Movement direction (0-16 discretized)
- C-Stick - Smash/aerial direction
- Shoulder - L/R trigger pressure (for shields, wavedash)
Since the policy is stochastic (probabilistic), running it multiple times on the same state gives different actions. This is intentional - it prevents the AI from being predictable.
policy = Process.get(:loaded_policy)
if policy do
# Create a neutral game state: both players standing on stage
game_state = %GameState{
frame: 0,
stage: 2, # Final Destination (flat stage, no platforms)
players: %{
1 => %Player{
x: 0.0, # Center stage
y: 0.0, # On ground
percent: 0.0, # No damage
stock: 4, # Full stocks
facing: 1, # Facing right
character: 9, # Mewtwo
action: 14, # Standing/idle action
action_frame: 0,
invulnerable: false,
jumps_left: 2,
on_ground: true,
shield_strength: 60.0
},
2 => %Player{
x: 20.0, # Opponent to the right
y: 0.0,
percent: 50.0, # Opponent has some damage
stock: 4,
facing: -1, # Facing left (toward us)
character: 2, # Fox (common opponent)
action: 14,
action_frame: 0,
invulnerable: false,
jumps_left: 2,
on_ground: true,
shield_strength: 60.0
}
}
}
{:ok, agent} = Agents.Agent.start_link(policy: policy)
# Sample 50 actions to see the distribution
samples = for _ <- 1..50 do
{:ok, action} = Agents.Agent.get_action(agent, game_state, player_port: 1)
action
end
GenServer.stop(agent)
# Analyze the samples
button_names = [:a, :b, :x, :y, :z, :l, :r, :d_up]
button_counts = samples
|> Enum.reduce(Map.new(button_names, fn b -> {b, 0} end), fn action, acc ->
Enum.reduce(button_names, acc, fn btn, inner_acc ->
pressed = case action.buttons do
%Nx.Tensor{} = t -> Nx.to_flat_list(t) |> Enum.at(Enum.find_index(button_names, &(&1 == btn))) == 1
map when is_map(map) -> Map.get(map, btn, false)
_ -> false
end
Map.update!(inner_acc, btn, &(&1 + if(pressed, do: 1, else: 0)))
end)
end)
get_stick_value = fn val ->
case val do
%Nx.Tensor{} -> Nx.to_number(Nx.squeeze(val))
n when is_number(n) -> n
_ -> 8
end
end
main_x_samples = Enum.map(samples, &get_stick_value.(&1.main_x))
main_y_samples = Enum.map(samples, &get_stick_value.(&1.main_y))
main_x_mode = main_x_samples |> Enum.frequencies() |> Enum.max_by(fn {_, v} -> v end) |> elem(0)
main_y_mode = main_y_samples |> Enum.frequencies() |> Enum.max_by(fn {_, v} -> v end) |> elem(0)
Kino.Markdown.new("""
### Scenario: Neutral Game
**Game State:**
- Player (Mewtwo) at center stage, 0% damage
- Opponent (Fox) 20 units to the right, 50% damage
- Both standing on ground
### Model Output (50 samples)
**Button Press Frequency:**
#{Enum.map(button_counts, fn {btn, count} ->
pct = Float.round(count / 50 * 100, 0)
bar = String.duplicate("█", trunc(pct / 5))
"| #{String.upcase(to_string(btn))} | #{bar} #{pct}% |"
end) |> Enum.join("\n")}
**Main Stick:**
- Most common X: #{main_x_mode} #{if main_x_mode < 8, do: "(left)", else: if main_x_mode > 8, do: "(right)", else: "(neutral)"}
- Most common Y: #{main_y_mode} #{if main_y_mode < 8, do: "(down)", else: if main_y_mode > 8, do: "(up)", else: "(neutral)"}
*Note: Stochastic sampling means each run gives slightly different results*
""")
else
Kino.Markdown.new("**No policy loaded.**")
end
5. Deterministic vs Stochastic Mode
The model can run in two modes:
- Stochastic (default) - Samples actions from probability distributions. More varied, less predictable.
- Deterministic - Always picks the highest-probability action. Consistent but exploitable.
Let’s compare them on a more interesting scenario: Player at high percent vs attacking opponent
policy = Process.get(:loaded_policy)
if policy do
# Dangerous scenario: we're at high percent, opponent is attacking
game_state = %GameState{
frame: 0,
stage: 2,
players: %{
1 => %Player{
x: -30.0, # Left side of stage
y: 0.0,
percent: 120.0, # HIGH damage - could die!
stock: 2, # Only 2 stocks left
facing: 1,
character: 9, # Mewtwo
action: 14, # Standing (vulnerable)
action_frame: 0,
invulnerable: false,
jumps_left: 1, # Used one jump
on_ground: true,
shield_strength: 30.0 # Shield is low
},
2 => %Player{
x: -10.0, # Close and approaching
y: 0.0,
percent: 40.0,
stock: 3,
facing: -1,
character: 2, # Fox
action: 20, # Dashing toward us
action_frame: 5,
invulnerable: false,
jumps_left: 2,
on_ground: true,
shield_strength: 60.0
}
}
}
# Deterministic: always same output
{:ok, det_agent} = Agents.Agent.start_link(policy: policy, deterministic: true)
det_actions = for _ <- 1..5 do
{:ok, action} = Agents.Agent.get_action(det_agent, game_state, player_port: 1)
action
end
GenServer.stop(det_agent)
# Stochastic: varies each time
{:ok, stoch_agent} = Agents.Agent.start_link(policy: policy, deterministic: false)
stoch_actions = for _ <- 1..5 do
{:ok, action} = Agents.Agent.get_action(stoch_agent, game_state, player_port: 1)
action
end
GenServer.stop(stoch_agent)
button_names = [:a, :b, :x, :y, :z, :l, :r, :d_up]
format_action = fn action ->
buttons = case action.buttons do
%Nx.Tensor{} = t ->
Nx.to_flat_list(t)
|> Enum.with_index()
|> Enum.filter(fn {v, _} -> v == 1 end)
|> Enum.map(fn {_, i} -> Enum.at(button_names, i) end)
map when is_map(map) ->
Enum.filter(map, fn {_, v} -> v end) |> Enum.map(fn {k, _} -> k end)
_ -> []
end
get_val = fn v ->
case v do
%Nx.Tensor{} -> Nx.to_number(Nx.squeeze(v))
n -> n
end
end
btn_str = if buttons == [], do: "none", else: Enum.join(buttons, ",")
"#{btn_str} stick=(#{get_val.(action.main_x)},#{get_val.(action.main_y)})"
end
Kino.Markdown.new("""
### Scenario: Survival Situation
**Game State:**
- Player (Mewtwo) at **120% damage** - kill percent!
- Low shield, only 1 jump left
- Fox dashing toward us from 20 units away
### Deterministic Mode (5 runs - should be identical)
#{Enum.with_index(det_actions, 1) |> Enum.map(fn {a, i} -> "#{i}. #{format_action.(a)}" end) |> Enum.join("\n")}
### Stochastic Mode (5 runs - should vary)
#{Enum.with_index(stoch_actions, 1) |> Enum.map(fn {a, i} -> "#{i}. #{format_action.(a)}" end) |> Enum.join("\n")}
**Interpretation:**
- Deterministic always picks the "best" action, but opponents can predict it
- Stochastic explores alternatives, making the AI less exploitable
- Good models should show sensible defensive options (shield, retreat, jump)
""")
else
Kino.Markdown.new("**No policy loaded.**")
end
6. Model vs Human Comparison
The ultimate test: How well does the model match human play?
We take random frames from the replays and compare:
- What the human actually did
- What the model would do in the same situation
Metrics:
- Button Accuracy - % of buttons that match (all 7 buttons)
- Stick Match - Whether stick position is within 2 units (allowing for small variations)
Note: 100% accuracy isn’t the goal! Humans vary their play, and there are often multiple good options.
frames = Process.get(:training_frames, [])
policy = Process.get(:loaded_policy)
if length(frames) > 0 and policy do
{:ok, agent} = Agents.Agent.start_link(policy: policy, deterministic: true)
# Sample frames for comparison
num_samples = min(50, length(frames))
sample_frames = Enum.take_random(frames, num_samples)
button_names = [:a, :b, :x, :y, :z, :l, :r]
comparisons = Enum.map(sample_frames, fn frame ->
game_state = %GameState{
frame: 0,
stage: 2,
players: %{
1 => frame.player,
2 => frame.opponent
}
}
{:ok, model_action} = Agents.Agent.get_action(agent, game_state, player_port: 1)
human_action = frame.action
# Get model button values
model_buttons = case model_action.buttons do
%Nx.Tensor{} = t ->
vals = Nx.to_flat_list(t)
Map.new(Enum.with_index(button_names), fn {btn, i} -> {btn, Enum.at(vals, i, 0) == 1} end)
map when is_map(map) -> map
_ -> %{}
end
# Compare buttons
button_matches = Enum.count(button_names, fn btn ->
model_val = Map.get(model_buttons, btn, false)
human_val = Map.get(human_action.buttons, btn, false)
model_val == human_val
end)
# Get stick values
get_val = fn v ->
case v do
%Nx.Tensor{} -> Nx.to_number(Nx.squeeze(v))
n when is_number(n) -> n
_ -> 8
end
end
model_x = get_val.(model_action.main_x)
model_y = get_val.(model_action.main_y)
human_x = human_action.main_x
human_y = human_action.main_y
stick_x_close = abs(model_x - human_x) <= 2
stick_y_close = abs(model_y - human_y) <= 2
%{
button_accuracy: button_matches / 7 * 100,
stick_match: stick_x_close and stick_y_close,
exact_stick: model_x == human_x and model_y == human_y
}
end)
GenServer.stop(agent)
avg_button_acc = Enum.map(comparisons, & &1.button_accuracy) |> Enum.sum() |> Kernel./(length(comparisons))
stick_match_rate = Enum.count(comparisons, & &1.stick_match) / length(comparisons) * 100
exact_stick_rate = Enum.count(comparisons, & &1.exact_stick) / length(comparisons) * 100
# Determine quality
quality = cond do
avg_button_acc > 85 and stick_match_rate > 70 -> {"Excellent", "The model closely matches human play"}
avg_button_acc > 75 and stick_match_rate > 50 -> {"Good", "The model has learned basic patterns"}
avg_button_acc > 65 -> {"Fair", "More training data or epochs may help"}
true -> {"Needs work", "Consider more training or architecture changes"}
end
Kino.Markdown.new("""
### Model vs Human Comparison
**Sample size:** #{num_samples} random frames
| Metric | Value | Explanation |
|--------|-------|-------------|
| **Button Accuracy** | #{Float.round(avg_button_acc, 1)}% | How often all 7 buttons match |
| **Stick Match (±2)** | #{Float.round(stick_match_rate, 1)}% | Stick within 2 units of human |
| **Exact Stick Match** | #{Float.round(exact_stick_rate, 1)}% | Stick exactly matches human |
### Assessment: #{elem(quality, 0)}
#{elem(quality, 1)}
**Notes:**
- Perfect accuracy (100%) isn't expected - humans have personal style
- Button accuracy >80% and stick >60% indicates good imitation
- Low scores may indicate: not enough training data, wrong player port, or model architecture issues
""")
else
Kino.Markdown.new("""
**Cannot run comparison.**
Make sure you have:
1. Loaded a policy (Section 1)
2. Parsed replay files (Section 3)
""")
end
7. Backbone Comparison (Mamba vs LSTM)
Different backbone architectures have different trade-offs:
| Backbone | Complexity | Speed | Use Case |
|---|---|---|---|
| LSTM | O(L) per step | ~220ms | Good accuracy, slower inference |
| Mamba (SSM) | O(L) total | ~9ms | Fast inference, newer architecture |
| Sliding Window | O(W²) | ~50ms | Attention over recent frames |
This section lets you compare two trained policies with different backbones.
# Load two policies for comparison
policy1_path = Kino.Input.text("Policy 1 (e.g., LSTM)",
default: "checkpoints/lstm_policy.bin")
|> Kino.render()
policy2_path = Kino.Input.text("Policy 2 (e.g., Mamba)",
default: "checkpoints/mamba_policy.bin")
|> Kino.render()
:ok
alias ExPhil.Networks.Policy
p1_file = Kino.Input.read(policy1_path)
p2_file = Kino.Input.read(policy2_path)
load_policy_info = fn path ->
case Training.load_policy(path) do
{:ok, policy} ->
param_count = policy.params.data
|> Enum.flat_map(fn {_layer, params} ->
Enum.map(params, fn {_name, tensor} -> Nx.size(tensor) end)
end)
|> Enum.sum()
{:ok, %{
path: path,
backbone: policy.config[:backbone] || :mlp,
temporal: policy.config[:temporal] || false,
hidden_size: policy.config[:hidden_size] || hd(policy.config[:hidden_sizes] || [256]),
param_count: param_count,
size_mb: Float.round(param_count * 4 / 1024 / 1024, 2),
policy: policy
}}
error -> error
end
end
results = {load_policy_info.(p1_file), load_policy_info.(p2_file)}
case results do
{{:ok, p1}, {:ok, p2}} ->
Process.put(:compare_policies, {p1, p2})
Kino.Markdown.new("""
### Policies Loaded for Comparison
| Metric | #{p1.backbone} | #{p2.backbone} |
|--------|----------------|----------------|
| **Parameters** | #{Number.delimit(p1.param_count)} | #{Number.delimit(p2.param_count)} |
| **Size** | #{p1.size_mb} MB | #{p2.size_mb} MB |
| **Temporal** | #{p1.temporal} | #{p2.temporal} |
| **Hidden Size** | #{p1.hidden_size} | #{p2.hidden_size} |
""")
{{:error, r1}, {:error, r2}} ->
Kino.Markdown.new("**Failed to load both policies:** #{inspect(r1)}, #{inspect(r2)}")
{{:error, r}, _} ->
Kino.Markdown.new("**Failed to load Policy 1:** #{inspect(r)}")
{_, {:error, r}} ->
Kino.Markdown.new("**Failed to load Policy 2:** #{inspect(r)}")
end
Inference Speed Benchmark
Compare how fast each backbone processes game states. Lower is better for real-time gameplay (need <16ms for 60 FPS).
case Process.get(:compare_policies) do
{p1, p2} ->
# Create test input
embed_size = p1.policy.config[:embed_size] || 1991
window_size = p1.policy.config[:window_size] || 60
# Run benchmark
benchmark_policy = fn policy, name ->
# Build model
{model, _} = if policy.config[:temporal] do
Policy.build_temporal(
embed_size: policy.config[:embed_size],
hidden_sizes: policy.config[:hidden_sizes] || [256, 256],
backbone: policy.config[:backbone] || :lstm,
window_size: policy.config[:window_size] || 60,
hidden_size: policy.config[:hidden_size] || 256,
state_size: policy.config[:state_size] || 16,
expand_factor: policy.config[:expand_factor] || 2,
conv_size: policy.config[:conv_size] || 4,
num_layers: policy.config[:num_layers] || 2
)
else
Policy.build(
embed_size: policy.config[:embed_size],
hidden_sizes: policy.config[:hidden_sizes] || [256, 256]
)
end
# Create input
input_shape = if policy.config[:temporal] do
{1, policy.config[:window_size] || 60, policy.config[:embed_size]}
else
{1, policy.config[:embed_size]}
end
input = Nx.random_uniform(input_shape)
# Warmup
predict_fn = Axon.build(model, mode: :inference)
for _ <- 1..3, do: predict_fn.(policy.params, input)
# Benchmark
runs = 10
times = for _ <- 1..runs do
start = System.monotonic_time(:microsecond)
predict_fn.(policy.params, input)
System.monotonic_time(:microsecond) - start
end
avg_us = Enum.sum(times) / runs
avg_ms = Float.round(avg_us / 1000, 2)
%{name: name, backbone: policy.config[:backbone], avg_ms: avg_ms, fps_ready: avg_ms < 16}
end
b1 = benchmark_policy.(p1.policy, "Policy 1")
b2 = benchmark_policy.(p2.policy, "Policy 2")
speedup = if b1.avg_ms > b2.avg_ms do
Float.round(b1.avg_ms / b2.avg_ms, 1)
else
Float.round(b2.avg_ms / b1.avg_ms, 1)
end
faster = if b1.avg_ms < b2.avg_ms, do: b1.backbone, else: b2.backbone
Kino.Markdown.new("""
### Benchmark Results
| Policy | Backbone | Avg Inference | 60 FPS Ready? |
|--------|----------|---------------|---------------|
| Policy 1 | #{b1.backbone} | **#{b1.avg_ms} ms** | #{if b1.fps_ready, do: "✅ Yes", else: "❌ No"} |
| Policy 2 | #{b2.backbone} | **#{b2.avg_ms} ms** | #{if b2.fps_ready, do: "✅ Yes", else: "❌ No"} |
**#{faster}** is **#{speedup}x faster**
> For 60 FPS gameplay, inference must complete in <16ms per frame.
""")
nil ->
Kino.Markdown.new("**Load two policies above first.**")
end
Visual Comparison
case Process.get(:compare_policies) do
{p1, p2} ->
data = [
%{metric: "Parameters (K)", policy1: p1.param_count / 1000, policy2: p2.param_count / 1000},
%{metric: "Size (MB)", policy1: p1.size_mb, policy2: p2.size_mb}
]
Vl.new(width: 400, height: 200, title: "Model Size Comparison")
|> Vl.data_from_values(data)
|> Vl.transform(fold: ["policy1", "policy2"])
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "metric", type: :nominal, title: "Metric")
|> Vl.encode_field(:y, "value", type: :quantitative, title: "Value")
|> Vl.encode_field(:color, "key", type: :nominal, title: "Policy",
scale: %{range: ["#4C78A8", "#F58518"]})
|> Vl.encode_field(:x_offset, "key")
nil ->
Kino.Markdown.new("**Load two policies above first.**")
end
Next Steps
After analyzing your model, you might want to:
-
Train longer - More epochs or more replay data
mix run scripts/train_from_replays.exs --epochs 10 --max-files 100 -
Try temporal training - Uses past frames for context
# LSTM backbone (slower but proven) mix run scripts/train_from_replays.exs --temporal --backbone lstm # Mamba backbone (24x faster inference, recommended for real-time play) mix run scripts/train_from_replays.exs --temporal --backbone mamba \ --hidden 256 --state-size 16 --expand-factor 2 --conv-size 4 -
Fine-tune with PPO - Improve beyond imitation
mix run scripts/train_ppo.exs --pretrained checkpoints/imitation_latest_policy.bin --mock -
Test in Dolphin - Play against the AI
mix run scripts/play_dolphin.exs --policy checkpoints/imitation_latest_policy.bin