Powered by AppSignal & Oban Pro

MCTS - ConnectFour

mcts_connect_four.livemd

MCTS - ConnectFour

Mix.install([
  {:bumblebee, "~> 0.6.3"},
  {:nx, "~> 0.10.0"},
  {:axon, "~> 0.7.0"},
  {:kino, "~> 0.19.0"},
  {:torchx, "~> 0.10.0"},
  {:uuid, "~> 1.1"}
])

Section

Nx.default_backend(Torchx.Backend)
defmodule Utils do

  def mask_and_normalize(policy, valid_moves) do
    masked = Enum.zip_with(policy, valid_moves, fn p, v -> p * v end)
    total  = Enum.sum(masked)

    if total > 0 do
      Enum.map(masked, fn p -> p / total end)
    else
      n_valid = Enum.sum(valid_moves)
      Enum.map(valid_moves, fn v -> v / n_valid end)
    end
  end

  def apply_temperature(action_probs, temperature) do
    probs = Enum.map(action_probs, fn p -> :math.pow(p, 1.0 / temperature) end)
    total = Enum.sum(probs)
    Enum.map(probs, fn p -> p / total end)
  end

  def sample_action(probs, action_size) do
    r = :rand.uniform()

    probs
    |> Enum.with_index()
    |> Enum.reduce_while(0.0, fn {p, i}, acc ->
      acc = acc + p
      if r <= acc, do: {:halt, i}, else: {:cont, acc}
    end)
    |> then(fn
      i when is_integer(i) -> i
      _                    -> action_size - 1
    end)
  end

  def add_dirichlet_noise(policy, args) do
    epsilon = args.dirichlet_epsilon
    noise   = Dirichlet.sample(List.duplicate(args.dirichlet_alpha, args.action_size))

    Enum.zip_with(policy, noise, fn p, n ->
      (1 - epsilon) * p + epsilon * n
    end)
  end
end

Game

