MCTS - ConnectFour
Mix.install([
{:bumblebee, "~> 0.6.3"},
{:nx, "~> 0.10.0"},
{:axon, "~> 0.7.0"},
{:kino, "~> 0.19.0"},
{:torchx, "~> 0.10.0"},
{:uuid, "~> 1.1"}
])
Section
Nx.default_backend(Torchx.Backend)
Game Board
defmodule GameBoard do
use GenServer
@row_count 6
@col_count 7
@action_size 7
@in_a_row 4
@fields [:board]
@type t :: %__MODULE__{
board: Nx.Tensor.t()
}
defstruct @fields
def start_link() do
GenServer.start_link(__MODULE__, %{}, name: __MODULE__)
end
def get_state() do
GenServer.call(__MODULE__, :get_state)
end
def save_board(board) do
GenServer.cast(__MODULE__, {:save, board})
end
def create_new_board() do
Nx.broadcast(0, {@row_count, @col_count})
end
def change_perspective(board, player) do
board |> Nx.multiply(player)
end
def get_row_count, do: @row_count
def get_col_count, do: @col_count
def get_action_size, do: @action_size
def play_move(board, col, player_id) do
row_indices = Nx.iota({@row_count})
row =
Nx.slice_along_axis(board, col, 1, axis: 1)
|> Nx.flatten()
|> Nx.equal(0)
|> Nx.multiply(row_indices)
|> Nx.reduce_max()
|> Nx.to_number()
indices = Nx.tensor([[row, col]])
updates = Nx.tensor([player_id], type: Nx.type(board))
board = Nx.indexed_put(board, indices, updates)
save_board(board)
end
def get_valid_moves(board) do
case Nx.shape(board) do
{_, _row, _col} ->
board[[.., 0]] # Pega todas as linhas da coluna 0
|> Nx.equal(0) # Cria máscara booleana (true onde for 0)
|> Nx.as_type(:u8) # Converte para unsigned integer de 8 bits (0 ou 1)
_ ->
board[0] # Acessa a primeira linha (índice 0)
|> Nx.equal(0) # Compara com zero (gera booleano)
|> Nx.as_type(:u8) # Converte para 1 (true) ou 0 (false)
end
end
def check_win(board, action) do
if is_nil(action), do: false
# 1. Cria um vetor de índices: [0, 1, 2, 3, 4, 5]
row_indices = Nx.iota({6})
# 2. Pega a coluna e identifica onde NÃO está vazio
mask = board[[.., action]] |> Nx.not_equal(0)
# 3. Define um valor "infinito" (ex: 99) para células vazias
# para que o Nx.reduce_min não as selecione
row =
Nx.select(mask, row_indices, 99)
|> Nx.reduce_min()
col = action
player = board[row][col]
# Definindo o valor necessário para ganhar (ex: 4 - 1 = 3)
needed = @in_a_row - 1
# Verificamos cada eixo somando as direções opostas
is_win =
count(board, row, action, player, 1, 0) >= needed or # Vertical
(count(board, row, action, player, 0, 1) + count(board, row, col, player, 0, -1)) >= needed or # Horizontal
(count(board, row, action, player, 1, 1) + count(board, row, col, player, -1, -1)) >= needed or # Diagonal Principal
(count(board, row, action, player, 1, -1) + count(board, row, col, player, -1, 1)) >= needed # Diagonal Secundária
is_win
end
def count(board, row, action, player, offset_row, offset_column) do
# Usamos 1..(in_a_row - 1) para simular o range(1, self.in_a_row)
Enum.reduce_while(1..(@in_a_row - 1), 0, fn i, _acc ->
r = row + offset_row * i
c = action + offset_column * i
# Verificamos os limites e se a peça pertence ao jogador
if r >= 0 and r < @row_count and
c >= 0 and c < @col_count and
Nx.to_number(board[r][c]) == player do
{:cont, i} # Continua contando e atualiza o acumulador para i
else
{:halt, i - 1} # Para a execução e retorna o valor anterior
end
end)
end
def get_value_and_terminated(board, action) do
if check_win(board, action) do
true
else
board
|> get_valid_moves() # Assume-se que retorna um tensor de 0s e 1s
|> Nx.sum() # Soma todos os elementos
|> Nx.to_number() == 0 # Converte o tensor unitário em um número Elixir
end
end
def get_encoded_board(state) do
# três planos binários: state == -1, state == 0, state == 1
planes =
[-1, 0, 1]
|> Enum.map(fn player ->
Nx.equal(state, Nx.tensor(player, type: :s32))
end)
|> Nx.stack() # {3, 6, 7}
|> Nx.as_type(:f32)
# batch: {batch, 6, 7} → troca eixos 0 e 1 → {batch, 3, 6, 7}
case Nx.shape(state) do
{_rows, _cols} -> planes # {3, 6, 7}
{_batch, _rows, _cols} -> Nx.transpose(planes, axes: [1, 0, 2, 3])
end
end
def get_next_board(board, action, player) do
# coluna alvo como tensor booleano: quais linhas têm 0 nessa coluna
col = board[[.., action]] # {6}
# índice da linha mais baixa com 0 (np.max(np.where(...)))
row =
col
|> Nx.equal(0) # [1,1,1,1,1,0] por ex.
|> Nx.to_flat_list()
|> Enum.with_index()
|> Enum.filter(fn {val, _i} -> val == 1 end)
|> Enum.map(fn {_val, i} -> i end)
|> Enum.max()
# atualiza a célula [row, action] com o valor do player
indices = Nx.tensor([[row, action]])
updates = Nx.tensor([player], type: :s32)
Nx.indexed_put(board, indices, updates)
end
def get_opponent(player), do: -player
def get_opponent_value(value), do: -value
@impl true
def init(state) do
{:ok, state}
end
@impl true
def handle_cast({:save, value}, state) do
new_state =
state
|> Map.put(:board, value)
{:noreply, new_state}
end
@impl true
def handle_call(:get_state, _from, state) do
{:reply, state, state}
end
end
new_board = GameBoard.create_new_board()
GameBoard.start_link()
GameBoard.save_board(new_board)
%{board: board} = GameBoard.get_state()
encoded_board = GameBoard.get_encoded_board(board)
input_tensor = Nx.new_axis(encoded_board, 0)
Dirichlet
defmodule Dirichlet do
@moduledoc """
Implementação de np.random.dirichlet em Elixir puro.
Usa o método Marsaglia-Tsang para amostrar a distribuição Gamma.
"""
@doc """
Amostra uma distribuição de Dirichlet.
## Parâmetros
- alphas: lista de concentrações [α₁, α₂, …, αₖ], todos > 0
## Retorno
Lista de floats que somam 1.0
## Exemplo
iex> Dirichlet.sample([1.0, 1.0, 1.0])
[0.312, 0.485, 0.203] # valores aleatórios
"""
def sample(alphas) when is_list(alphas) do
gammas = Enum.map(alphas, &sample_gamma/1)
total = Enum.sum(gammas)
Enum.map(gammas, fn g -> g / total end)
end
# ── Gamma(alpha, 1) via Marsaglia-Tsang ──────────────────────────────
# Caso especial: alpha < 1 → Gamma(alpha+1) * U^(1/alpha)
defp sample_gamma(alpha) when alpha < 1 do
u = :rand.uniform()
sample_gamma(alpha + 1.0) * :math.pow(u, 1.0 / alpha)
end
# Caso geral: alpha >= 1 → método d-c de Marsaglia-Tsang
defp sample_gamma(alpha) do
d = alpha - 1.0 / 3.0
c = 1.0 / :math.sqrt(9.0 * d)
marsaglia_loop(d, c)
end
defp marsaglia_loop(d, c) do
{x, v} = sample_v(c)
u = :rand.uniform()
cond do
v <= 0.0 ->
marsaglia_loop(d, c)
# Critério de aceitação rápido (log-squeeze)
u < 1.0 - 0.0331 * (x * x) * (x * x) ->
d * v
:math.log(u) < 0.5 * x * x + d * (1.0 - v + :math.log(v)) ->
d * v
true ->
marsaglia_loop(d, c)
end
end
# Gera candidato v = (1 + c*x)^3 com x ~ N(0,1)
defp sample_v(c) do
x = sample_normal()
v = 1.0 + c * x
{x, v * v * v}
end
# Normal(0,1) via Box-Muller
defp sample_normal do
u1 = :rand.uniform()
u2 = :rand.uniform()
:math.sqrt(-2.0 * :math.log(u1)) * :math.cos(2.0 * :math.pi() * u2)
end
end
Model ResNet
defmodule Resnet do
defstruct [:policy_network, :value_network, :policy_params, :value_params, :config]
@type t :: %__MODULE__{
policy_network: Axon.t(),
value_network: Axon.t(),
policy_params: term() | nil,
value_params: term() | nil,
config: map()
}
def new(config) do
%__MODULE__{
policy_network: build_policy_network(config),
value_network: build_value_network(config),
config: config
}
end
def initial_fit(:policy, network) do
{init_fn, _predict_fn} = Axon.build(network.policy_network)
template = Nx.template({3, 6, 7}, :f32)
initial_params = init_fn.(template, Axon.ModelState.empty())
%{network | policy_params: initial_params}
end
def initial_fit(:value, network) do
{init_fn, _predict_fn} = Axon.build(network.value_network)
template = Nx.template({3, 6, 7}, :f32)
initial_params = init_fn.(template, Axon.ModelState.empty())
%{network | value_params: initial_params}
end
# defp trainer_fit(network, x, y, epochs, config) do
# {init_fn, _predict_fn} = Axon.build(network)
# initial_params = init_fn.(x, Axon.ModelState.empty())
# train_standard(network, x, y, epochs, initial_params, config)
# end
def predict(:policy, network, inputs) do
{_init_fn, predict_fn} = Axon.build(network.policy_network)
predict_fn.(network.policy_params, inputs)
end
def predict(:value, network, inputs) do
{_init_fn, predict_fn} = Axon.build(network.value_network)
predict_fn.(network.value_params, inputs)
end
# defp train_standard(network, x, y, epochs, initial_params, config) do
# network
# |> Axon.Loop.trainer(
# &Axon.Losses.huber(&1, &2.combined, reduction: :mean),
# Polaris.Optimizers.adam(learning_rate: config.learning_rate)
# )
# |> Axon.Loop.run(Stream.repeatedly(fn -> {x, y} end), initial_params,
# epochs: epochs,
# iterations: elem(Nx.shape(y), 0),
# compiler: Torchx
# )
# end
defp start_block(num_hidden) do
Axon.input("input", shape: {nil, 3})
|> Axon.conv(num_hidden, kernel_size: 3, padding: :same)
|> Axon.batch_norm()
|> Axon.relu()
end
defp police_head(x) do
x
|> Axon.conv(32, kernel_size: 3, padding: :same)
|> Axon.batch_norm()
|> Axon.relu()
|> Axon.flatten()
|> Axon.dense(GameBoard.get_col_count(), kernel_initializer: :lecun_normal)
end
defp value_head(x) do
x
|> Axon.conv(3, kernel_size: 3, padding: :same)
|> Axon.batch_norm()
|> Axon.relu()
|> Axon.flatten()
|> Axon.dense(1, kernel_initializer: :lecun_normal)
|> Axon.tanh()
end
defp res_block(x, num_hidden) do
residual = x
x =
x
|> Axon.conv(num_hidden, kernel_size: 3, padding: :same)
|> Axon.batch_norm()
|> Axon.relu()
|> Axon.conv(num_hidden, kernel_size: 3, padding: :same)
|> Axon.batch_norm()
x
|> Axon.add(residual)
|> Axon.relu()
end
defp build_base_network(config) do
num_res_blocks = get_in(config, [:num_res_blocks])
num_hidden = get_in(config, [:num_hidden])
start_block(num_hidden)
|> then(fn input ->
Enum.reduce(1..num_res_blocks, input, fn _, acc ->
res_block(acc, num_hidden)
end)
end)
end
defp build_policy_network(config) do
build_base_network(config)
|> police_head()
end
defp build_value_network(config) do
build_base_network(config)
|> value_head()
end
end
model =
Resnet.new(%{
num_hidden: 128,
num_res_blocks: 9,
learning_rate: 0.01
})
model = Resnet.initial_fit(:policy, model)
model = Resnet.initial_fit(:value, model)
policy = Resnet.predict(:policy, model, encoded_board)
Resnet.predict(:value, model, encoded_board)
MCTS
defmodule MCTS do
defmodule Node do
defstruct [
:id, :args, :board, :parent_id, :action_taken,
prior: 0,
visit_count: 0,
value_sum: 0,
children_ids: []
]
end
defmodule Tree do
use Agent
def start_link, do: Agent.start_link(fn -> {%{}, 0} end)
def stop(tree), do: Agent.stop(tree)
# insere nó, retorna id atribuído
def put_node(tree, node) do
Agent.get_and_update(tree, fn {nodes, next_id} ->
node = %{node | id: next_id}
{next_id, {Map.put(nodes, next_id, node), next_id + 1}}
end)
end
def get_node(tree, id) do
Agent.get(tree, fn {nodes, _} -> Map.fetch!(nodes, id) end)
end
def update_node(tree, id, fun) do
Agent.update(tree, fn {nodes, next_id} ->
{Map.update!(nodes, id, fun), next_id}
end)
end
def add_child(tree, parent_id, child_id) do
update_node(tree, parent_id, fn node ->
%{node | children_ids: node.children_ids ++ [child_id]}
end)
end
end
defstruct [:model, :args]
# ── API pública ───────────────────────────────────────────────────────
def search(mcts, board) do
{:ok, tree} = Tree.start_link()
root_id =
Tree.put_node(tree, %Node{
args: mcts.args,
board: board,
parent_id: nil,
action_taken: nil,
visit_count: 1
})
# inferência inicial + ruído Dirichlet
{policy, _} = predict(mcts.model, board)
# policy =
# policy
# |> add_dirichlet_noise(mcts.args, GameBoard.get_action_size())
# |> mask_and_normalize(GameBoard.get_valid_moves(board))
# expand(tree, root_id, policy)
# # simulações
# Enum.each(1..mcts.args.num_mcts_searches, fn _ ->
# # selection
# node_id = select_leaf(tree, root_id)
# node = Tree.get_node(tree, node_id)
# {value, is_terminal} =
# GameBoard.get_value_and_terminated(node.board, node.action_taken)
# value = GameBoard.get_opponent_value(value)
# value =
# if not is_terminal do
# {policy, model_value} = predict(mcts.model, node.board)
# policy =
# policy
# |> mask_and_normalize(GameBoard.get_valid_moves(node.board))
# expand(tree, node_id, policy)
# model_value
# else
# value
# end
# backpropagate(tree, node_id, value)
# end)
# distribução final
# root = Tree.get_node(tree, root_id)
# n = GameBoard.get_action_size()
# probs = List.duplicate(0.0, n)
# probs =
# Enum.reduce(root.children_ids, probs, fn child_id, acc ->
# child = Tree.get_node(tree, child_id)
# List.replace_at(acc, child.action_taken, child.visit_count * 1.0)
# end)
# total = Enum.sum(probs)
# result = Enum.map(probs, fn v -> v / total end)
# Tree.stop(tree)
# result
end
# ── operações na árvore ───────────────────────────────────────────────
defp expand(tree, node_id, policy) do
node = Tree.get_node(tree, node_id)
Enum.each(Enum.with_index(policy), fn {prob, action} ->
if prob > 0 do
child_state =
node.board
|> GameBoard.get_next_board(action, 1)
|> GameBoard.change_perspective(-1)
child_id =
Tree.put_node(tree, %Node{
args: node.args,
board: child_state,
parent_id: node_id,
action_taken: action,
prior: prob
})
Tree.add_child(tree, node_id, child_id)
end
end)
end
defp select_leaf(tree, node_id) do
node = Tree.get_node(tree, node_id)
if node.children_ids == [] do
node_id
else
best_child_id =
Enum.max_by(node.children_ids, fn child_id ->
child = Tree.get_node(tree, child_id)
get_ucb(node, child)
end)
select_leaf(tree, best_child_id)
end
end
defp backpropagate(tree, node_id, value) do
node = Tree.get_node(tree, node_id)
Tree.update_node(tree, node_id, fn n ->
%{n | value_sum: n.value_sum + value, visit_count: n.visit_count + 1}
end)
if node.parent_id != nil do
backpropagate(tree, node.parent_id, GameBoard.get_opponent_value(value))
end
end
defp get_ucb(parent, child) do
q_value =
if child.visit_count == 0 do
0
else
1 - (child.value_sum / child.visit_count + 1) / 2
end
q_value +
parent.args.c *
(:math.sqrt(parent.visit_count) / (child.visit_count + 1)) *
child.prior
end
defp predict(model, board) do
inputs =
board
|> GameBoard.get_encoded_board()
policy = Resnet.predict(:policy, model, inputs)
value = Resnet.predict(:value, model, inputs)
policy =
policy
|> Axon.Activations.softmax(axis: 1)
|> Nx.to_flat_list()
# value =
# value
# |> Nx.squeeze()
# |> Nx.to_number()
{policy, value}
end
defp add_dirichlet_noise(policy, args, action_size) do
epsilon = args.dirichlet_epsilon
noise = Dirichlet.sample(List.duplicate(args.dirichlet_alpha, action_size))
Enum.zip_with(policy, noise, fn p, n ->
(1 - epsilon) * p + epsilon * n
end)
end
defp mask_and_normalize(policy, valid_moves) do
masked = Enum.zip_with(policy, valid_moves, fn p, v -> p * v end)
total = Enum.sum(masked)
if total > 0 do
Enum.map(masked, fn p -> p / total end)
else
n_valid = Enum.sum(valid_moves)
Enum.map(valid_moves, fn v -> v / n_valid end)
end
end
end
args = %{
c: 2.0,
num_mcts_searches: 600,
dirichlet_epsilon: 0.25,
dirichlet_alpha: 0.3
}
mcts = %MCTS{model: model, args: args}
{x, y} = MCTS.search(mcts, board)
# action = Enum.zip(probs, 0..6) |> Enum.max_by(fn {p, _} -> p end) |> elem(1)
Nx.squeeze(y)
defmodule GameNode do
use GenServer
@fields [:board, :parent, :action_taken, :prior, :children, :visit_count, :value_sum]
@type t :: %__MODULE__{
board: Nx.Tensor.t(),
parent: integer(),
action_taken: integer(),
prior: integer(),
children: list(),
visit_count: integer(),
value_sum: integer()
}
defstruct @fields
def start_link(args) do
uuid = UUID.uuid4()
GenServer.start_link(__MODULE__, args, name: {:global, uuid})
end
def get_state(uuid) do
GenServer.call(uuid, :get_state)
end
def expand(node, policy) do
policy
|> Enum.with_index()
|> Enum.filter(fn {prob, _action} -> prob > 0 end)
|> Enum.map(fn {prob, action} ->
child_board =
GameBoard.get_state()
|> GameBoard.get_next_board(action, 1)
|> GameBoard.change_perspective(-1)
child_node = GameNode.start_link(
%{
board: child_board,
action_taken: action,
prior: prob
})
%{node | children: child_node}
end)
end
@impl true
def init(%{board: board, visit_count: visit_count} = _args) do
state =
%{
board: board,
parent: nil,
action_taken: nil,
prior: 0,
visit_count: visit_count || 0,
children: []
}
{:ok, state}
end
@impl true
def handle_cast({:save, value}, state) do
new_state =
state
|> Map.put(:board, value)
{:noreply, new_state}
end
@impl true
def handle_call(:get_state, _from, state) do
{:reply, state, state}
end
end
defmodule MCTS do
@dirichlet_epsilon 0.25
@dirichlet_alpha 0.3
@action_size 7
def search(network, board) do
root = GameNode.start_link(%{board: board, visit_count: 1})
# 1. Obter o estado codificado
encoded_board = GameBoard.get_encoded_board(board)
# 2. Adicionar dimensão de batch (o .unsqueeze(0))
# Transforma {3, 6, 7} em {1, 3, 6, 7}
_input_tensor = Nx.new_axis(encoded_board, 0)
# 3. Executar a inferência
# O Axon.build retorna {init_fn, predict_fn}
# O 'params' contém os pesos treinados da rede
policy = Resnet.predict(:policy, network, encoded_board)
# probabilities =
# policy
# |> Axon.Activations.softmax(axis: 1)
# |> Nx.to_flat_list()
# noise = Dirichlet.sample(List.duplicate(@dirichlet_alpha, @action_size))
# policy =
# probabilities
# |> Enum.zip(noise)
# |> Enum.map(fn {p, n} -> (1 - @dirichlet_epsilon) * p + @dirichlet_epsilon * n end)
# valid_moves = GameBoard.get_valid_moves(board)
# policy =
# policy
# |> Enum.zip(valid_moves)
# |> Enum.map(fn {p, v} -> p * v end)
# total = Enum.sum(policy)
# policy =
# if total > 0 do
# Enum.map(policy, fn p -> p / total end)
# else
# # fallback: distribuição uniforme sobre movimentos válidos
# n_valid = Enum.sum(valid_moves)
# Enum.zip(valid_moves, valid_moves)
# |> Enum.map(fn {v, _} -> v / n_valid end)
# end
# GameNode.expand(root, policy)
end
end
policy
probabilities
noise
probabilities
|> Enum.zip(noise)
MCTS.search(fitted_model, board)
AlphaZero
defmodule AlphaZero do
@num_interactions 48
def learn() do
Enum.reduce(1..@num_iterations, step, fn _, acc ->
memory = []
# self.model.eval()
# for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
# memory += self.selfPlay()
# self.model.train()
# for epoch in trange(self.args['num_epochs']):
# self.train(memory)
end)
end
def self_play() do
memory = []
player = 1
board = GameBoard.create_new_board()
neutral_board = GameBoard.change_perspective(board, player)
# action_probs = MCTS.search(neutral_state)
# memory.append((neutral_state, action_probs, player))
# temperature_action_probs = action_probs ** (1 / self.args['temperature'])
# temperature_action_probs /= np.sum(temperature_action_probs)
# action = np.random.choice(self.game.action_size, p=temperature_action_probs)
# state = self.game.get_next_state(state, action, player)
# value, is_terminal = self.game.get_value_and_terminated(state, action)
# if is_terminal:
# returnMemory = []
# for hist_neutral_state, hist_action_probs, hist_player in memory:
# hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
# returnMemory.append((
# self.game.get_encoded_state(hist_neutral_state),
# hist_action_probs,
# hist_outcome
# ))
# return returnMemory
# player = self.game.get_opponent(player)
end
end
AlphaZero.learn()
GameBoard.play_move(board, 4, -1)
GameBoard.change_perspective(board, 1)
{:ok, pid} = GameNode.start_link(%{board: board, visit_count: 1})
GameNode.get_state(pid)
AlphaZero.self_play(board)