Powered by AppSignal & Oban Pro

ExPhil Evaluation Dashboard

notebooks/evaluation_dashboard.livemd

ExPhil Evaluation Dashboard

What is ExPhil?

ExPhil is an Elixir-based AI that learns to play Super Smash Bros. Melee by watching human replay files. It uses imitation learning (also called behavioral cloning) to copy how humans play, then can be improved with reinforcement learning (PPO) to develop its own strategies.

This notebook lets you:

  • Inspect trained models - See the neural network architecture and parameters
  • Analyze training data - Visualize button press rates and stick positions from replays
  • Test inference - Run the model on game states and see what actions it chooses
  • Compare to humans - Measure how closely the model matches human play

Setup

Run this cell first to load all dependencies. This may take a minute on first run.

Mix.install([
  {:exphil, path: Path.expand(".."), env: :dev},
  {:kino, "~> 0.14"},
  {:kino_vega_lite, "~> 0.1"}
])

alias ExPhil.{Agents, Embeddings, Training}
alias ExPhil.Training.{Imitation, Data}
alias ExPhil.Bridge.{GameState, Player, ControllerState}
alias ExPhil.Data.Peppi
alias VegaLite, as: Vl

IO.puts("Setup complete!")

1. Load a Trained Policy

A policy is a neural network that takes a game state as input and outputs controller actions. The policy file (.bin) contains:

  • Model weights - The learned parameters of the neural network
  • Architecture config - How the network is structured (layer sizes, etc.)
  • Training config - Settings used during training (temporal mode, backbone type)
policy_path = Kino.Input.text("Policy Path",
  default: "checkpoints/imitation_latest_policy.bin")
|> Kino.render()

:ok
policy_file = Kino.Input.read(policy_path)

case Training.load_policy(policy_file) do
  {:ok, policy} ->
    Process.put(:loaded_policy, policy)

    Kino.Markdown.new("""
    ### Policy Loaded Successfully!

    | Setting | Value | What it means |
    |---------|-------|---------------|
    | **Temporal** | `#{policy.config.temporal}` | #{if policy.config.temporal, do: "Uses past frames for context (better for combos)", else: "Single-frame decisions only"} |
    | **Backbone** | `#{policy.config.backbone}` | #{case policy.config.backbone do; :sliding_window -> "Transformer attention over recent frames"; :lstm -> "LSTM recurrent memory"; :hybrid -> "LSTM + attention combined"; _ -> "Simple feedforward MLP"; end} |
    | **Embed Size** | `#{policy.config.embed_size}` | Dimensions of the game state representation |
    | **Hidden Sizes** | `#{inspect(policy.config.hidden_sizes)}` | Neurons in each hidden layer |
    """)

  {:error, reason} ->
    Kino.Markdown.new("""
    ### Failed to load policy

    **Error:** `#{inspect(reason)}`

    Make sure the path is correct and the file exists.
    """)
end

2. Model Architecture

Neural networks are made of layers. Each layer has weights (learned parameters) that transform the input. Here we show:

  • Total parameters - How many numbers the model learned (more = more capacity but slower)
  • Layer breakdown - Each layer’s shape and size
policy = Process.get(:loaded_policy)