defmodule Game do
  defstruct [:row_count, :col_count, :action_size, :in_a_row]

  def new(opts \\ []) do
    row_count = Keyword.get(opts, :row_count, 6)
    col_count = Keyword.get(opts, :col_count, 7)

    %Game{
      row_count:   row_count,
      col_count:   col_count,
      action_size: Keyword.get(opts, :action_size, col_count),
      in_a_row:    Keyword.get(opts, :in_a_row, 4)
    }
  end

  # ── estado inicial ─────────────────────────────────────────────────────

  def get_initial_board(%Game{row_count: rows, col_count: cols}) do
    Nx.broadcast(Nx.tensor(0, type: :s32), {rows, cols})
  end

  # ── próximo estado ─────────────────────────────────────────────────────

  def get_next_board(board, action, player) do
    col = board[[.., action]]

    empty_rows =
      col
      |> Nx.equal(0)
      |> Nx.to_flat_list()
      |> Enum.with_index()
      |> Enum.filter(fn {val, _i} -> val == 1 end)
      |> Enum.map(fn {_val, i} -> i end)

    case empty_rows do
      [] ->
        board

      rows ->
        row     = Enum.max(rows)
        indices = Nx.tensor([[row, action]])
        updates = Nx.tensor([player], type: :s32)
        Nx.indexed_put(board, indices, updates)
    end
  end

  # ── movimentos válidos ─────────────────────────────────────────────────

  def get_valid_moves(board) do
    case Nx.shape(board) do
      {_rows, _cols} ->
        board[0]
        |> Nx.equal(0)
        |> Nx.as_type(:u8)

      {_batch, _rows, _cols} ->
        board[[.., 0, ..]]
        |> Nx.equal(0)
        |> Nx.as_type(:u8)
    end
  end

  # ── verificação de vitória ─────────────────────────────────────────────

  def check_win(_game, _board, nil), do: false

  def check_win(%Game{row_count: row_count, col_count: col_count, in_a_row: in_a_row}, board, action) do
    col = board[[.., action]] |> Nx.to_flat_list()

    occupied =
      col
      |> Enum.with_index()
      |> Enum.filter(fn {val, _i} -> val != 0 end)
      |> Enum.map(fn {_val, i} -> i end)

    case occupied do
      [] ->
        false

      rows ->
        row        = Enum.min(rows)
        player     = board[row][action] |> Nx.to_number()
        board_list = Nx.to_list(board)

        count = fn offset_row, offset_col ->
          Enum.reduce_while(1..(in_a_row - 1), 0, fn i, _acc ->
            r = row + offset_row * i
            c = action + offset_col * i

            if r < 0 or r >= row_count or c < 0 or c >= col_count do
              {:halt, i - 1}
            else
              val = board_list |> Enum.at(r) |> Enum.at(c)
              if val != player, do: {:halt, i - 1}, else: {:cont, i}
            end
          end)
        end

        count.(1,  0) >= in_a_row - 1 or
        count.(0,  1) + count.(0,  -1) >= in_a_row - 1 or
        count.(1,  1) + count.(-1, -1) >= in_a_row - 1 or
        count.(1, -1) + count.(-1,  1) >= in_a_row - 1
    end
  end

  # ── valor e terminal ───────────────────────────────────────────────────

  def get_value_and_terminated(%Game{} = game, board, action) do
    if check_win(game, board, action) do
      {1, true}
    else
      valid_sum =
        board
        |> get_valid_moves()
        |> Nx.sum()
        |> Nx.to_number()

      if valid_sum == 0, do: {0, true}, else: {0, false}
    end
  end

  # ── perspectiva ────────────────────────────────────────────────────────

  def change_perspective(board, player) do
    Nx.multiply(board, Nx.tensor(player, type: :s32))
  end

  # ── encode ─────────────────────────────────────────────────────────────

  def get_encoded_board(board) do
    planes =
      [-1, 0, 1]
      |> Enum.map(fn player ->
        Nx.equal(board, Nx.tensor(player, type: :s32))
      end)
      |> Nx.stack()
      |> Nx.as_type(:f32)

    case Nx.shape(board) do
      {_rows, _cols}         -> planes
      {_batch, _rows, _cols} -> Nx.transpose(planes, axes: [1, 0, 2, 3])
    end
  end

  # ── helpers ────────────────────────────────────────────────────────────

  def get_opponent(player),      do: -player
  def get_opponent_value(value), do: -value
end

Distribuição de Dirichlet

defmodule Dirichlet do
  @moduledoc """
  Implementação de np.random.dirichlet em Elixir puro.
  Usa o método Marsaglia-Tsang para amostrar a distribuição Gamma.
  """

  @doc """
  Amostra uma distribuição de Dirichlet.

  ## Parâmetros
    - alphas: lista de concentrações [α₁, α₂, …, αₖ], todos > 0

  ## Retorno
    Lista de floats que somam 1.0

  ## Exemplo
      iex> Dirichlet.sample([1.0, 1.0, 1.0])
      [0.312, 0.485, 0.203]  # valores aleatórios
  """
  def sample(alphas) when is_list(alphas) do
    gammas = Enum.map(alphas, &amp;sample_gamma/1)
    total  = Enum.sum(gammas)
    Enum.map(gammas, fn g -> g / total end)
  end

  # ── Gamma(alpha, 1) via Marsaglia-Tsang ──────────────────────────────

  # Caso especial: alpha < 1  →  Gamma(alpha+1) * U^(1/alpha)
  defp sample_gamma(alpha) when alpha < 1 do
    u = :rand.uniform()
    sample_gamma(alpha + 1.0) * :math.pow(u, 1.0 / alpha)
  end

  # Caso geral: alpha >= 1  →  método d-c de Marsaglia-Tsang
  defp sample_gamma(alpha) do
    d = alpha - 1.0 / 3.0
    c = 1.0 / :math.sqrt(9.0 * d)
    marsaglia_loop(d, c)
  end

  defp marsaglia_loop(d, c) do
    {x, v} = sample_v(c)
    u = :rand.uniform()

    cond do
      v <= 0.0 ->
        marsaglia_loop(d, c)

      # Critério de aceitação rápido (log-squeeze)
      u < 1.0 - 0.0331 * (x * x) * (x * x) ->
        d * v

      :math.log(u) < 0.5 * x * x + d * (1.0 - v + :math.log(v)) ->
        d * v

      true ->
        marsaglia_loop(d, c)
    end
  end

  # Gera candidato v = (1 + c*x)^3 com x ~ N(0,1)
  defp sample_v(c) do
    x = sample_normal()
    v = 1.0 + c * x
    {x, v * v * v}
  end

  # Normal(0,1) via Box-Muller
  defp sample_normal do
    u1 = :rand.uniform()
    u2 = :rand.uniform()
    :math.sqrt(-2.0 * :math.log(u1)) * :math.cos(2.0 * :math.pi() * u2)
  end
