Powered by AppSignal & Oban Pro

MCTS - ConnectFour

mcts_connect_four.livemd

MCTS - ConnectFour

Mix.install([
  {:bumblebee, "~> 0.6.3"},
  {:nx, "~> 0.10.0"},
  {:axon, "~> 0.7.0"},
  {:kino, "~> 0.19.0"},
  {:torchx, "~> 0.10.0"},
  {:uuid, "~> 1.1"}
])

Section

Nx.default_backend(Torchx.Backend)

Game Board

defmodule GameBoard do
  use GenServer
  
  @row_count 6
  @col_count 7
  @action_size 7
  @in_a_row 4

  @fields [:board]
  @type t :: %__MODULE__{
    board: Nx.Tensor.t()  
  }
  
  defstruct @fields

  def start_link() do
    GenServer.start_link(__MODULE__, %{}, name: __MODULE__)
  end

  def get_state() do
    GenServer.call(__MODULE__, :get_state)
  end
  
  def save_board(board) do
    GenServer.cast(__MODULE__, {:save, board})
  end
  
  def create_new_board() do
    Nx.broadcast(0, {@row_count, @col_count})
  end

  def change_perspective(board, player) do
    board |> Nx.multiply(player)
  end
  
  def get_row_count, do: @row_count
  def get_col_count, do: @col_count
  def get_action_size, do: @action_size
  
  def play_move(board, col, player_id) do
    row_indices = Nx.iota({@row_count})
    
    row =
      Nx.slice_along_axis(board, col, 1, axis: 1)
      |> Nx.flatten()
      |> Nx.equal(0)
      |> Nx.multiply(row_indices)
      |> Nx.reduce_max()
      |> Nx.to_number()
    
    indices = Nx.tensor([[row, col]])
    updates = Nx.tensor([player_id], type: Nx.type(board))
    
    board = Nx.indexed_put(board, indices, updates)
    
    save_board(board)
  end  
 
  def get_valid_moves(board) do
    case Nx.shape(board) do
      {_, _row, _col} ->
        board[[.., 0]]          # Pega todas as linhas da coluna 0
        |> Nx.equal(0)          # Cria máscara booleana (true onde for 0)
        |> Nx.as_type(:u8)      # Converte para unsigned integer de 8 bits (0 ou 1)
      
      _ ->
        board[0]                # Acessa a primeira linha (índice 0)
        |> Nx.equal(0)          # Compara com zero (gera booleano)
        |> Nx.as_type(:u8)      # Converte para 1 (true) ou 0 (false)
        
    end
  end

  def check_win(board, action) do
    if is_nil(action), do: false
        
    # 1. Cria um vetor de índices: [0, 1, 2, 3, 4, 5]
    row_indices = Nx.iota({6}) 
    # 2. Pega a coluna e identifica onde NÃO está vazio
    mask = board[[.., action]] |> Nx.not_equal(0)
    # 3. Define um valor "infinito" (ex: 99) para células vazias
    # para que o Nx.reduce_min não as selecione
    row = 
      Nx.select(mask, row_indices, 99)
      |> Nx.reduce_min()

    col = action
    player = board[row][col]

    # Definindo o valor necessário para ganhar (ex: 4 - 1 = 3)
    needed = @in_a_row - 1
    
    # Verificamos cada eixo somando as direções opostas
    is_win = 
      count(board, row, action, player, 1, 0) >= needed or                          # Vertical
      (count(board, row, action, player, 0, 1) + count(board, row, col, player, 0, -1)) >= needed or    # Horizontal
      (count(board, row, action, player, 1, 1) + count(board, row, col, player, -1, -1)) >= needed or   # Diagonal Principal
      (count(board, row, action, player, 1, -1) + count(board, row, col, player, -1, 1)) >= needed      # Diagonal Secundária

    is_win
  end

  def count(board, row, action, player, offset_row, offset_column) do
    # Usamos 1..(in_a_row - 1) para simular o range(1, self.in_a_row)
    Enum.reduce_while(1..(@in_a_row - 1), 0, fn i, _acc ->
      r = row + offset_row * i
      c = action + offset_column * i
  
      # Verificamos os limites e se a peça pertence ao jogador
      if r >= 0 and r < @row_count and 
         c >= 0 and c < @col_count and 
         Nx.to_number(board[r][c]) == player do
        {:cont, i} # Continua contando e atualiza o acumulador para i
      else
        {:halt, i - 1} # Para a execução e retorna o valor anterior
      end
    end)
  end

  def get_value_and_terminated(board, action) do
    if check_win(board, action) do
      true
    else
      board
      |> get_valid_moves()      # Assume-se que retorna um tensor de 0s e 1s
      |> Nx.sum()               # Soma todos os elementos
      |> Nx.to_number() == 0    # Converte o tensor unitário em um número Elixir
    end
  end

  def get_encoded_board(state) do
    # três planos binários: state == -1, state == 0, state == 1
    planes =
      [-1, 0, 1]
      |> Enum.map(fn player ->
        Nx.equal(state, Nx.tensor(player, type: :s32))
      end)
      |> Nx.stack()                          # {3, 6, 7}
      |> Nx.as_type(:f32)
  
    # batch: {batch, 6, 7} → troca eixos 0 e 1 → {batch, 3, 6, 7}
    case Nx.shape(state) do
      {_rows, _cols}         -> planes                        # {3, 6, 7}
      {_batch, _rows, _cols} -> Nx.transpose(planes, axes: [1, 0, 2, 3])
    end
  end
  
  def get_next_board(board, action, player) do
    # coluna alvo como tensor booleano: quais linhas têm 0 nessa coluna
    col = board[[.., action]]                          # {6}
  
    # índice da linha mais baixa com 0 (np.max(np.where(...)))
    row =
      col
      |> Nx.equal(0)                                   # [1,1,1,1,1,0] por ex.
      |> Nx.to_flat_list()
      |> Enum.with_index()
      |> Enum.filter(fn {val, _i} -> val == 1 end)
      |> Enum.map(fn {_val, i} -> i end)
      |> Enum.max()
  
    # atualiza a célula [row, action] com o valor do player
    indices = Nx.tensor([[row, action]])
    updates = Nx.tensor([player], type: :s32)
    Nx.indexed_put(board, indices, updates)
  end
    
  def get_opponent(player), do: -player
  def get_opponent_value(value), do: -value
      
  @impl true
  def init(state) do
    {:ok, state}
  end

  @impl true
  def handle_cast({:save, value}, state) do
    new_state =
      state
      |> Map.put(:board, value)
    
    {:noreply, new_state}
  end

  @impl true
  def handle_call(:get_state, _from, state) do
    {:reply, state, state}
  end
