MCTS - ConnectFour
Mix.install([
{:bumblebee, "~> 0.6.3"},
{:nx, "~> 0.10.0"},
{:axon, "~> 0.7.0"},
{:kino, "~> 0.19.0"},
{:torchx, "~> 0.10.0"},
{:uuid, "~> 1.1"}
])
Section
Nx.default_backend(Torchx.Backend)
defmodule Utils do
def mask_and_normalize(policy, valid_moves) do
masked = Enum.zip_with(policy, valid_moves, fn p, v -> p * v end)
total = Enum.sum(masked)
if total > 0 do
Enum.map(masked, fn p -> p / total end)
else
n_valid = Enum.sum(valid_moves)
Enum.map(valid_moves, fn v -> v / n_valid end)
end
end
def apply_temperature(action_probs, temperature) do
probs = Enum.map(action_probs, fn p -> :math.pow(p, 1.0 / temperature) end)
total = Enum.sum(probs)
Enum.map(probs, fn p -> p / total end)
end
def sample_action(probs, action_size) do
r = :rand.uniform()
probs
|> Enum.with_index()
|> Enum.reduce_while(0.0, fn {p, i}, acc ->
acc = acc + p
if r <= acc, do: {:halt, i}, else: {:cont, acc}
end)
|> then(fn
i when is_integer(i) -> i
_ -> action_size - 1
end)
end
def add_dirichlet_noise(policy, args) do
epsilon = args.dirichlet_epsilon
noise = Dirichlet.sample(List.duplicate(args.dirichlet_alpha, args.action_size))
Enum.zip_with(policy, noise, fn p, n ->
(1 - epsilon) * p + epsilon * n
end)
end
end
Game
defmodule Game do
defstruct [:row_count, :col_count, :action_size, :in_a_row]
def new(opts \\ []) do
row_count = Keyword.get(opts, :row_count, 6)
col_count = Keyword.get(opts, :col_count, 7)
%Game{
row_count: row_count,
col_count: col_count,
action_size: Keyword.get(opts, :action_size, col_count),
in_a_row: Keyword.get(opts, :in_a_row, 4)
}
end
# ── estado inicial ─────────────────────────────────────────────────────
def get_initial_board(%Game{row_count: rows, col_count: cols}) do
Nx.broadcast(Nx.tensor(0, type: :s32), {rows, cols})
end
# ── próximo estado ─────────────────────────────────────────────────────
def get_next_board(board, action, player) do
col = board[[.., action]]
empty_rows =
col
|> Nx.equal(0)
|> Nx.to_flat_list()
|> Enum.with_index()
|> Enum.filter(fn {val, _i} -> val == 1 end)
|> Enum.map(fn {_val, i} -> i end)
case empty_rows do
[] ->
board
rows ->
row = Enum.max(rows)
indices = Nx.tensor([[row, action]])
updates = Nx.tensor([player], type: :s32)
Nx.indexed_put(board, indices, updates)
end
end
# ── movimentos válidos ─────────────────────────────────────────────────
def get_valid_moves(board) do
case Nx.shape(board) do
{_rows, _cols} ->
board[0]
|> Nx.equal(0)
|> Nx.as_type(:u8)
{_batch, _rows, _cols} ->
board[[.., 0, ..]]
|> Nx.equal(0)
|> Nx.as_type(:u8)
end
end
# ── verificação de vitória ─────────────────────────────────────────────
def check_win(_game, _board, nil), do: false
def check_win(%Game{row_count: row_count, col_count: col_count, in_a_row: in_a_row}, board, action) do
col = board[[.., action]] |> Nx.to_flat_list()
occupied =
col
|> Enum.with_index()
|> Enum.filter(fn {val, _i} -> val != 0 end)
|> Enum.map(fn {_val, i} -> i end)
case occupied do
[] ->
false
rows ->
row = Enum.min(rows)
player = board[row][action] |> Nx.to_number()
board_list = Nx.to_list(board)
count = fn offset_row, offset_col ->
Enum.reduce_while(1..(in_a_row - 1), 0, fn i, _acc ->
r = row + offset_row * i
c = action + offset_col * i
if r < 0 or r >= row_count or c < 0 or c >= col_count do
{:halt, i - 1}
else
val = board_list |> Enum.at(r) |> Enum.at(c)
if val != player, do: {:halt, i - 1}, else: {:cont, i}
end
end)
end
count.(1, 0) >= in_a_row - 1 or
count.(0, 1) + count.(0, -1) >= in_a_row - 1 or
count.(1, 1) + count.(-1, -1) >= in_a_row - 1 or
count.(1, -1) + count.(-1, 1) >= in_a_row - 1
end
end
# ── valor e terminal ───────────────────────────────────────────────────
def get_value_and_terminated(%Game{} = game, board, action) do
if check_win(game, board, action) do
{1, true}
else
valid_sum =
board
|> get_valid_moves()
|> Nx.sum()
|> Nx.to_number()
if valid_sum == 0, do: {0, true}, else: {0, false}
end
end
# ── perspectiva ────────────────────────────────────────────────────────
def change_perspective(board, player) do
Nx.multiply(board, Nx.tensor(player, type: :s32))
end
# ── encode ─────────────────────────────────────────────────────────────
def get_encoded_board(board) do
planes =
[-1, 0, 1]
|> Enum.map(fn player ->
Nx.equal(board, Nx.tensor(player, type: :s32))
end)
|> Nx.stack()
|> Nx.as_type(:f32)
case Nx.shape(board) do
{_rows, _cols} -> planes
{_batch, _rows, _cols} -> Nx.transpose(planes, axes: [1, 0, 2, 3])
end
end
# ── helpers ────────────────────────────────────────────────────────────
def get_opponent(player), do: -player
def get_opponent_value(value), do: -value
end
Distribuição de Dirichlet
defmodule Dirichlet do
@moduledoc """
Implementação de np.random.dirichlet em Elixir puro.
Usa o método Marsaglia-Tsang para amostrar a distribuição Gamma.
"""
@doc """
Amostra uma distribuição de Dirichlet.
## Parâmetros
- alphas: lista de concentrações [α₁, α₂, …, αₖ], todos > 0
## Retorno
Lista de floats que somam 1.0
## Exemplo
iex> Dirichlet.sample([1.0, 1.0, 1.0])
[0.312, 0.485, 0.203] # valores aleatórios
"""
def sample(alphas) when is_list(alphas) do
gammas = Enum.map(alphas, &sample_gamma/1)
total = Enum.sum(gammas)
Enum.map(gammas, fn g -> g / total end)
end
# ── Gamma(alpha, 1) via Marsaglia-Tsang ──────────────────────────────
# Caso especial: alpha < 1 → Gamma(alpha+1) * U^(1/alpha)
defp sample_gamma(alpha) when alpha < 1 do
u = :rand.uniform()
sample_gamma(alpha + 1.0) * :math.pow(u, 1.0 / alpha)
end
# Caso geral: alpha >= 1 → método d-c de Marsaglia-Tsang
defp sample_gamma(alpha) do
d = alpha - 1.0 / 3.0
c = 1.0 / :math.sqrt(9.0 * d)
marsaglia_loop(d, c)
end
defp marsaglia_loop(d, c) do
{x, v} = sample_v(c)
u = :rand.uniform()
cond do
v <= 0.0 ->
marsaglia_loop(d, c)
# Critério de aceitação rápido (log-squeeze)
u < 1.0 - 0.0331 * (x * x) * (x * x) ->
d * v
:math.log(u) < 0.5 * x * x + d * (1.0 - v + :math.log(v)) ->
d * v
true ->
marsaglia_loop(d, c)
end
end
# Gera candidato v = (1 + c*x)^3 com x ~ N(0,1)
defp sample_v(c) do
x = sample_normal()
v = 1.0 + c * x
{x, v * v * v}
end
# Normal(0,1) via Box-Muller
defp sample_normal do
u1 = :rand.uniform()
u2 = :rand.uniform()
:math.sqrt(-2.0 * :math.log(u1)) * :math.cos(2.0 * :math.pi() * u2)
end
end
Model ResNet
defmodule ResNet do
import Axon
def start_link(game, num_res_blocks, num_hidden, learning_rate) do
model = build(game, num_res_blocks, num_hidden)
{init_fn, predict_fn} = Axon.build(model, mode: :inference)
params = init_fn.(Nx.template({1, 3, game.row_count, game.col_count}, :f32), %{})
Agent.start_link(fn ->
%{
model: model,
params: params,
predict_fn: predict_fn,
learning_rate: learning_rate
}
end, name: __MODULE__)
end
# ── inferência ─────────────────────────────────────────────────────────
def predict(input) do
%{params: params, predict_fn: predict_fn} =
Agent.get(__MODULE__, & &1)
predict_fn.(params, %{"state" => input})
end
def get_params, do: Agent.get(__MODULE__, & &1.params)
def get_opt_state, do: Agent.get(__MODULE__, & &1.opt_state)
def set_mode(_mode), do: :ok
# ── treinamento via Axon.Loop ──────────────────────────────────────────
def train_step(state_tensor, policy_targets, value_targets) do
%{model: model, params: params, learning_rate: learning_rate} =
Agent.get(__MODULE__, & &1)
loss_fn = fn {out_policy, out_value}, {p_targets, v_targets} ->
policy_loss =
out_policy
|> Axon.Activations.log_softmax(axis: 1)
|> Nx.multiply(p_targets)
|> Nx.sum(axes: [1])
|> Nx.mean()
|> Nx.negate()
value_loss =
out_value
|> Nx.subtract(v_targets)
|> Nx.pow(2)
|> Nx.mean()
Nx.add(policy_loss, value_loss)
end
new_params =
model
|> Axon.Loop.trainer(loss_fn, Polaris.Optimizers.adam(learning_rate: learning_rate))
|> Axon.Loop.run(
[{%{"state" => state_tensor}, {policy_targets, value_targets}}],
params,
epochs: 1
)
Agent.update(__MODULE__, fn state ->
%{state | params: new_params}
end)
end
# ── save / load ────────────────────────────────────────────────────────
def save(iteration) do
params = get_params()
File.mkdir_p!("models")
File.write!("models/params_#{iteration}.nx", Nx.serialize(params))
IO.puts("modelo salvo: models/params_#{iteration}.nx")
end
def load(iteration) do
params = "models/params_#{iteration}.nx" |> File.read!() |> Nx.deserialize()
Agent.update(__MODULE__, fn state -> %{state | params: params} end)
end
# ── build ──────────────────────────────────────────────────────────────
def build(game, num_res_blocks, num_hidden) do
input = input("state", shape: {nil, 3, game.row_count, game.col_count})
x = start_block(input, num_hidden)
x = Enum.reduce(1..num_res_blocks, x, fn _i, x -> res_block(x, num_hidden) end)
policy = policy_head(x, game)
value = value_head(x, game)
Axon.container({policy, value})
end
defp start_block(x, num_hidden) do
x
|> conv(num_hidden, kernel_size: {3, 3}, padding: :same, use_bias: false)
|> batch_norm()
|> relu()
end
defp res_block(x, num_hidden) do
residual = x
x
|> conv(num_hidden, kernel_size: {3, 3}, padding: :same, use_bias: false)
|> batch_norm()
|> relu()
|> conv(num_hidden, kernel_size: {3, 3}, padding: :same, use_bias: false)
|> batch_norm()
|> add(residual)
|> relu()
end
defp policy_head(x, game) do
x
|> conv(32, kernel_size: {3, 3}, padding: :same, use_bias: false)
|> batch_norm()
|> relu()
|> flatten()
|> dense(game.action_size)
end
defp value_head(x, game) do
x
|> conv(3, kernel_size: {3, 3}, padding: :same, use_bias: false)
|> batch_norm()
|> relu()
|> flatten()
|> dense(1)
|> tanh()
end
end
MCTS
defmodule MCTS do
defmodule Node do
defstruct [
:id, :game, :args, :board, :parent_id, :action_taken,
prior: 0,
visit_count: 0,
value_sum: 0,
children_ids: []
]
end
defmodule Tree do
use Agent
def start_link, do: Agent.start_link(fn -> {%{}, 0} end)
def stop(tree), do: Agent.stop(tree)
def put_node(tree, node) do
Agent.get_and_update(tree, fn {nodes, next_id} ->
node = %{node | id: next_id}
{next_id, {Map.put(nodes, next_id, node), next_id + 1}}
end)
end
def get_node(tree, id) do
Agent.get(tree, fn {nodes, _} -> Map.fetch!(nodes, id) end)
end
def update_node(tree, id, fun) do
Agent.update(tree, fn {nodes, next_id} ->
{Map.update!(nodes, id, fun), next_id}
end)
end
def add_child(tree, parent_id, child_id) do
update_node(tree, parent_id, fn node ->
%{node | children_ids: node.children_ids ++ [child_id]}
end)
end
end
def search(%{args: args, game: game}, board) do
{:ok, tree} = Tree.start_link()
root_id =
Tree.put_node(tree, %Node{
game: game,
args: args,
board: board,
parent_id: nil,
action_taken: nil,
visit_count: 1
})
{policy, _} = predict(board)
valid_moves =
board
|> Game.get_valid_moves()
|> Nx.to_flat_list()
policy =
policy
|> Utils.add_dirichlet_noise(args)
|> Utils.mask_and_normalize(valid_moves)
expand(tree, root_id, policy)
Enum.each(1..args.num_mcts_searches, fn _ ->
node_id = select_leaf(tree, root_id)
node = Tree.get_node(tree, node_id)
{value, is_terminal} =
Game.get_value_and_terminated(game, node.board, node.action_taken)
value = Game.get_opponent_value(value)
value =
if not is_terminal do
{policy, value} = predict(node.board)
valid_moves =
node.board
|> Game.get_valid_moves()
|> Nx.to_flat_list()
policy = Utils.mask_and_normalize(policy, valid_moves)
expand(tree, node_id, policy)
value |> Nx.squeeze() |> Nx.to_number()
else
value
end
backpropagate(tree, node_id, value)
end)
root = Tree.get_node(tree, root_id)
probs = List.duplicate(0.0, args.action_size)
probs =
Enum.reduce(root.children_ids, probs, fn child_id, acc ->
child = Tree.get_node(tree, child_id)
List.replace_at(acc, child.action_taken, child.visit_count * 1.0)
end)
total = Enum.sum(probs)
result = Enum.map(probs, fn v -> v / total end)
Tree.stop(tree)
result
end
# ── operações na árvore ───────────────────────────────────────────────
defp expand(tree, node_id, policy) do
node = Tree.get_node(tree, node_id)
valid_moves = Game.get_valid_moves(node.board) |> Nx.to_flat_list()
policy
|> Enum.with_index()
|> Enum.zip(valid_moves)
|> Enum.each(fn {{prob, action}, valid} ->
if prob > 0 and valid == 1 do
child_board =
node.board
|> Game.get_next_board(action, 1)
|> Game.change_perspective(-1)
child_id =
Tree.put_node(tree, %Node{
game: node.game,
args: node.args,
board: child_board,
parent_id: node_id,
action_taken: action,
prior: prob
})
Tree.add_child(tree, node_id, child_id)
end
end)
end
defp select_leaf(tree, node_id) do
node = Tree.get_node(tree, node_id)
if node.children_ids == [] do
node_id
else
best_child_id =
Enum.max_by(node.children_ids, fn child_id ->
child = Tree.get_node(tree, child_id)
get_ucb(node, child)
end)
select_leaf(tree, best_child_id)
end
end
defp backpropagate(tree, node_id, value) do
node = Tree.get_node(tree, node_id)
Tree.update_node(tree, node_id, fn n ->
%{n | value_sum: n.value_sum + value, visit_count: n.visit_count + 1}
end)
if node.parent_id != nil do
backpropagate(tree, node.parent_id, Game.get_opponent_value(value))
end
end
defp get_ucb(parent, child) do
q_value =
if child.visit_count == 0 do
0
else
1 - (child.value_sum / child.visit_count + 1) / 2
end
q_value +
parent.args.c *
(:math.sqrt(parent.visit_count) / (child.visit_count + 1)) *
child.prior
end
defp predict(board) do
{policy, value} =
board
|> Game.get_encoded_board()
|> Nx.new_axis(0)
|> ResNet.predict()
policy =
policy
|> Axon.Activations.softmax(axis: 1)
|> Nx.squeeze(axes: [0])
|> Nx.to_flat_list()
{policy, value}
end
end
AlphaZero
defmodule AlphaZero do
use GenServer
defstruct [:args, :game, :memory]
# ── API pública ────────────────────────────────────────────────────────
def start_link(%{args: args, game: game}) do
GenServer.start_link(__MODULE__, %__MODULE__{args: args, game: game, memory: []}, name: __MODULE__)
end
def get_args, do: GenServer.call(__MODULE__, :get_args)
def get_game, do: GenServer.call(__MODULE__, :get_game)
def learn do
args = get_args()
game = get_game()
Enum.each(1..args.num_iterations, fn iteration ->
IO.puts("Iteração #{iteration}/#{args.num_iterations}")
ResNet.set_mode(:inference)
memory =
Enum.flat_map(1..args.num_selfplay_iterations, fn i ->
IO.puts(" self-play #{i}/#{args.num_selfplay_iterations}")
self_play(game, args)
end)
ResNet.set_mode(:train)
Enum.each(1..args.num_epochs, fn epoch ->
IO.puts(" epoch #{epoch}/#{args.num_epochs}")
train(memory, args)
end)
ResNet.save(iteration)
end)
end
# ── self-play ──────────────────────────────────────────────────────────
def self_play(game, args) do
board = Game.get_initial_board(game)
player = 1
do_self_play(game, args, board, player, _memory = [])
end
defp do_self_play(game, args, board, player, memory) do
neutral_board = Game.change_perspective(board, player)
action_probs = MCTS.search(%{args: args, game: game}, neutral_board)
memory = [{neutral_board, action_probs, player} | memory]
valid_moves = Game.get_valid_moves(board) |> Nx.to_flat_list()
action =
action_probs
|> Utils.apply_temperature(args.temperature)
|> Utils.mask_and_normalize(valid_moves)
|> Utils.sample_action(args.action_size)
board = Game.get_next_board(board, action, player)
{value, is_terminal} = Game.get_value_and_terminated(game, board, action)
if is_terminal do
build_return_memory(memory, value, player)
else
do_self_play(game, args, board, Game.get_opponent(player), memory)
end
end
defp build_return_memory(memory, value, current_player) do
Enum.map(memory, fn {hist_board, hist_probs, hist_player} ->
outcome =
if hist_player == current_player,
do: value,
else: Game.get_opponent_value(value)
{Game.get_encoded_board(hist_board), hist_probs, outcome}
end)
end
# ── train ──────────────────────────────────────────────────────────────
def train(memory, args) do
memory
|> Enum.shuffle()
|> Enum.chunk_every(args.batch_size)
|> Enum.with_index(1)
|> Enum.each(fn {batch, _i} ->
{states, policy_targets, value_targets} =
Enum.reduce(batch, {[], [], []}, fn {s, p, v}, {sa, pa, va} ->
{[s | sa], [p | pa], [v | va]}
end)
|> then(fn {sa, pa, va} ->
{Enum.reverse(sa), Enum.reverse(pa), Enum.reverse(va)}
end)
state_tensor =
states
|> Enum.map(&Nx.new_axis(&1, 0))
|> Nx.concatenate(axis: 0)
|> Nx.as_type(:f32)
policy_tensor =
policy_targets
|> Enum.map(&Nx.tensor(&1, type: :f32))
|> Nx.stack()
value_tensor =
value_targets
|> Enum.map(&Nx.tensor([&1], type: :f32))
|> Nx.stack()
ResNet.train_step(state_tensor, policy_tensor, value_tensor)
end)
end
# ── GenServer callbacks ────────────────────────────────────────────────
@impl true
def init(state), do: {:ok, state}
@impl true
def handle_call(:get_args, _from, state), do: {:reply, state.args, state}
@impl true
def handle_call(:get_game, _from, state), do: {:reply, state.game, state}
end
defmodule PlayLive do
def start do
frame = Kino.Frame.new()
Kino.render(frame)
game = Game.new()
# botões de escolha
btn_train = Kino.Control.button("Treinar novo modelo")
btn_load = Kino.Control.button("Carregar modelo salvo")
Kino.Frame.render(frame, Kino.Layout.grid([btn_train, btn_load], columns: 2))
Kino.listen(btn_train, fn _event ->
Kino.Frame.render(frame, Kino.Markdown.new("**Treinando...**"))
args = %{
num_iterations: 3,
num_selfplay_iterations: 5,
num_epochs: 4,
batch_size: 64,
temperature: 1.25,
learning_rate: 0.001,
num_mcts_searches: 60,
dirichlet_epsilon: 0.25,
dirichlet_alpha: 0.3,
action_size: game.action_size,
c: 2.0
}
ensure_resnet(game, 9, 128, args.learning_rate)
{:ok, _} = AlphaZero.start_link(%{args: args, game: game})
AlphaZero.learn()
Kino.Frame.render(frame, Kino.Markdown.new("**Treino concluído!**"))
start_game(frame, game, args)
end)
Kino.listen(btn_load, fn _event ->
# lista modelos salvos
case File.ls("models") do
{:ok, files} ->
iterations =
files
|> Enum.filter(&String.starts_with?(&1, "params_"))
|> Enum.map(fn f ->
f
|> String.replace("params_", "")
|> String.replace(".nx", "")
|> String.to_integer()
end)
|> Enum.sort()
if iterations == [] do
Kino.Frame.render(frame, Kino.Markdown.new("**Nenhum modelo salvo encontrado.**"))
else
show_model_picker(frame, game, iterations)
end
{:error, _} ->
Kino.Frame.render(frame, Kino.Markdown.new("**Pasta models/ não encontrada.**"))
end
end)
end
defp show_model_picker(frame, game, iterations) do
buttons =
Enum.map(iterations, fn i ->
{Kino.Control.button("Modelo #{i}"), i}
end)
button_row =
buttons
|> Enum.map(fn {btn, _} -> btn end)
|> Kino.Layout.grid(columns: length(buttons))
Kino.Frame.render(frame, Kino.Markdown.new("**Escolha o modelo:**"))
Kino.Frame.append(frame, button_row)
Enum.each(buttons, fn {btn, iteration} ->
Kino.listen(btn, fn _event ->
args = %{
num_iterations: 8,
num_selfplay_iterations: 100,
num_epochs: 4,
batch_size: 64,
temperature: 1.25,
learning_rate: 0.001,
num_mcts_searches: 60,
dirichlet_epsilon: 0.25,
dirichlet_alpha: 0.3,
action_size: game.action_size,
c: 2.0
}
ensure_resnet(game, 9, 128, args.learning_rate)
ResNet.load(iteration)
Kino.Frame.render(frame, Kino.Markdown.new("**Modelo #{iteration} carregado!**"))
start_game(frame, game, args)
end)
end)
end
defp ensure_resnet(game, num_res_blocks, num_hidden, learning_rate) do
case Process.whereis(ResNet) do
nil -> :ok
pid -> Agent.stop(pid)
end
{:ok, _} = ResNet.start_link(game, num_res_blocks, num_hidden, learning_rate)
end
defp start_game(frame, game, args) do
board = Game.get_initial_board(game)
try do Agent.stop(:play_state) catch _, _ -> :ok end
Agent.start_link(fn ->
%{board: board, game: game, args: args, game_over: false, ai_mode: :mcts}
end, name: :play_state)
board_frame = Kino.Frame.new()
# botões de coluna
buttons =
Enum.map(0..6, fn col ->
{Kino.Control.button("#{col}"), col}
end)
button_row =
buttons
|> Enum.map(fn {btn, _} -> btn end)
|> Kino.Layout.grid(columns: 7)
# botões de modo da IA
btn_mcts = Kino.Control.button("IA forte (MCTS)")
btn_parallel = Kino.Control.button("IA forte paralela")
btn_raw = Kino.Control.button("IA rápida (sem MCTS)")
btn_restart = Kino.Control.button("Nova partida")
mode_row = Kino.Layout.grid([btn_mcts, btn_parallel, btn_raw], columns: 3)
Kino.Frame.render(frame, Kino.Layout.grid([
board_frame,
button_row,
mode_row,
btn_restart
], columns: 1))
render_board(board_frame, board, "Sua vez (modo: MCTS):")
# listeners de modo
Kino.listen(btn_mcts, fn _event ->
Agent.update(:play_state, fn s -> %{s | ai_mode: :mcts} end)
state = Agent.get(:play_state, & &1)
render_board(board_frame, state.board, "Modo: MCTS")
end)
Kino.listen(btn_parallel, fn _event ->
Agent.update(:play_state, fn s -> %{s | ai_mode: :parallel} end)
state = Agent.get(:play_state, & &1)
render_board(board_frame, state.board, "Modo: MCTS paralelo")
end)
Kino.listen(btn_raw, fn _event ->
Agent.update(:play_state, fn s -> %{s | ai_mode: :raw} end)
state = Agent.get(:play_state, & &1)
render_board(board_frame, state.board, "Modo: rede neural pura")
end)
Kino.listen(btn_restart, fn _event ->
start_game(frame, game, args)
end)
Enum.each(buttons, fn {btn, col} ->
Kino.listen(btn, fn _event ->
state = Agent.get(:play_state, & &1)
if not state.game_over do
valid = Game.get_valid_moves(state.board) |> Nx.to_flat_list()
if Enum.at(valid, col) == 1 do
play_turn(board_frame, state.game, state.args, state.board, state.ai_mode, col)
end
end
end)
end)
end
defp play_turn(board_frame, game, args, board, ai_mode, col) do
board = Game.get_next_board(board, col, -1)
{value, is_terminal} = Game.get_value_and_terminated(game, board, col)
if is_terminal do
msg = if value == 1, do: "Você venceu!", else: "Empate!"
render_board(board_frame, board, msg)
Agent.update(:play_state, fn s -> %{s | board: board, game_over: true} end)
else
render_board(board_frame, board, "IA pensando...")
ai_action = get_ai_action(game, args, board, ai_mode)
board = Game.get_next_board(board, ai_action, 1)
{value, is_terminal} = Game.get_value_and_terminated(game, board, ai_action)
if is_terminal do
msg = if value == 1, do: "IA venceu!", else: "Empate!"
render_board(board_frame, board, msg)
Agent.update(:play_state, fn s -> %{s | board: board, game_over: true} end)
else
mode_label =
case ai_mode do
:mcts -> "MCTS"
:parallel -> "MCTS paralelo"
:raw -> "rede neural"
end
render_board(board_frame, board, "Sua vez — IA jogou coluna #{ai_action} (#{mode_label}):")
Agent.update(:play_state, fn s -> %{s | board: board} end)
end
end
end
defp get_ai_action(game, args, board, :mcts) do
neutral_board = Game.change_perspective(board, 1)
action_probs = MCTS.search(%{args: args, game: game}, neutral_board)
action_probs |> Enum.with_index() |> Enum.max_by(fn {p, _} -> p end) |> elem(1)
end
defp get_ai_action(game, args, board, :parallel) do
neutral_board = Game.change_perspective(board, 1)
action_probs = search_parallel(%{args: args, game: game}, neutral_board)
action_probs |> Enum.with_index() |> Enum.max_by(fn {p, _} -> p end) |> elem(1)
end
defp get_ai_action(_game, _args, board, :raw) do
neutral_board = Game.change_perspective(board, 1)
{policy, _value} =
neutral_board
|> Game.get_encoded_board()
|> Nx.new_axis(0)
|> ResNet.predict()
valid_moves = Game.get_valid_moves(board) |> Nx.to_flat_list()
policy
|> Axon.Activations.softmax(axis: 1)
|> Nx.squeeze(axes: [0])
|> Nx.to_flat_list()
|> Utils.mask_and_normalize(valid_moves)
|> Enum.with_index()
|> Enum.max_by(fn {p, _} -> p end)
|> elem(1)
end
defp search_parallel(%{args: args, game: game}, board, num_workers \\ 4) do
searches_per_worker = div(args.num_mcts_searches, num_workers)
results =
1..num_workers
|> Enum.map(fn _ ->
Task.async(fn ->
MCTS.search(
%{args: %{args | num_mcts_searches: searches_per_worker}, game: game},
board
)
end)
end)
|> Task.await_many(:infinity)
results
|> Enum.zip()
|> Enum.map(fn probs_tuple ->
probs_tuple
|> Tuple.to_list()
|> Enum.sum()
|> Kernel./(num_workers)
end)
end
defp render_board(frame, board, message) do
header = "| 0 | 1 | 2 | 3 | 4 | 5 | 6 |\n|---|---|---|---|---|---|---|"
rows =
board
|> Nx.to_list()
|> Enum.map(fn row ->
cells =
Enum.map(row, fn
1 -> "🔴"
-1 -> "🟡"
0 -> "⚪"
end)
|> Enum.join(" | ")
"| #{cells} |"
end)
|> Enum.join("\n")
Kino.Frame.render(frame, Kino.Markdown.new("**#{message}**\n\n#{header}\n#{rows}"))
end
end
PlayLive.start()