end

Model ResNet

defmodule ResNet do
  import Axon

  def start_link(game, num_res_blocks, num_hidden, learning_rate) do
    model  = build(game, num_res_blocks, num_hidden)

    {init_fn, predict_fn} = Axon.build(model, mode: :inference)
    params = init_fn.(Nx.template({1, 3, game.row_count, game.col_count}, :f32), %{})

    Agent.start_link(fn ->
      %{
        model:         model,
        params:        params,
        predict_fn:    predict_fn,
        learning_rate: learning_rate
      }
    end, name: __MODULE__)
  end

  # ── inferência ─────────────────────────────────────────────────────────

  def predict(input) do
    %{params: params, predict_fn: predict_fn} =
      Agent.get(__MODULE__, &amp; &amp;1)

    predict_fn.(params, %{"state" => input})
  end

  def get_params,    do: Agent.get(__MODULE__, &amp; &amp;1.params)
  def get_opt_state, do: Agent.get(__MODULE__, &amp; &amp;1.opt_state)

  def set_mode(_mode), do: :ok

  # ── treinamento via Axon.Loop ──────────────────────────────────────────

  def train_step(state_tensor, policy_targets, value_targets) do
    %{model: model, params: params, learning_rate: learning_rate} =
      Agent.get(__MODULE__, &amp; &amp;1)

    loss_fn = fn {out_policy, out_value}, {p_targets, v_targets} ->
      policy_loss =
        out_policy
        |> Axon.Activations.log_softmax(axis: 1)
        |> Nx.multiply(p_targets)
        |> Nx.sum(axes: [1])
        |> Nx.mean()
        |> Nx.negate()

      value_loss =
        out_value
        |> Nx.subtract(v_targets)
        |> Nx.pow(2)
        |> Nx.mean()

      Nx.add(policy_loss, value_loss)
    end

    new_params =
      model
      |> Axon.Loop.trainer(loss_fn, Polaris.Optimizers.adam(learning_rate: learning_rate))
      |> Axon.Loop.run(
        [{%{"state" => state_tensor}, {policy_targets, value_targets}}],
        params,
        epochs: 1
      )

    Agent.update(__MODULE__, fn state ->
      %{state | params: new_params}
    end)
  end

  # ── save / load ────────────────────────────────────────────────────────

  def save(iteration) do
    params = get_params()
    File.mkdir_p!("models")
    File.write!("models/params_#{iteration}.nx", Nx.serialize(params))
    IO.puts("modelo salvo: models/params_#{iteration}.nx")
  end

  def load(iteration) do
    params = "models/params_#{iteration}.nx" |> File.read!() |> Nx.deserialize()
    Agent.update(__MODULE__, fn state -> %{state | params: params} end)
  end

  # ── build ──────────────────────────────────────────────────────────────

  def build(game, num_res_blocks, num_hidden) do
    input = input("state", shape: {nil, 3, game.row_count, game.col_count})

    x = start_block(input, num_hidden)
    x = Enum.reduce(1..num_res_blocks, x, fn _i, x -> res_block(x, num_hidden) end)

    policy = policy_head(x, game)
    value  = value_head(x, game)

    Axon.container({policy, value})
  end

  defp start_block(x, num_hidden) do
    x
    |> conv(num_hidden, kernel_size: {3, 3}, padding: :same, use_bias: false)
    |> batch_norm()
    |> relu()
  end

  defp res_block(x, num_hidden) do
    residual = x

    x
    |> conv(num_hidden, kernel_size: {3, 3}, padding: :same, use_bias: false)
    |> batch_norm()
    |> relu()
    |> conv(num_hidden, kernel_size: {3, 3}, padding: :same, use_bias: false)
    |> batch_norm()
    |> add(residual)
    |> relu()
  end

  defp policy_head(x, game) do
    x
    |> conv(32, kernel_size: {3, 3}, padding: :same, use_bias: false)
    |> batch_norm()
    |> relu()
    |> flatten()
    |> dense(game.action_size)
  end

  defp value_head(x, game) do
    x
    |> conv(3, kernel_size: {3, 3}, padding: :same, use_bias: false)
    |> batch_norm()
    |> relu()
    |> flatten()
    |> dense(1)
    |> tanh()
  end