if policy do
  # Count total parameters
  param_count = policy.params.data
  |> Enum.flat_map(fn {_layer, params} ->
    Enum.map(params, fn {_name, tensor} -> Nx.size(tensor) end)
  end)
  |> Enum.sum()

  # Get layer details
  layer_info = policy.params.data
  |> Enum.map(fn {layer, params} ->
    layer_params = Enum.map(params, fn {name, tensor} ->
      {name, Nx.shape(tensor), Nx.size(tensor)}
    end)
    {layer, layer_params}
  end)
  |> Enum.sort()

  layer_docs = Enum.map(layer_info, fn {layer, params} ->
    # Explain what each layer type does
    layer_explanation = cond do
      String.contains?(layer, "backbone") -> "Hidden layer - learns game patterns"
      String.contains?(layer, "buttons") -> "Button output - predicts A/B/X/Y/Z/L/R/D-pad"
      String.contains?(layer, "main_x") -> "Main stick X - predicts left/right movement"
      String.contains?(layer, "main_y") -> "Main stick Y - predicts up/down movement"
      String.contains?(layer, "c_x") or String.contains?(layer, "c_y") -> "C-stick - predicts aerial/smash direction"
      String.contains?(layer, "shoulder") -> "Shoulder - predicts L/R trigger pressure"
      String.contains?(layer, "embed") -> "Embedding - encodes game state"
      true -> "Neural network layer"
    end

    param_details = Enum.map(params, fn {name, shape, size} ->
      "  - `#{name}`: shape=`#{inspect(shape)}`, params=#{Number.delimit(size)}"
    end) |> Enum.join("\n")

    """
    **#{layer}** - #{layer_explanation}
    #{param_details}
    """
  end) |> Enum.join("\n")

  Kino.Markdown.new("""
  ### Model Summary

  | Metric | Value |
  |--------|-------|
  | **Total Parameters** | #{Number.delimit(param_count)} |
  | **Model Size** | ~#{Float.round(param_count * 4 / 1024 / 1024, 2)} MB (float32) |

  ### What do the parameters mean?

  - **kernel** = Weight matrix that transforms inputs
  - **bias** = Offset added after transformation
  - Shape `{input_dim, output_dim}` means: takes `input_dim` features, outputs `output_dim` features

  ### Layer Breakdown

  #{layer_docs}
  """)
else
  Kino.Markdown.new("**No policy loaded.** Run the cell above first.")
end

3. Training Data Analysis

Before we can understand the model, we need to understand the training data. Melee replays (.slp files) contain frame-by-frame recordings of:

  • Player positions, damage, stocks
  • Controller inputs (buttons, sticks)
  • Game state (stage, items, etc.)

What to look for:

  • Button rates show which buttons humans press most often
  • Stick distributions show common movement patterns
  • A well-trained model should roughly match these distributions
replay_dir = Kino.Input.text("Replay Directory",
  default: "./replays")
|> Kino.render()

max_files = Kino.Input.number("Max Files to Parse", default: 5)
|> Kino.render()

:ok
replay_path = Kino.Input.read(replay_dir)
max = Kino.Input.read(max_files)

replay_files = Path.wildcard(Path.join(replay_path, "**/*.slp"))
|> Enum.take(max)

if length(replay_files) == 0 do
  Kino.Markdown.new("""
  **No replay files found** in `#{replay_path}`

  Make sure you have `.slp` files in that directory.
  """)
else
  IO.puts("Parsing #{length(replay_files)} replay files...")

  all_frames = replay_files
  |> Enum.flat_map(fn path ->
    case Peppi.parse(path) do
      {:ok, replay} -> Peppi.to_training_frames(replay)
      {:error, _} -> []
    end
  end)

  Process.put(:training_frames, all_frames)

  Kino.Markdown.new("""
  ### Replays Parsed

  | Stat | Value |
  |------|-------|
  | Files parsed | #{length(replay_files)} |
  | Total frames | #{Number.delimit(length(all_frames))} |
  | Duration | ~#{Float.round(length(all_frames) / 60 / 60, 1)} minutes of gameplay |

  *Melee runs at 60 FPS, so each frame is ~16.7ms*
  """)
end

Button Press Rates

Each bar shows what percentage of frames had that button pressed. In Melee:

  • A = Normal attacks (jab, tilts, aerials)
  • B = Special moves (character-specific)
  • X/Y = Jump
  • Z = Grab / aerial attack
  • L/R = Shield, wavedash, L-cancel
  • D-pad = Taunt (rarely used competitively)
frames = Process.get(:training_frames, [])