end
new_board = GameBoard.create_new_board()
GameBoard.start_link()
GameBoard.save_board(new_board)
%{board: board} = GameBoard.get_state()
encoded_board = GameBoard.get_encoded_board(board)
input_tensor = Nx.new_axis(encoded_board, 0)

Dirichlet

defmodule Dirichlet do
  @moduledoc """
  Implementação de np.random.dirichlet em Elixir puro.
  Usa o método Marsaglia-Tsang para amostrar a distribuição Gamma.
  """

  @doc """
  Amostra uma distribuição de Dirichlet.

  ## Parâmetros
    - alphas: lista de concentrações [α₁, α₂, …, αₖ], todos > 0

  ## Retorno
    Lista de floats que somam 1.0

  ## Exemplo
      iex> Dirichlet.sample([1.0, 1.0, 1.0])
      [0.312, 0.485, 0.203]  # valores aleatórios
  """
  def sample(alphas) when is_list(alphas) do
    gammas = Enum.map(alphas, &amp;sample_gamma/1)
    total  = Enum.sum(gammas)
    Enum.map(gammas, fn g -> g / total end)
  end

  # ── Gamma(alpha, 1) via Marsaglia-Tsang ──────────────────────────────

  # Caso especial: alpha < 1  →  Gamma(alpha+1) * U^(1/alpha)
  defp sample_gamma(alpha) when alpha < 1 do
    u = :rand.uniform()
    sample_gamma(alpha + 1.0) * :math.pow(u, 1.0 / alpha)
  end

  # Caso geral: alpha >= 1  →  método d-c de Marsaglia-Tsang
  defp sample_gamma(alpha) do
    d = alpha - 1.0 / 3.0
    c = 1.0 / :math.sqrt(9.0 * d)
    marsaglia_loop(d, c)
  end

  defp marsaglia_loop(d, c) do
    {x, v} = sample_v(c)
    u = :rand.uniform()

    cond do
      v <= 0.0 ->
        marsaglia_loop(d, c)

      # Critério de aceitação rápido (log-squeeze)
      u < 1.0 - 0.0331 * (x * x) * (x * x) ->
        d * v

      :math.log(u) < 0.5 * x * x + d * (1.0 - v + :math.log(v)) ->
        d * v

      true ->
        marsaglia_loop(d, c)
    end
  end

  # Gera candidato v = (1 + c*x)^3 com x ~ N(0,1)
  defp sample_v(c) do
    x = sample_normal()
    v = 1.0 + c * x
    {x, v * v * v}
  end

  # Normal(0,1) via Box-Muller
  defp sample_normal do
    u1 = :rand.uniform()
    u2 = :rand.uniform()
    :math.sqrt(-2.0 * :math.log(u1)) * :math.cos(2.0 * :math.pi() * u2)
  end