end

MCTS

defmodule MCTS do
  defmodule Node do
    defstruct [
      :id, :game, :args, :board, :parent_id, :action_taken,
      prior: 0,
      visit_count: 0,
      value_sum: 0,
      children_ids: []
    ]
  end

  defmodule Tree do
    use Agent

    def start_link, do: Agent.start_link(fn -> {%{}, 0} end)
    def stop(tree),  do: Agent.stop(tree)

    def put_node(tree, node) do
      Agent.get_and_update(tree, fn {nodes, next_id} ->
        node = %{node | id: next_id}
        {next_id, {Map.put(nodes, next_id, node), next_id + 1}}
      end)
    end

    def get_node(tree, id) do
      Agent.get(tree, fn {nodes, _} -> Map.fetch!(nodes, id) end)
    end

    def update_node(tree, id, fun) do
      Agent.update(tree, fn {nodes, next_id} ->
        {Map.update!(nodes, id, fun), next_id}
      end)
    end

    def add_child(tree, parent_id, child_id) do
      update_node(tree, parent_id, fn node ->
        %{node | children_ids: node.children_ids ++ [child_id]}
      end)
    end
  end

  def search(%{args: args, game: game}, board) do
    {:ok, tree} = Tree.start_link()

    root_id =
      Tree.put_node(tree, %Node{
        game:         game,
        args:         args,
        board:        board,
        parent_id:    nil,
        action_taken: nil,
        visit_count:  1
      })

    {policy, _} = predict(board)

    valid_moves = 
      board
      |> Game.get_valid_moves()
      |> Nx.to_flat_list()

    policy =
      policy
      |> Utils.add_dirichlet_noise(args)
      |> Utils.mask_and_normalize(valid_moves)

    expand(tree, root_id, policy)

    Enum.each(1..args.num_mcts_searches, fn _ ->
      node_id = select_leaf(tree, root_id)
      node    = Tree.get_node(tree, node_id)

      {value, is_terminal} =
        Game.get_value_and_terminated(game, node.board, node.action_taken)

      value = Game.get_opponent_value(value)

      value =
        if not is_terminal do
          {policy, value} = predict(node.board)

          valid_moves =
            node.board
            |> Game.get_valid_moves()
            |> Nx.to_flat_list()

          policy = Utils.mask_and_normalize(policy, valid_moves)

          expand(tree, node_id, policy)

          value |> Nx.squeeze() |> Nx.to_number()
        else
          value
        end

      backpropagate(tree, node_id, value)
    end)

    root  = Tree.get_node(tree, root_id)
    probs = List.duplicate(0.0, args.action_size)

    probs =
      Enum.reduce(root.children_ids, probs, fn child_id, acc ->
        child = Tree.get_node(tree, child_id)
        List.replace_at(acc, child.action_taken, child.visit_count * 1.0)
      end)

    total  = Enum.sum(probs)
    result = Enum.map(probs, fn v -> v / total end)

    Tree.stop(tree)
    result
  end

  # ── operações na árvore ───────────────────────────────────────────────

  defp expand(tree, node_id, policy) do
    node        = Tree.get_node(tree, node_id)
    valid_moves = Game.get_valid_moves(node.board) |> Nx.to_flat_list()

    policy
    |> Enum.with_index()
    |> Enum.zip(valid_moves)
    |> Enum.each(fn {{prob, action}, valid} ->
      if prob > 0 and valid == 1 do
        child_board =
          node.board
          |> Game.get_next_board(action, 1)
          |> Game.change_perspective(-1)

        child_id =
          Tree.put_node(tree, %Node{
            game:         node.game,
            args:         node.args,
            board:        child_board,
            parent_id:    node_id,
            action_taken: action,
            prior:        prob
          })

        Tree.add_child(tree, node_id, child_id)
      end
    end)
  end

  defp select_leaf(tree, node_id) do
    node = Tree.get_node(tree, node_id)

    if node.children_ids == [] do
      node_id
    else
      best_child_id =
        Enum.max_by(node.children_ids, fn child_id ->
          child = Tree.get_node(tree, child_id)
          get_ucb(node, child)
        end)

      select_leaf(tree, best_child_id)
    end
  end

  defp backpropagate(tree, node_id, value) do
    node = Tree.get_node(tree, node_id)

    Tree.update_node(tree, node_id, fn n ->
      %{n | value_sum: n.value_sum + value, visit_count: n.visit_count + 1}
    end)

    if node.parent_id != nil do
      backpropagate(tree, node.parent_id, Game.get_opponent_value(value))
    end
  end

  defp get_ucb(parent, child) do
    q_value =
      if child.visit_count == 0 do
        0
      else
        1 - (child.value_sum / child.visit_count + 1) / 2
      end

    q_value +
      parent.args.c *
      (:math.sqrt(parent.visit_count) / (child.visit_count + 1)) *
      child.prior
  end

  defp predict(board) do
    {policy, value} = 
      board
      |> Game.get_encoded_board()
      |> Nx.new_axis(0)
      |> ResNet.predict()

    policy =
      policy
      |> Axon.Activations.softmax(axis: 1)
      |> Nx.squeeze(axes: [0])
      |> Nx.to_flat_list()

    {policy, value}
  end