if length(frames) > 0 do
  # Count button presses
  button_counts = frames
  |> Enum.reduce(%{a: 0, b: 0, x: 0, y: 0, z: 0, l: 0, r: 0, d_up: 0}, fn frame, acc ->
    buttons = frame.action.buttons
    Map.merge(acc, buttons, fn _k, v1, v2 ->
      v1 + if(v2, do: 1, else: 0)
    end)
  end)

  total = length(frames)
  button_rates = Map.new(button_counts, fn {k, v} -> {k, v / total * 100} end)

  chart_data = Enum.map(button_rates, fn {button, rate} ->
    %{button: to_string(button) |> String.upcase(), rate: Float.round(rate, 2)}
  end)

  Vl.new(width: 500, height: 250)
  |> Vl.data_from_values(chart_data)
  |> Vl.mark(:bar, color: "#4C78A8")
  |> Vl.encode_field(:x, "button",
    type: :nominal,
    sort: "-y",
    title: "Button",
    axis: [label_angle: 0])
  |> Vl.encode_field(:y, "rate",
    type: :quantitative,
    title: "Press Rate (%)")
  |> Vl.config(title: [text: "Human Button Press Rates", fontSize: 16])
else
  Kino.Markdown.new("**No frames loaded.** Run the replay parsing cell first.")
end

Stick Position Distributions

The GameCube controller has two analog sticks:

  • Main Stick (left) - Movement, attacks direction
  • C-Stick (right) - Smash attacks, aerial direction

Positions are discretized to 0-16:

  • 0 = Full left/down
  • 8 = Neutral (center)
  • 16 = Full right/up

Competitive players often use:

  • Cardinal directions (0, 8, 16) for precise inputs
  • Slight angles for DI (directional influence)
  • Neutral position when not moving
frames = Process.get(:training_frames, [])

if length(frames) > 0 do
  # Sample for performance (scatter plots with too many points are slow)
  sample_size = min(3000, length(frames))
  sampled = Enum.take_random(frames, sample_size)

  stick_data = Enum.map(sampled, fn frame ->
    %{
      main_x: frame.action.main_x,
      main_y: frame.action.main_y,
      c_x: frame.action.c_x,
      c_y: frame.action.c_y
    }
  end)

  main_chart = Vl.new(width: 280, height: 280)
  |> Vl.data_from_values(stick_data)
  |> Vl.mark(:circle, opacity: 0.15, size: 8, color: "#E45756")
  |> Vl.encode_field(:x, "main_x",
    type: :quantitative,
    scale: [domain: [0, 16]],
    title: "X (0=left, 16=right)")
  |> Vl.encode_field(:y, "main_y",
    type: :quantitative,
    scale: [domain: [0, 16]],
    title: "Y (0=down, 16=up)")
  |> Vl.config(title: [text: "Main Stick", fontSize: 14])

  c_chart = Vl.new(width: 280, height: 280)
  |> Vl.data_from_values(stick_data)
  |> Vl.mark(:circle, opacity: 0.15, size: 8, color: "#72B7B2")
  |> Vl.encode_field(:x, "c_x",
    type: :quantitative,
    scale: [domain: [0, 16]],
    title: "X (0=left, 16=right)")
  |> Vl.encode_field(:y, "c_y",
    type: :quantitative,
    scale: [domain: [0, 16]],
    title: "Y (0=down, 16=up)")
  |> Vl.config(title: [text: "C-Stick", fontSize: 14])

  Vl.concat([main_chart, c_chart], :horizontal)
else
  Kino.Markdown.new("**No frames loaded.**")
end

4. Live Inference Testing

Now let’s see what the model actually outputs! We create a synthetic game state and ask the model what action it would take.

Understanding the outputs:

  • Buttons - Which buttons the model wants to press (sampled from probabilities)
  • Main Stick - Movement direction (0-16 discretized)
  • C-Stick - Smash/aerial direction
  • Shoulder - L/R trigger pressure (for shields, wavedash)

Since the policy is stochastic (probabilistic), running it multiple times on the same state gives different actions. This is intentional - it prevents the AI from being predictable.

policy = Process.get(:loaded_policy)