end

Model ResNet

defmodule Resnet do
  defstruct [:policy_network, :value_network, :policy_params, :value_params, :config]

  @type t :: %__MODULE__{
          policy_network: Axon.t(),
          value_network: Axon.t(),
          policy_params: term() | nil,
          value_params: term() | nil,
          config: map()
        }

  def new(config) do
    %__MODULE__{
      policy_network: build_policy_network(config),
      value_network: build_value_network(config),
      config: config
    }
  end

  def initial_fit(:policy, network) do
    {init_fn, _predict_fn} = Axon.build(network.policy_network)
    template = Nx.template({3, 6, 7}, :f32)
    initial_params = init_fn.(template, Axon.ModelState.empty())
    
    %{network | policy_params: initial_params}
  end

  def initial_fit(:value, network) do
    {init_fn, _predict_fn} = Axon.build(network.value_network)
    template = Nx.template({3, 6, 7}, :f32)
    initial_params = init_fn.(template, Axon.ModelState.empty())
    
    %{network | value_params: initial_params}
  end

  # defp trainer_fit(network, x, y, epochs, config) do
  #   {init_fn, _predict_fn} = Axon.build(network)
  #   initial_params = init_fn.(x, Axon.ModelState.empty())

  #   train_standard(network, x, y, epochs, initial_params, config)
  # end

  def predict(:policy, network, inputs) do
    {_init_fn, predict_fn} = Axon.build(network.policy_network)
    predict_fn.(network.policy_params, inputs)
  end
  
  def predict(:value, network, inputs) do
    {_init_fn, predict_fn} = Axon.build(network.value_network)
    predict_fn.(network.value_params, inputs)
  end
  
  # defp train_standard(network, x, y, epochs, initial_params, config) do
  #   network
  #   |> Axon.Loop.trainer(
  #     &Axon.Losses.huber(&1, &2.combined, reduction: :mean),
  #     Polaris.Optimizers.adam(learning_rate: config.learning_rate)
  #   )
  #   |> Axon.Loop.run(Stream.repeatedly(fn -> {x, y} end), initial_params,
  #     epochs: epochs,
  #     iterations: elem(Nx.shape(y), 0),
  #     compiler: Torchx
  #   )
  # end
  
  defp start_block(num_hidden) do
    Axon.input("input", shape: {nil, 3})
    |> Axon.conv(num_hidden, kernel_size: 3, padding: :same)
    |> Axon.batch_norm()
    |> Axon.relu()
  end

  defp police_head(x) do
    x
    |> Axon.conv(32, kernel_size: 3, padding: :same)
    |> Axon.batch_norm()
    |> Axon.relu()
    |> Axon.flatten()
    |> Axon.dense(GameBoard.get_col_count(), kernel_initializer: :lecun_normal)
  end

  defp value_head(x) do
    x
    |> Axon.conv(3, kernel_size: 3, padding: :same)
    |> Axon.batch_norm()
    |> Axon.relu()
    |> Axon.flatten()
    |> Axon.dense(1, kernel_initializer: :lecun_normal)
    |> Axon.tanh()
  end

  defp res_block(x, num_hidden) do
    residual = x
  
    x =
      x
      |> Axon.conv(num_hidden, kernel_size: 3, padding: :same)
      |> Axon.batch_norm()
      |> Axon.relu()
      |> Axon.conv(num_hidden, kernel_size: 3, padding: :same)
      |> Axon.batch_norm()
  
    x
    |> Axon.add(residual)
    |> Axon.relu()
  end

  defp build_base_network(config) do
    num_res_blocks = get_in(config, [:num_res_blocks])
    num_hidden = get_in(config, [:num_hidden])

    start_block(num_hidden)
    |> then(fn input -> 
       Enum.reduce(1..num_res_blocks, input, fn _, acc -> 
         res_block(acc, num_hidden) 
       end)
     end)
  end
  
  defp build_policy_network(config) do
    build_base_network(config)
    |> police_head()
  end

  defp build_value_network(config) do
    build_base_network(config) 
    |> value_head()
  end