end

AlphaZero

defmodule AlphaZero do
  use GenServer

  defstruct [:args, :game, :memory]

  # ── API pública ────────────────────────────────────────────────────────

  def start_link(%{args: args, game: game}) do
    GenServer.start_link(__MODULE__, %__MODULE__{args: args, game: game, memory: []}, name: __MODULE__)
  end

  def get_args, do: GenServer.call(__MODULE__, :get_args)
  def get_game, do: GenServer.call(__MODULE__, :get_game)

  def learn do
    args = get_args()
    game = get_game()

    Enum.each(1..args.num_iterations, fn iteration ->
      IO.puts("Iteração #{iteration}/#{args.num_iterations}")

      ResNet.set_mode(:inference)
      memory =
        Enum.flat_map(1..args.num_selfplay_iterations, fn i ->
          IO.puts("  self-play #{i}/#{args.num_selfplay_iterations}")
          self_play(game, args)
        end)

      ResNet.set_mode(:train)
      Enum.each(1..args.num_epochs, fn epoch ->
        IO.puts("  epoch #{epoch}/#{args.num_epochs}")
        train(memory, args)
      end)

      ResNet.save(iteration)
    end)
  end

  # ── self-play ──────────────────────────────────────────────────────────

  def self_play(game, args) do
    board  = Game.get_initial_board(game)
    player = 1
    do_self_play(game, args, board, player, _memory = [])
  end

  defp do_self_play(game, args, board, player, memory) do
    neutral_board = Game.change_perspective(board, player)
    action_probs  = MCTS.search(%{args: args, game: game}, neutral_board)

    memory = [{neutral_board, action_probs, player} | memory]

    valid_moves = Game.get_valid_moves(board) |> Nx.to_flat_list()

    action =
      action_probs
      |> Utils.apply_temperature(args.temperature)
      |> Utils.mask_and_normalize(valid_moves)
      |> Utils.sample_action(args.action_size)

    board = Game.get_next_board(board, action, player)

    {value, is_terminal} = Game.get_value_and_terminated(game, board, action)

    if is_terminal do
      build_return_memory(memory, value, player)
    else
      do_self_play(game, args, board, Game.get_opponent(player), memory)
    end
  end

  defp build_return_memory(memory, value, current_player) do
    Enum.map(memory, fn {hist_board, hist_probs, hist_player} ->
      outcome =
        if hist_player == current_player,
          do:   value,
          else: Game.get_opponent_value(value)

      {Game.get_encoded_board(hist_board), hist_probs, outcome}
    end)
  end

  # ── train ──────────────────────────────────────────────────────────────

  def train(memory, args) do
    memory
    |> Enum.shuffle()
    |> Enum.chunk_every(args.batch_size)
    |> Enum.with_index(1)
    |> Enum.each(fn {batch, _i} ->
      {states, policy_targets, value_targets} =
        Enum.reduce(batch, {[], [], []}, fn {s, p, v}, {sa, pa, va} ->
          {[s | sa], [p | pa], [v | va]}
        end)
        |> then(fn {sa, pa, va} ->
          {Enum.reverse(sa), Enum.reverse(pa), Enum.reverse(va)}
        end)

      state_tensor =
        states
        |> Enum.map(&amp;Nx.new_axis(&amp;1, 0))
        |> Nx.concatenate(axis: 0)
        |> Nx.as_type(:f32)

      policy_tensor =
        policy_targets
        |> Enum.map(&amp;Nx.tensor(&amp;1, type: :f32))
        |> Nx.stack()

      value_tensor =
        value_targets
        |> Enum.map(&amp;Nx.tensor([&amp;1], type: :f32))
        |> Nx.stack()

      ResNet.train_step(state_tensor, policy_tensor, value_tensor)
    end)
  end

  # ── GenServer callbacks ────────────────────────────────────────────────

  @impl true
  def init(state), do: {:ok, state}

  @impl true
  def handle_call(:get_args, _from, state), do: {:reply, state.args, state}

  @impl true
  def handle_call(:get_game, _from, state), do: {:reply, state.game, state}