if policy do
  # Create a neutral game state: both players standing on stage
  game_state = %GameState{
    frame: 0,
    stage: 2,  # Final Destination (flat stage, no platforms)
    players: %{
      1 => %Player{
        x: 0.0,           # Center stage
        y: 0.0,           # On ground
        percent: 0.0,     # No damage
        stock: 4,         # Full stocks
        facing: 1,        # Facing right
        character: 9,     # Mewtwo
        action: 14,       # Standing/idle action
        action_frame: 0,
        invulnerable: false,
        jumps_left: 2,
        on_ground: true,
        shield_strength: 60.0
      },
      2 => %Player{
        x: 20.0,          # Opponent to the right
        y: 0.0,
        percent: 50.0,    # Opponent has some damage
        stock: 4,
        facing: -1,       # Facing left (toward us)
        character: 2,     # Fox (common opponent)
        action: 14,
        action_frame: 0,
        invulnerable: false,
        jumps_left: 2,
        on_ground: true,
        shield_strength: 60.0
      }
    }
  }

  {:ok, agent} = Agents.Agent.start_link(policy: policy)

  # Sample 50 actions to see the distribution
  samples = for _ <- 1..50 do
    {:ok, action} = Agents.Agent.get_action(agent, game_state, player_port: 1)
    action
  end

  GenServer.stop(agent)

  # Analyze the samples
  button_names = [:a, :b, :x, :y, :z, :l, :r, :d_up]

  button_counts = samples
  |> Enum.reduce(Map.new(button_names, fn b -> {b, 0} end), fn action, acc ->
    Enum.reduce(button_names, acc, fn btn, inner_acc ->
      pressed = case action.buttons do
        %Nx.Tensor{} = t -> Nx.to_flat_list(t) |> Enum.at(Enum.find_index(button_names, &amp;(&amp;1 == btn))) == 1
        map when is_map(map) -> Map.get(map, btn, false)
        _ -> false
      end
      Map.update!(inner_acc, btn, &amp;(&amp;1 + if(pressed, do: 1, else: 0)))
    end)
  end)

  get_stick_value = fn val ->
    case val do
      %Nx.Tensor{} -> Nx.to_number(Nx.squeeze(val))
      n when is_number(n) -> n
      _ -> 8
    end
  end

  main_x_samples = Enum.map(samples, &amp;get_stick_value.(&amp;1.main_x))
  main_y_samples = Enum.map(samples, &amp;get_stick_value.(&amp;1.main_y))

  main_x_mode = main_x_samples |> Enum.frequencies() |> Enum.max_by(fn {_, v} -> v end) |> elem(0)
  main_y_mode = main_y_samples |> Enum.frequencies() |> Enum.max_by(fn {_, v} -> v end) |> elem(0)

  Kino.Markdown.new("""
  ### Scenario: Neutral Game

  **Game State:**
  - Player (Mewtwo) at center stage, 0% damage
  - Opponent (Fox) 20 units to the right, 50% damage
  - Both standing on ground

  ### Model Output (50 samples)

  **Button Press Frequency:**
  #{Enum.map(button_counts, fn {btn, count} ->
    pct = Float.round(count / 50 * 100, 0)
    bar = String.duplicate("█", trunc(pct / 5))
    "| #{String.upcase(to_string(btn))} | #{bar} #{pct}% |"
  end) |> Enum.join("\n")}

  **Main Stick:**
  - Most common X: #{main_x_mode} #{if main_x_mode < 8, do: "(left)", else: if main_x_mode > 8, do: "(right)", else: "(neutral)"}
  - Most common Y: #{main_y_mode} #{if main_y_mode < 8, do: "(down)", else: if main_y_mode > 8, do: "(up)", else: "(neutral)"}

  *Note: Stochastic sampling means each run gives slightly different results*
  """)
else
  Kino.Markdown.new("**No policy loaded.**")
end

5. Deterministic vs Stochastic Mode

The model can run in two modes:

  • Stochastic (default) - Samples actions from probability distributions. More varied, less predictable.
  • Deterministic - Always picks the highest-probability action. Consistent but exploitable.