end
model =
  Resnet.new(%{
    num_hidden: 128,
    num_res_blocks: 9,
    learning_rate: 0.01
  })
model = Resnet.initial_fit(:policy, model)
model = Resnet.initial_fit(:value, model)
policy = Resnet.predict(:policy, model, encoded_board)
Resnet.predict(:value, model, encoded_board)

MCTS

defmodule MCTS do
  defmodule Node do
    defstruct [
      :id, :args, :board, :parent_id, :action_taken,
      prior: 0,
      visit_count: 0,
      value_sum: 0,
      children_ids: []
    ]
  end

  defmodule Tree do
    use Agent

    def start_link, do: Agent.start_link(fn -> {%{}, 0} end)

    def stop(tree), do: Agent.stop(tree)

    # insere nó, retorna id atribuído
    def put_node(tree, node) do
      Agent.get_and_update(tree, fn {nodes, next_id} ->
        node = %{node | id: next_id}
        {next_id, {Map.put(nodes, next_id, node), next_id + 1}}
      end)
    end

    def get_node(tree, id) do
      Agent.get(tree, fn {nodes, _} -> Map.fetch!(nodes, id) end)
    end

    def update_node(tree, id, fun) do
      Agent.update(tree, fn {nodes, next_id} ->
        {Map.update!(nodes, id, fun), next_id}
      end)
    end

    def add_child(tree, parent_id, child_id) do
      update_node(tree, parent_id, fn node ->
        %{node | children_ids: node.children_ids ++ [child_id]}
      end)
    end
  end

  defstruct [:model, :args]

  # ── API pública ───────────────────────────────────────────────────────

  def search(mcts, board) do
    {:ok, tree} = Tree.start_link()

    root_id =
      Tree.put_node(tree, %Node{
        args:         mcts.args,
        board:        board,
        parent_id:    nil,
        action_taken: nil,
        visit_count:  1
      })

    # inferência inicial + ruído Dirichlet
    {policy, _} = predict(mcts.model, board)

    # policy =
    #   policy
    #   |> add_dirichlet_noise(mcts.args, GameBoard.get_action_size())
    #   |> mask_and_normalize(GameBoard.get_valid_moves(board))

    # expand(tree, root_id, policy)

    # # simulações
    # Enum.each(1..mcts.args.num_mcts_searches, fn _ ->
    #   # selection
    #   node_id = select_leaf(tree, root_id)
    #   node    = Tree.get_node(tree, node_id)

    #   {value, is_terminal} =
    #     GameBoard.get_value_and_terminated(node.board, node.action_taken)

    #   value = GameBoard.get_opponent_value(value)

    #   value =
    #     if not is_terminal do
    #       {policy, model_value} = predict(mcts.model, node.board)

    #       policy =
    #         policy
    #         |> mask_and_normalize(GameBoard.get_valid_moves(node.board))

    #       expand(tree, node_id, policy)
    #       model_value
    #     else
    #       value
    #     end

    #   backpropagate(tree, node_id, value)
    # end)

    # distribução final
    # root     = Tree.get_node(tree, root_id)
    # n        = GameBoard.get_action_size()
    # probs    = List.duplicate(0.0, n)

    # probs =
    #   Enum.reduce(root.children_ids, probs, fn child_id, acc ->
    #     child = Tree.get_node(tree, child_id)
    #     List.replace_at(acc, child.action_taken, child.visit_count * 1.0)
    #   end)

    # total = Enum.sum(probs)
    # result = Enum.map(probs, fn v -> v / total end)

    # Tree.stop(tree)
    # result
  end

  # ── operações na árvore ───────────────────────────────────────────────

  defp expand(tree, node_id, policy) do
    node = Tree.get_node(tree, node_id)

    Enum.each(Enum.with_index(policy), fn {prob, action} ->
      if prob > 0 do
        child_state =
          node.board
          |> GameBoard.get_next_board(action, 1)
          |> GameBoard.change_perspective(-1)

        child_id =
          Tree.put_node(tree, %Node{
            args:         node.args,
            board:        child_state,
            parent_id:    node_id,
            action_taken: action,
            prior:        prob
          })

        Tree.add_child(tree, node_id, child_id)
      end
    end)
  end

  defp select_leaf(tree, node_id) do
    node = Tree.get_node(tree, node_id)

    if node.children_ids == [] do
      node_id
    else
      best_child_id =
        Enum.max_by(node.children_ids, fn child_id ->
          child = Tree.get_node(tree, child_id)
          get_ucb(node, child)
        end)

      select_leaf(tree, best_child_id)
    end
  end

  defp backpropagate(tree, node_id, value) do
    node = Tree.get_node(tree, node_id)

    Tree.update_node(tree, node_id, fn n ->
      %{n | value_sum: n.value_sum + value, visit_count: n.visit_count + 1}
    end)

    if node.parent_id != nil do
      backpropagate(tree, node.parent_id, GameBoard.get_opponent_value(value))
    end
  end

  defp get_ucb(parent, child) do
    q_value =
      if child.visit_count == 0 do
        0
      else
        1 - (child.value_sum / child.visit_count + 1) / 2
      end

    q_value +
      parent.args.c *
      (:math.sqrt(parent.visit_count) / (child.visit_count + 1)) *
      child.prior
  end

  defp predict(model, board) do
    inputs =
      board
      |> GameBoard.get_encoded_board()

    policy = Resnet.predict(:policy, model, inputs)
    value = Resnet.predict(:value, model, inputs)
    
    policy =
      policy
      |> Axon.Activations.softmax(axis: 1)
      |> Nx.to_flat_list()

    # value =
    #   value
    #   |> Nx.squeeze()
    #   |> Nx.to_number()

    {policy, value}
  end

  defp add_dirichlet_noise(policy, args, action_size) do
    epsilon = args.dirichlet_epsilon
    noise   = Dirichlet.sample(List.duplicate(args.dirichlet_alpha, action_size))

    Enum.zip_with(policy, noise, fn p, n ->
      (1 - epsilon) * p + epsilon * n
    end)
  end

  defp mask_and_normalize(policy, valid_moves) do
    masked = Enum.zip_with(policy, valid_moves, fn p, v -> p * v end)
    total  = Enum.sum(masked)

    if total > 0 do
      Enum.map(masked, fn p -> p / total end)
    else
      n_valid = Enum.sum(valid_moves)
      Enum.map(valid_moves, fn v -> v / n_valid end)
    end
  end