end
defmodule PlayLive do
  def start do
    frame = Kino.Frame.new()
    Kino.render(frame)

    game = Game.new()

    # botões de escolha
    btn_train = Kino.Control.button("Treinar novo modelo")
    btn_load  = Kino.Control.button("Carregar modelo salvo")

    Kino.Frame.render(frame, Kino.Layout.grid([btn_train, btn_load], columns: 2))

    Kino.listen(btn_train, fn _event ->      
      Kino.Frame.render(frame, Kino.Markdown.new("**Treinando...**"))

      args = %{
        num_iterations:          3,
        num_selfplay_iterations: 5,
        num_epochs:              4,
        batch_size:              64,
        temperature:             1.25,
        learning_rate:           0.001,
        num_mcts_searches:       60,
        dirichlet_epsilon:       0.25,
        dirichlet_alpha:         0.3,
        action_size:             game.action_size,
        c:                       2.0
      }

      ensure_resnet(game, 9, 128, args.learning_rate)
      {:ok, _} = AlphaZero.start_link(%{args: args, game: game})

      AlphaZero.learn()

      Kino.Frame.render(frame, Kino.Markdown.new("**Treino concluído!**"))
      start_game(frame, game, args)
    end)

    Kino.listen(btn_load, fn _event ->
      # lista modelos salvos
      case File.ls("models") do
        {:ok, files} ->
          iterations =
            files
            |> Enum.filter(&amp;String.starts_with?(&amp;1, "params_"))
            |> Enum.map(fn f ->
              f
              |> String.replace("params_", "")
              |> String.replace(".nx", "")
              |> String.to_integer()
            end)
            |> Enum.sort()

          if iterations == [] do
            Kino.Frame.render(frame, Kino.Markdown.new("**Nenhum modelo salvo encontrado.**"))
          else
            show_model_picker(frame, game, iterations)
          end

        {:error, _} ->
          Kino.Frame.render(frame, Kino.Markdown.new("**Pasta models/ não encontrada.**"))
      end
    end)
  end

  defp show_model_picker(frame, game, iterations) do
    buttons =
      Enum.map(iterations, fn i ->
        {Kino.Control.button("Modelo #{i}"), i}
      end)

    button_row =
      buttons
      |> Enum.map(fn {btn, _} -> btn end)
      |> Kino.Layout.grid(columns: length(buttons))

    Kino.Frame.render(frame, Kino.Markdown.new("**Escolha o modelo:**"))
    Kino.Frame.append(frame, button_row)

    Enum.each(buttons, fn {btn, iteration} ->
      Kino.listen(btn, fn _event ->
        args = %{
          num_iterations:          8,
          num_selfplay_iterations: 100,
          num_epochs:              4,
          batch_size:              64,
          temperature:             1.25,
          learning_rate:           0.001,
          num_mcts_searches:       60,
          dirichlet_epsilon:       0.25,
          dirichlet_alpha:         0.3,
          action_size:             game.action_size,
          c:                       2.0
        }

        ensure_resnet(game, 9, 128, args.learning_rate)
        ResNet.load(iteration)

        Kino.Frame.render(frame, Kino.Markdown.new("**Modelo #{iteration} carregado!**"))
        start_game(frame, game, args)
      end)
    end)
  end

  defp ensure_resnet(game, num_res_blocks, num_hidden, learning_rate) do
    case Process.whereis(ResNet) do
      nil -> :ok
      pid -> Agent.stop(pid)
    end
  
    {:ok, _} = ResNet.start_link(game, num_res_blocks, num_hidden, learning_rate)
  end
    