Let’s compare them on a more interesting scenario: Player at high percent vs attacking opponent

policy = Process.get(:loaded_policy)

if policy do
  # Dangerous scenario: we're at high percent, opponent is attacking
  game_state = %GameState{
    frame: 0,
    stage: 2,
    players: %{
      1 => %Player{
        x: -30.0,         # Left side of stage
        y: 0.0,
        percent: 120.0,   # HIGH damage - could die!
        stock: 2,         # Only 2 stocks left
        facing: 1,
        character: 9,     # Mewtwo
        action: 14,       # Standing (vulnerable)
        action_frame: 0,
        invulnerable: false,
        jumps_left: 1,    # Used one jump
        on_ground: true,
        shield_strength: 30.0  # Shield is low
      },
      2 => %Player{
        x: -10.0,         # Close and approaching
        y: 0.0,
        percent: 40.0,
        stock: 3,
        facing: -1,
        character: 2,     # Fox
        action: 20,       # Dashing toward us
        action_frame: 5,
        invulnerable: false,
        jumps_left: 2,
        on_ground: true,
        shield_strength: 60.0
      }
    }
  }

  # Deterministic: always same output
  {:ok, det_agent} = Agents.Agent.start_link(policy: policy, deterministic: true)
  det_actions = for _ <- 1..5 do
    {:ok, action} = Agents.Agent.get_action(det_agent, game_state, player_port: 1)
    action
  end
  GenServer.stop(det_agent)

  # Stochastic: varies each time
  {:ok, stoch_agent} = Agents.Agent.start_link(policy: policy, deterministic: false)
  stoch_actions = for _ <- 1..5 do
    {:ok, action} = Agents.Agent.get_action(stoch_agent, game_state, player_port: 1)
    action
  end
  GenServer.stop(stoch_agent)

  button_names = [:a, :b, :x, :y, :z, :l, :r, :d_up]

  format_action = fn action ->
    buttons = case action.buttons do
      %Nx.Tensor{} = t ->
        Nx.to_flat_list(t)
        |> Enum.with_index()
        |> Enum.filter(fn {v, _} -> v == 1 end)
        |> Enum.map(fn {_, i} -> Enum.at(button_names, i) end)
      map when is_map(map) ->
        Enum.filter(map, fn {_, v} -> v end) |> Enum.map(fn {k, _} -> k end)
      _ -> []
    end

    get_val = fn v ->
      case v do
        %Nx.Tensor{} -> Nx.to_number(Nx.squeeze(v))
        n -> n
      end
    end

    btn_str = if buttons == [], do: "none", else: Enum.join(buttons, ",")
    "#{btn_str} stick=(#{get_val.(action.main_x)},#{get_val.(action.main_y)})"
  end

  Kino.Markdown.new("""
  ### Scenario: Survival Situation

  **Game State:**
  - Player (Mewtwo) at **120% damage** - kill percent!
  - Low shield, only 1 jump left
  - Fox dashing toward us from 20 units away

  ### Deterministic Mode (5 runs - should be identical)
  #{Enum.with_index(det_actions, 1) |> Enum.map(fn {a, i} -> "#{i}. #{format_action.(a)}" end) |> Enum.join("\n")}

  ### Stochastic Mode (5 runs - should vary)
  #{Enum.with_index(stoch_actions, 1) |> Enum.map(fn {a, i} -> "#{i}. #{format_action.(a)}" end) |> Enum.join("\n")}

  **Interpretation:**
  - Deterministic always picks the "best" action, but opponents can predict it
  - Stochastic explores alternatives, making the AI less exploitable
  - Good models should show sensible defensive options (shield, retreat, jump)
  """)
else
  Kino.Markdown.new("**No policy loaded.**")
end

6. Model vs Human Comparison

The ultimate test: How well does the model match human play?

We take random frames from the replays and compare:

  1. What the human actually did
  2. What the model would do in the same situation

Metrics:

  • Button Accuracy - % of buttons that match (all 7 buttons)
  • Stick Match - Whether stick position is within 2 units (allowing for small variations)