end
args = %{
  c: 2.0,
  num_mcts_searches: 600,
  dirichlet_epsilon: 0.25,
  dirichlet_alpha: 0.3
}

mcts   = %MCTS{model: model, args: args}
{x, y}  = MCTS.search(mcts, board)
# action = Enum.zip(probs, 0..6) |> Enum.max_by(fn {p, _} -> p end) |> elem(1)

Nx.squeeze(y) 
defmodule GameNode do
  use GenServer

  @fields [:board, :parent, :action_taken, :prior, :children, :visit_count, :value_sum]
  @type t :: %__MODULE__{
    board: Nx.Tensor.t(),
    parent: integer(),
    action_taken: integer(),
    prior: integer(),
    children: list(),
    visit_count: integer(),
    value_sum: integer()
  }
  
  defstruct @fields

  def start_link(args) do
    uuid = UUID.uuid4()
    GenServer.start_link(__MODULE__, args, name: {:global, uuid})
  end

  def get_state(uuid) do
    GenServer.call(uuid, :get_state)
  end

  def expand(node, policy) do
    policy
    |> Enum.with_index()
    |> Enum.filter(fn {prob, _action} -> prob > 0 end)
    |> Enum.map(fn {prob, action} ->
      child_board =
        GameBoard.get_state()
        |> GameBoard.get_next_board(action, 1)
        |> GameBoard.change_perspective(-1)

      child_node = GameNode.start_link(
        %{
          board: child_board,
          action_taken: action,
          prior: prob
        })
    
      %{node | children: child_node}
    end)
  end
  
  @impl true
  def init(%{board: board, visit_count: visit_count} = _args) do
    state =
      %{
        board: board,
        parent: nil,
        action_taken: nil,
        prior: 0,
        visit_count: visit_count || 0,
        children: []
      }
    
    {:ok, state}
  end

  @impl true
  def handle_cast({:save, value}, state) do
    new_state =
      state
      |> Map.put(:board, value)
    
    {:noreply, new_state}
  end

  @impl true
  def handle_call(:get_state, _from, state) do
    {:reply, state, state}
  end