defp start_game(frame, game, args) do
  board = Game.get_initial_board(game)

  try do Agent.stop(:play_state) catch _, _ -> :ok end

  Agent.start_link(fn ->
    %{board: board, game: game, args: args, game_over: false, ai_mode: :mcts}
  end, name: :play_state)

  board_frame = Kino.Frame.new()

  # botões de coluna
  buttons =
    Enum.map(0..6, fn col ->
      {Kino.Control.button("#{col}"), col}
    end)

  button_row =
    buttons
    |> Enum.map(fn {btn, _} -> btn end)
    |> Kino.Layout.grid(columns: 7)

  # botões de modo da IA
  btn_mcts     = Kino.Control.button("IA forte (MCTS)")
  btn_parallel = Kino.Control.button("IA forte paralela")
  btn_raw      = Kino.Control.button("IA rápida (sem MCTS)")
  btn_restart  = Kino.Control.button("Nova partida")

  mode_row = Kino.Layout.grid([btn_mcts, btn_parallel, btn_raw], columns: 3)

  Kino.Frame.render(frame, Kino.Layout.grid([
    board_frame,
    button_row,
    mode_row,
    btn_restart
  ], columns: 1))

  render_board(board_frame, board, "Sua vez (modo: MCTS):")

  # listeners de modo
  Kino.listen(btn_mcts, fn _event ->
    Agent.update(:play_state, fn s -> %{s | ai_mode: :mcts} end)
    state = Agent.get(:play_state, &amp; &amp;1)
    render_board(board_frame, state.board, "Modo: MCTS")
  end)

  Kino.listen(btn_parallel, fn _event ->
    Agent.update(:play_state, fn s -> %{s | ai_mode: :parallel} end)
    state = Agent.get(:play_state, &amp; &amp;1)
    render_board(board_frame, state.board, "Modo: MCTS paralelo")
  end)

  Kino.listen(btn_raw, fn _event ->
    Agent.update(:play_state, fn s -> %{s | ai_mode: :raw} end)
    state = Agent.get(:play_state, &amp; &amp;1)
    render_board(board_frame, state.board, "Modo: rede neural pura")
  end)

  Kino.listen(btn_restart, fn _event ->
    start_game(frame, game, args)
  end)

  Enum.each(buttons, fn {btn, col} ->
    Kino.listen(btn, fn _event ->
      state = Agent.get(:play_state, &amp; &amp;1)

      if not state.game_over do
        valid = Game.get_valid_moves(state.board) |> Nx.to_flat_list()

        if Enum.at(valid, col) == 1 do
          play_turn(board_frame, state.game, state.args, state.board, state.ai_mode, col)
        end
      end
    end)
  end)