Note: 100% accuracy isn’t the goal! Humans vary their play, and there are often multiple good options.

frames = Process.get(:training_frames, [])
policy = Process.get(:loaded_policy)

if length(frames) > 0 and policy do
  {:ok, agent} = Agents.Agent.start_link(policy: policy, deterministic: true)

  # Sample frames for comparison
  num_samples = min(50, length(frames))
  sample_frames = Enum.take_random(frames, num_samples)

  button_names = [:a, :b, :x, :y, :z, :l, :r]

  comparisons = Enum.map(sample_frames, fn frame ->
    game_state = %GameState{
      frame: 0,
      stage: 2,
      players: %{
        1 => frame.player,
        2 => frame.opponent
      }
    }

    {:ok, model_action} = Agents.Agent.get_action(agent, game_state, player_port: 1)
    human_action = frame.action

    # Get model button values
    model_buttons = case model_action.buttons do
      %Nx.Tensor{} = t ->
        vals = Nx.to_flat_list(t)
        Map.new(Enum.with_index(button_names), fn {btn, i} -> {btn, Enum.at(vals, i, 0) == 1} end)
      map when is_map(map) -> map
      _ -> %{}
    end

    # Compare buttons
    button_matches = Enum.count(button_names, fn btn ->
      model_val = Map.get(model_buttons, btn, false)
      human_val = Map.get(human_action.buttons, btn, false)
      model_val == human_val
    end)

    # Get stick values
    get_val = fn v ->
      case v do
        %Nx.Tensor{} -> Nx.to_number(Nx.squeeze(v))
        n when is_number(n) -> n
        _ -> 8
      end
    end

    model_x = get_val.(model_action.main_x)
    model_y = get_val.(model_action.main_y)
    human_x = human_action.main_x
    human_y = human_action.main_y

    stick_x_close = abs(model_x - human_x) <= 2
    stick_y_close = abs(model_y - human_y) <= 2

    %{
      button_accuracy: button_matches / 7 * 100,
      stick_match: stick_x_close and stick_y_close,
      exact_stick: model_x == human_x and model_y == human_y
    }
  end)

  GenServer.stop(agent)

  avg_button_acc = Enum.map(comparisons, &amp; &amp;1.button_accuracy) |> Enum.sum() |> Kernel./(length(comparisons))
  stick_match_rate = Enum.count(comparisons, &amp; &amp;1.stick_match) / length(comparisons) * 100
  exact_stick_rate = Enum.count(comparisons, &amp; &amp;1.exact_stick) / length(comparisons) * 100

  # Determine quality
  quality = cond do
    avg_button_acc > 85 and stick_match_rate > 70 -> {"Excellent", "The model closely matches human play"}
    avg_button_acc > 75 and stick_match_rate > 50 -> {"Good", "The model has learned basic patterns"}
    avg_button_acc > 65 -> {"Fair", "More training data or epochs may help"}
    true -> {"Needs work", "Consider more training or architecture changes"}
  end

  Kino.Markdown.new("""
  ### Model vs Human Comparison

  **Sample size:** #{num_samples} random frames

  | Metric | Value | Explanation |
  |--------|-------|-------------|
  | **Button Accuracy** | #{Float.round(avg_button_acc, 1)}% | How often all 7 buttons match |
  | **Stick Match (±2)** | #{Float.round(stick_match_rate, 1)}% | Stick within 2 units of human |
  | **Exact Stick Match** | #{Float.round(exact_stick_rate, 1)}% | Stick exactly matches human |

  ### Assessment: #{elem(quality, 0)}

  #{elem(quality, 1)}

  **Notes:**
  - Perfect accuracy (100%) isn't expected - humans have personal style
  - Button accuracy >80% and stick >60% indicates good imitation
  - Low scores may indicate: not enough training data, wrong player port, or model architecture issues
  """)
else
  Kino.Markdown.new("""
  **Cannot run comparison.**

  Make sure you have:
  1. Loaded a policy (Section 1)
  2. Parsed replay files (Section 3)
  """)