end
defmodule MCTS do
  @dirichlet_epsilon 0.25
  @dirichlet_alpha 0.3
  @action_size 7
  
  def search(network, board) do
    root = GameNode.start_link(%{board: board, visit_count: 1})

    # 1. Obter o estado codificado
    encoded_board = GameBoard.get_encoded_board(board)
    
    # 2. Adicionar dimensão de batch (o .unsqueeze(0))
    # Transforma {3, 6, 7} em {1, 3, 6, 7}
    _input_tensor = Nx.new_axis(encoded_board, 0)
    
    # 3. Executar a inferência
    # O Axon.build retorna {init_fn, predict_fn}
    # O 'params' contém os pesos treinados da rede
    policy = Resnet.predict(:policy, network, encoded_board)

    # probabilities = 
    #   policy
    #   |> Axon.Activations.softmax(axis: 1)
    #   |> Nx.to_flat_list()

    # noise = Dirichlet.sample(List.duplicate(@dirichlet_alpha, @action_size))

    # policy =
    #   probabilities
    #   |> Enum.zip(noise)
    #   |> Enum.map(fn {p, n} -> (1 - @dirichlet_epsilon) * p + @dirichlet_epsilon * n end)

    # valid_moves = GameBoard.get_valid_moves(board)

    # policy =
    #   policy
    #   |> Enum.zip(valid_moves)
    #   |> Enum.map(fn {p, v} -> p * v end)

    # total = Enum.sum(policy)

    # policy =
    #   if total > 0 do
    #     Enum.map(policy, fn p -> p / total end)
    #   else
    #     # fallback: distribuição uniforme sobre movimentos válidos
    #     n_valid = Enum.sum(valid_moves)
    #     Enum.zip(valid_moves, valid_moves)
    #     |> Enum.map(fn {v, _} -> v / n_valid end)
    #   end  

    # GameNode.expand(root, policy)
  end
end
policy
probabilities
noise
probabilities
      |> Enum.zip(noise)
MCTS.search(fitted_model, board)

AlphaZero

defmodule AlphaZero do
  @num_interactions 48
  
  def learn() do
    Enum.reduce(1..@num_iterations, step, fn _, acc -> 
        memory = []
        
        # self.model.eval()
        # for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
        #     memory += self.selfPlay()
            
        # self.model.train()
        # for epoch in trange(self.args['num_epochs']):
        #     self.train(memory)
    end)
  end

  def self_play() do
    memory = []
    player = 1
    board = GameBoard.create_new_board()

    neutral_board = GameBoard.change_perspective(board, player)

    # action_probs = MCTS.search(neutral_state)
            
    # memory.append((neutral_state, action_probs, player))
    
    # temperature_action_probs = action_probs ** (1 / self.args['temperature'])
    # temperature_action_probs /= np.sum(temperature_action_probs)
    # action = np.random.choice(self.game.action_size, p=temperature_action_probs)
    
    # state = self.game.get_next_state(state, action, player)
    
    # value, is_terminal = self.game.get_value_and_terminated(state, action)
    
    # if is_terminal:
    #     returnMemory = []
    #     for hist_neutral_state, hist_action_probs, hist_player in memory:
    #         hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
    #         returnMemory.append((
    #             self.game.get_encoded_state(hist_neutral_state),
    #             hist_action_probs,
    #             hist_outcome
    #         ))
    #     return returnMemory
    
    # player = self.game.get_opponent(player)
    
  end
end
AlphaZero.learn()
GameBoard.play_move(board, 4, -1)
GameBoard.change_perspective(board, 1)
{:ok, pid} = GameNode.start_link(%{board: board, visit_count: 1})
GameNode.get_state(pid)
AlphaZero.self_play(board)