end

defp play_turn(board_frame, game, args, board, ai_mode, col) do
  board = Game.get_next_board(board, col, -1)

  {value, is_terminal} = Game.get_value_and_terminated(game, board, col)

  if is_terminal do
    msg = if value == 1, do: "Você venceu!", else: "Empate!"
    render_board(board_frame, board, msg)
    Agent.update(:play_state, fn s -> %{s | board: board, game_over: true} end)
  else
    render_board(board_frame, board, "IA pensando...")

    ai_action = get_ai_action(game, args, board, ai_mode)

    board = Game.get_next_board(board, ai_action, 1)

    {value, is_terminal} = Game.get_value_and_terminated(game, board, ai_action)

    if is_terminal do
      msg = if value == 1, do: "IA venceu!", else: "Empate!"
      render_board(board_frame, board, msg)
      Agent.update(:play_state, fn s -> %{s | board: board, game_over: true} end)
    else
      mode_label =
        case ai_mode do
          :mcts     -> "MCTS"
          :parallel -> "MCTS paralelo"
          :raw      -> "rede neural"
        end

      render_board(board_frame, board, "Sua vez — IA jogou coluna #{ai_action} (#{mode_label}):")
      Agent.update(:play_state, fn s -> %{s | board: board} end)
    end
  end
end

defp get_ai_action(game, args, board, :mcts) do
  neutral_board = Game.change_perspective(board, 1)
  action_probs  = MCTS.search(%{args: args, game: game}, neutral_board)
  action_probs |> Enum.with_index() |> Enum.max_by(fn {p, _} -> p end) |> elem(1)
end

defp get_ai_action(game, args, board, :parallel) do
  neutral_board = Game.change_perspective(board, 1)
  action_probs  = search_parallel(%{args: args, game: game}, neutral_board)
  action_probs |> Enum.with_index() |> Enum.max_by(fn {p, _} -> p end) |> elem(1)
end

defp get_ai_action(_game, _args, board, :raw) do
  neutral_board = Game.change_perspective(board, 1)

  {policy, _value} =
    neutral_board
    |> Game.get_encoded_board()
    |> Nx.new_axis(0)
    |> ResNet.predict()

  valid_moves = Game.get_valid_moves(board) |> Nx.to_flat_list()

  policy
  |> Axon.Activations.softmax(axis: 1)
  |> Nx.squeeze(axes: [0])
  |> Nx.to_flat_list()
  |> Utils.mask_and_normalize(valid_moves)
  |> Enum.with_index()
  |> Enum.max_by(fn {p, _} -> p end)
  |> elem(1)
end

defp search_parallel(%{args: args, game: game}, board, num_workers \\ 4) do
  searches_per_worker = div(args.num_mcts_searches, num_workers)

  results =
    1..num_workers
    |> Enum.map(fn _ ->
      Task.async(fn ->
        MCTS.search(
          %{args: %{args | num_mcts_searches: searches_per_worker}, game: game},
          board
        )
      end)
    end)
    |> Task.await_many(:infinity)

  results
  |> Enum.zip()
  |> Enum.map(fn probs_tuple ->
    probs_tuple
    |> Tuple.to_list()
    |> Enum.sum()
    |> Kernel./(num_workers)
  end)
end
  
  defp render_board(frame, board, message) do
    header = "| 0 | 1 | 2 | 3 | 4 | 5 | 6 |\n|---|---|---|---|---|---|---|"

    rows =
      board
      |> Nx.to_list()
      |> Enum.map(fn row ->
        cells =
          Enum.map(row, fn
            1  -> "🔴"
            -1 -> "🟡"
            0  -> "⚪"
          end)
          |> Enum.join(" | ")

        "| #{cells} |"
      end)
      |> Enum.join("\n")

    Kino.Frame.render(frame, Kino.Markdown.new("**#{message}**\n\n#{header}\n#{rows}"))
  end
end
PlayLive.start()