end

7. Backbone Comparison (Mamba vs LSTM)

Different backbone architectures have different trade-offs:

Backbone Complexity Speed Use Case
LSTM O(L) per step ~220ms Good accuracy, slower inference
Mamba (SSM) O(L) total ~9ms Fast inference, newer architecture
Sliding Window O(W²) ~50ms Attention over recent frames

This section lets you compare two trained policies with different backbones.

# Load two policies for comparison
policy1_path = Kino.Input.text("Policy 1 (e.g., LSTM)",
  default: "checkpoints/lstm_policy.bin")
|> Kino.render()

policy2_path = Kino.Input.text("Policy 2 (e.g., Mamba)",
  default: "checkpoints/mamba_policy.bin")
|> Kino.render()

:ok
alias ExPhil.Networks.Policy

p1_file = Kino.Input.read(policy1_path)
p2_file = Kino.Input.read(policy2_path)

load_policy_info = fn path ->
  case Training.load_policy(path) do
    {:ok, policy} ->
      param_count = policy.params.data
      |> Enum.flat_map(fn {_layer, params} ->
        Enum.map(params, fn {_name, tensor} -> Nx.size(tensor) end)
      end)
      |> Enum.sum()

      {:ok, %{
        path: path,
        backbone: policy.config[:backbone] || :mlp,
        temporal: policy.config[:temporal] || false,
        hidden_size: policy.config[:hidden_size] || hd(policy.config[:hidden_sizes] || [256]),
        param_count: param_count,
        size_mb: Float.round(param_count * 4 / 1024 / 1024, 2),
        policy: policy
      }}
    error -> error
  end
end

results = {load_policy_info.(p1_file), load_policy_info.(p2_file)}

case results do
  {{:ok, p1}, {:ok, p2}} ->
    Process.put(:compare_policies, {p1, p2})

    Kino.Markdown.new("""
    ### Policies Loaded for Comparison

    | Metric | #{p1.backbone} | #{p2.backbone} |
    |--------|----------------|----------------|
    | **Parameters** | #{Number.delimit(p1.param_count)} | #{Number.delimit(p2.param_count)} |
    | **Size** | #{p1.size_mb} MB | #{p2.size_mb} MB |
    | **Temporal** | #{p1.temporal} | #{p2.temporal} |
    | **Hidden Size** | #{p1.hidden_size} | #{p2.hidden_size} |
    """)

  {{:error, r1}, {:error, r2}} ->
    Kino.Markdown.new("**Failed to load both policies:** #{inspect(r1)}, #{inspect(r2)}")

  {{:error, r}, _} ->
    Kino.Markdown.new("**Failed to load Policy 1:** #{inspect(r)}")

  {_, {:error, r}} ->
    Kino.Markdown.new("**Failed to load Policy 2:** #{inspect(r)}")
end

Inference Speed Benchmark

Compare how fast each backbone processes game states. Lower is better for real-time gameplay (need <16ms for 60 FPS).

case Process.get(:compare_policies) do
  {p1, p2} ->
    # Create test input
    embed_size = p1.policy.config[:embed_size] || 1991
    window_size = p1.policy.config[:window_size] || 60

    # Run benchmark
    benchmark_policy = fn policy, name ->
      # Build model
      {model, _} = if policy.config[:temporal] do
        Policy.build_temporal(
          embed_size: policy.config[:embed_size],
          hidden_sizes: policy.config[:hidden_sizes] || [256, 256],
          backbone: policy.config[:backbone] || :lstm,
          window_size: policy.config[:window_size] || 60,
          hidden_size: policy.config[:hidden_size] || 256,
          state_size: policy.config[:state_size] || 16,
          expand_factor: policy.config[:expand_factor] || 2,
          conv_size: policy.config[:conv_size] || 4,
          num_layers: policy.config[:num_layers] || 2
        )
      else
        Policy.build(
          embed_size: policy.config[:embed_size],
          hidden_sizes: policy.config[:hidden_sizes] || [256, 256]
        )
      end

      # Create input
      input_shape = if policy.config[:temporal] do
        {1, policy.config[:window_size] || 60, policy.config[:embed_size]}
      else
        {1, policy.config[:embed_size]}
      end
      input = Nx.random_uniform(input_shape)

      # Warmup
      predict_fn = Axon.build(model, mode: :inference)
      for _ <- 1..3, do: predict_fn.(policy.params, input)

      # Benchmark
      runs = 10
      times = for _ <- 1..runs do
        start = System.monotonic_time(:microsecond)
        predict_fn.(policy.params, input)
        System.monotonic_time(:microsecond) - start
      end

      avg_us = Enum.sum(times) / runs
      avg_ms = Float.round(avg_us / 1000, 2)

      %{name: name, backbone: policy.config[:backbone], avg_ms: avg_ms, fps_ready: avg_ms < 16}
    end

    b1 = benchmark_policy.(p1.policy, "Policy 1")
    b2 = benchmark_policy.(p2.policy, "Policy 2")

    speedup = if b1.avg_ms > b2.avg_ms do
      Float.round(b1.avg_ms / b2.avg_ms, 1)
    else
      Float.round(b2.avg_ms / b1.avg_ms, 1)
    end

    faster = if b1.avg_ms < b2.avg_ms, do: b1.backbone, else: b2.backbone

    Kino.Markdown.new("""
    ### Benchmark Results

    | Policy | Backbone | Avg Inference | 60 FPS Ready? |
    |--------|----------|---------------|---------------|
    | Policy 1 | #{b1.backbone} | **#{b1.avg_ms} ms** | #{if b1.fps_ready, do: "✅ Yes", else: "❌ No"} |
    | Policy 2 | #{b2.backbone} | **#{b2.avg_ms} ms** | #{if b2.fps_ready, do: "✅ Yes", else: "❌ No"} |

    **#{faster}** is **#{speedup}x faster**

    > For 60 FPS gameplay, inference must complete in <16ms per frame.
    """)

  nil ->
    Kino.Markdown.new("**Load two policies above first.**")
end

Visual Comparison

case Process.get(:compare_policies) do
  {p1, p2} ->
    data = [
      %{metric: "Parameters (K)", policy1: p1.param_count / 1000, policy2: p2.param_count / 1000},
      %{metric: "Size (MB)", policy1: p1.size_mb, policy2: p2.size_mb}
    ]

    Vl.new(width: 400, height: 200, title: "Model Size Comparison")
    |> Vl.data_from_values(data)
    |> Vl.transform(fold: ["policy1", "policy2"])
    |> Vl.mark(:bar)
    |> Vl.encode_field(:x, "metric", type: :nominal, title: "Metric")
    |> Vl.encode_field(:y, "value", type: :quantitative, title: "Value")
    |> Vl.encode_field(:color, "key", type: :nominal, title: "Policy",
      scale: %{range: ["#4C78A8", "#F58518"]})
    |> Vl.encode_field(:x_offset, "key")

  nil ->
    Kino.Markdown.new("**Load two policies above first.**")
end

Next Steps

After analyzing your model, you might want to:

  1. Train longer - More epochs or more replay data

    mix run scripts/train_from_replays.exs --epochs 10 --max-files 100
  2. Try temporal training - Uses past frames for context

    # LSTM backbone (slower but proven)
    mix run scripts/train_from_replays.exs --temporal --backbone lstm
    
    # Mamba backbone (24x faster inference, recommended for real-time play)
    mix run scripts/train_from_replays.exs --temporal --backbone mamba \
      --hidden 256 --state-size 16 --expand-factor 2 --conv-size 4
  3. Fine-tune with PPO - Improve beyond imitation

    mix run scripts/train_ppo.exs --pretrained checkpoints/imitation_latest_policy.bin --mock
  4. Test in Dolphin - Play against the AI

    mix run scripts/play_dolphin.exs --policy checkpoints/imitation_latest_policy.bin