MuZero — Backtesting PETR4 (espaço latente)
Mix.install([
{:nx, "~> 0.10.0"},
{:axon, "~> 0.7.0"},
{:exla, "~> 0.10.0"},
{:kino, "~> 0.19.0"}
], system_env: %{
"EXLA_CPU_ONLY" => "true",
"CFLAGS" => "-Wno-error -Wno-invalid-specialization",
"CXXFLAGS" => "-Wno-error -Wno-invalid-specialization"
})
Section
> Extensão do livebook AlphaZero. Depende das células anteriores:
> Utils, Dirichlet, MarketData, TradingEnv.
>
> Diferença central vs. AlphaZero: o MCTS nunca chama TradingEnv.step.
> Ele expande chamando a dinâmica aprendida g em espaço latente. O ambiente
> real só é usado para (a) gerar trajetórias de treino e (b) o backtest final.
>
> Três funções aprendidas, treinadas em conjunto:
>
> h (representação): observação → s⁰
> g (dinâmica): (sᵏ, aᵏ) → (sᵏ⁺¹, rᵏ⁺¹)
> * f (predição): sᵏ → (política, valor)
Config
Nx.global_default_backend(EXLA.Backend)
# Nx.Defn.global_default_options(compiler: EXLA)
Dirichlet
defmodule Dirichlet do
def sample(alphas) when is_list(alphas) do
gammas = Enum.map(alphas, &sample_gamma/1)
total = Enum.sum(gammas)
Enum.map(gammas, fn g -> g / total end)
end
defp sample_gamma(alpha) when alpha < 1 do
u = :rand.uniform()
sample_gamma(alpha + 1.0) * :math.pow(u, 1.0 / alpha)
end
defp sample_gamma(alpha) do
d = alpha - 1.0 / 3.0
c = 1.0 / :math.sqrt(9.0 * d)
marsaglia_loop(d, c)
end
defp marsaglia_loop(d, c) do
{x, v} = sample_v(c)
u = :rand.uniform()
cond do
v <= 0.0 -> marsaglia_loop(d, c)
u < 1.0 - 0.0331 * (x * x) * (x * x) -> d * v
:math.log(u) < 0.5 * x * x + d * (1.0 - v + :math.log(v)) -> d * v
true -> marsaglia_loop(d, c)
end
end
defp sample_v(c) do
x = sample_normal()
v = 1.0 + c * x
{x, v * v * v}
end
defp sample_normal do
u1 = :rand.uniform()
u2 = :rand.uniform()
:math.sqrt(-2.0 * :math.log(u1)) * :math.cos(2.0 * :math.pi() * u2)
end
end
Utils
defmodule Utils do
def mask_and_normalize(policy, valid_moves) do
masked = Enum.zip_with(policy, valid_moves, fn p, v -> p * v end)
total = Enum.sum(masked)
if total > 0 do
Enum.map(masked, fn p -> p / total end)
else
n_valid = Enum.sum(valid_moves)
Enum.map(valid_moves, fn v -> v / n_valid end)
end
end
def apply_temperature(action_probs, temperature) do
probs = Enum.map(action_probs, fn p -> :math.pow(p, 1.0 / temperature) end)
total = Enum.sum(probs)
Enum.map(probs, fn p -> p / total end)
end
def sample_action(probs, action_size) do
r = :rand.uniform()
probs
|> Enum.with_index()
|> Enum.reduce_while(0.0, fn {p, i}, acc ->
acc = acc + p
if r <= acc, do: {:halt, i}, else: {:cont, acc}
end)
|> then(fn
i when is_integer(i) -> i
_ -> action_size - 1
end)
end
def add_dirichlet_noise(policy, args) do
epsilon = args.dirichlet_epsilon
noise = Dirichlet.sample(List.duplicate(args.dirichlet_alpha, args.action_size))
Enum.zip_with(policy, noise, fn p, n ->
(1 - epsilon) * p + epsilon * n
end)
end
def argmax_index(list) do
list |> Enum.with_index() |> Enum.max_by(fn {p, _} -> p end) |> elem(1)
end
end
defmodule MuZeroCfg do
defstruct [
:n_channels, # = env.n_channels
:window, # = env.window
latent_dim: 64,
hidden: 128,
action_size: 3,
k_unroll: 5, # passos de desenrolamento no treino
gamma: 0.997
]
def from_env(env, opts \\ []) do
struct(%MuZeroCfg{n_channels: env.n_channels, window: env.window}, opts)
end
end
MuZeroNet — h, g, f com parâmetros conjuntos
defmodule MuZeroNet do
import Axon
import Nx.Defn
# ── modelos ───────────────────────────────────────────────────────────
# h: observação {nc,1,W} → latente {latent_dim}
defp repr_model(cfg) do
input("obs", shape: {nil, cfg.n_channels, 1, cfg.window})
|> conv(cfg.hidden, kernel_size: {1, 3}, padding: :same, channels: :first, use_bias: false)
|> relu()
|> conv(cfg.hidden, kernel_size: {1, 3}, padding: :same, channels: :first, use_bias: false)
|> relu()
|> flatten()
|> dense(cfg.latent_dim)
end
# g: [latente ; ação_onehot] → {próximo_latente, recompensa}
defp dyn_model(cfg) do
sa = input("sa", shape: {nil, cfg.latent_dim + cfg.action_size})
trunk = sa |> dense(cfg.hidden) |> relu() |> dense(cfg.hidden) |> relu()
next = trunk |> dense(cfg.latent_dim)
reward = trunk |> dense(1) # regressão escalar (ver nota: categórico é melhor)
Axon.container({next, reward})
end
# f: latente → {logits_política, valor}
defp pred_model(cfg) do
s = input("s", shape: {nil, cfg.latent_dim})
trunk = s |> dense(cfg.hidden) |> relu()
policy = trunk |> dense(cfg.action_size)
value = trunk |> dense(1) # linear
Axon.container({policy, value})
end
# ── ciclo de vida ───────────────────────────────────────────────────────
def start_link(cfg, learning_rate) do
# IMPORTANTE: sem `compiler: EXLA` aqui. As predict_fn precisam ser Nx puras
# para compor dentro do value_and_grad; o EXLA global compila o grafo inteiro.
{h_init, h_pred} = Axon.build(repr_model(cfg), mode: :inference)
{g_init, g_pred} = Axon.build(dyn_model(cfg), mode: :inference)
{f_init, f_pred} = Axon.build(pred_model(cfg), mode: :inference)
params = %{
h: h_init.(Nx.template({1, cfg.n_channels, 1, cfg.window}, :f32), %{}),
g: g_init.(Nx.template({1, cfg.latent_dim + cfg.action_size}, :f32), %{}),
f: f_init.(Nx.template({1, cfg.latent_dim}, :f32), %{})
}
{opt_init, opt_update} = Polaris.Optimizers.adam(learning_rate: learning_rate)
opt_state = opt_init.(params)
# Sem jit externo: obs/actions_oh/policies/values/rewards chegam como tensores
# concretos, então tensor[i] funciona. Só `p` é simbólico dentro de value_and_grad.
step_fn = fn params, obs, actions_oh, policies, values, rewards ->
Nx.Defn.value_and_grad(params, fn p ->
loss_unroll(p, h_pred, g_pred, f_pred, cfg.k_unroll,
obs, actions_oh, policies, values, rewards)
end)
end
Agent.start_link(fn ->
%{cfg: cfg, params: params,
h_pred: h_pred, g_pred: g_pred, f_pred: f_pred,
opt_update: opt_update, opt_state: opt_state, step_fn: step_fn}
end, name: __MODULE__)
end
defp state, do: Agent.get(__MODULE__, & &1)
# ── inferência (usada pelo MCTS) ─────────────────────────────────────────
# observação {1,nc,1,W} → latente {1,D}
def representation(obs) do
s = state()
s.h_pred.(s.params.h, %{"obs" => obs}) |> scale_hidden()
end
# (latente {1,D}, ação int) → {próximo_latente {1,D}, recompensa float}
def dynamics(latent, action) do
s = state()
oh = action_onehot(action, s.cfg.action_size)
sa = Nx.concatenate([latent, oh], axis: 1)
{next_raw, reward} = s.g_pred.(s.params.g, %{"sa" => sa})
{scale_hidden(next_raw), Nx.to_number(Nx.reshape(reward, {}))}
end
# latente {1,D} → {probs (lista de 3), valor float}
def prediction(latent) do
s = state()
{logits, value} = s.f_pred.(s.params.f, %{"s" => latent})
probs =
logits
|> Axon.Activations.softmax(axis: 1)
|> Nx.squeeze(axes: [0])
|> Nx.to_flat_list()
{probs, Nx.to_number(Nx.reshape(value, {}))}
end
# ── treino conjunto (desenrolamento de K passos) ─────────────────────────
# tensores já empilhados em batch (ver MuZeroZero.train/2):
# obs {B, nc, 1, W}
# actions_oh lista de K tensores {B, action_size}
# policies lista de K+1 tensores {B, action_size}
# values lista de K+1 tensores {B, 1}
# rewards lista de K tensores {B, 1}
def train_step(obs, actions_oh, policies, values, rewards) do
%{params: params, opt_update: opt_update, opt_state: opt_state, step_fn: step_fn} = state()
{loss, grads} = step_fn.(params, obs, actions_oh, policies, values, rewards)
{new_params, new_opt} = opt_update.(grads, opt_state, params)
Agent.update(__MODULE__, fn st -> %{st | params: new_params, opt_state: new_opt} end)
Nx.to_number(loss)
end
# perda do unroll — função regular: captura as predict_fn e compõe dentro do jit.
# `next_raw |> scale_hidden() |> scale_gradient(0.5)` aplica half-gradient na dinâmica.
defp loss_unroll(p, hp, gp, fp, k, obs, actions_oh, policies, values, rewards) do
s0 = scale_hidden(hp.(p.h, %{"obs" => obs}))
{pl0, vl0} = fp.(p.f, %{"s" => s0})
loss0 = Nx.add(ce(pl0, policies[0]), mse(vl0, values[0]))
{_s, total} =
Enum.reduce(0..(k - 1), {s0, loss0}, fn step, {s, acc} ->
sa = Nx.concatenate([s, actions_oh[step]], axis: 1)
{next_raw, r} = gp.(p.g, %{"sa" => sa})
next = next_raw |> scale_hidden() |> scale_gradient(0.5)
{pl, vl} = fp.(p.f, %{"s" => next})
step_loss =
mse(r, rewards[step])
|> Nx.add(ce(pl, policies[step + 1]))
|> Nx.add(mse(vl, values[step + 1]))
{next, Nx.add(acc, step_loss)}
end)
total
end
def save(it) do
File.mkdir_p!("models")
File.write!("models/muzero_#{it}.nx", Nx.serialize(state().params))
end
def load(it) do
params = "models/muzero_#{it}.nx" |> File.read!() |> Nx.deserialize()
Agent.update(__MODULE__, fn st -> %{st | params: params} end)
end
# ── helpers ───────────────────────────────────────────────────────────────
# normaliza o latente para [0,1] por amostra (estabilidade do unroll, estilo MuZero)
defn scale_hidden(x) do
min = Nx.reduce_min(x, axes: [1], keep_axes: true)
max = Nx.reduce_max(x, axes: [1], keep_axes: true)
(x - min) / (max - min + 1.0e-5)
end
# escala o gradiente por `s` sem alterar o forward
defn scale_gradient(x, s) do
s * x + stop_grad((1.0 - s) * x)
end
defp ce(logits, target) do
logits
|> Axon.Activations.log_softmax(axis: 1)
|> Nx.multiply(target)
|> Nx.sum(axes: [1])
|> Nx.mean()
|> Nx.negate()
end
defp mse(pred, target), do: pred |> Nx.subtract(target) |> Nx.pow(2) |> Nx.mean()
defp action_onehot(a, n) do
0..(n - 1)
|> Enum.map(fn i -> if i == a, do: 1.0, else: 0.0 end)
|> Nx.tensor(type: :f32)
|> Nx.reshape({1, n})
end
end
MuZeroMCTS — busca em espaço latente
defmodule MuZeroMCTS do
defmodule Node do
defstruct [
:id, :latent, :parent_id, :action_taken,
prior: 0.0, reward: 0.0,
visit_count: 0, value_sum: 0.0,
children_ids: [], expanded: false
]
end
defmodule Tree do
use Agent
def start_link, do: Agent.start_link(fn -> {%{}, 0} end)
def stop(t), do: Agent.stop(t)
def put(t, node) do
Agent.get_and_update(t, fn {nodes, id} ->
{id, {Map.put(nodes, id, %{node | id: id}), id + 1}}
end)
end
def get(t, id), do: Agent.get(t, fn {n, _} -> Map.fetch!(n, id) end)
def upd(t, id, f), do: Agent.update(t, fn {n, i} -> {Map.update!(n, id, f), i} end)
end
# obs: {1,nc,1,W} ; valid: [h,b,s] do estado REAL (máscara só na raiz)
def search(obs, valid, args) do
{:ok, tree} = Tree.start_link()
s0 = MuZeroNet.representation(obs)
root_id = Tree.put(tree, %Node{latent: s0, parent_id: nil, action_taken: nil, visit_count: 1})
expand(tree, root_id, valid, _noise = true, args)
Enum.each(1..args.num_mcts_searches, fn _ ->
leaf_id = select(tree, root_id)
leaf = Tree.get(tree, leaf_id)
parent = Tree.get(tree, leaf.parent_id)
# dinâmica APRENDIDA — nada de TradingEnv aqui
{latent, reward} = MuZeroNet.dynamics(parent.latent, leaf.action_taken)
Tree.upd(tree, leaf_id, fn n -> %{n | latent: latent, reward: reward} end)
value = expand(tree, leaf_id, _valid = nil, _noise = false, args)
backpropagate(tree, leaf_id, value, args.gamma)
end)
root = Tree.get(tree, root_id)
probs = List.duplicate(0.0, args.action_size)
probs =
Enum.reduce(root.children_ids, probs, fn cid, acc ->
c = Tree.get(tree, cid)
List.replace_at(acc, c.action_taken, c.visit_count * 1.0)
end)
total = Enum.sum(probs)
result = if total > 0, do: Enum.map(probs, &(&1 / total)), else: probs
Tree.stop(tree)
result
end
# cria os filhos (placeholders) e devolve o valor de f(latente)
defp expand(tree, node_id, valid, noise, args) do
node = Tree.get(tree, node_id)
{policy, value} = MuZeroNet.prediction(node.latent)
policy = if valid, do: Utils.mask_and_normalize(policy, valid), else: policy
policy = if noise, do: Utils.add_dirichlet_noise(policy, args) |> renorm(), else: policy
child_ids =
policy
|> Enum.with_index()
|> Enum.map(fn {prob, action} ->
Tree.put(tree, %Node{
parent_id: node_id, action_taken: action,
prior: prob, expanded: false
})
end)
Tree.upd(tree, node_id, fn n -> %{n | children_ids: child_ids, expanded: true} end)
value
end
# desce por PUCT até achar um nó ainda não expandido (folha)
defp select(tree, node_id) do
node = Tree.get(tree, node_id)
if not node.expanded do
node_id
else
best =
Enum.max_by(node.children_ids, fn cid ->
get_ucb(node, Tree.get(tree, cid))
end)
select(tree, best)
end
end
# backup single-agent: G = reward + γ·G ; sem inversão de sinal
defp backpropagate(tree, node_id, g, gamma) do
node = Tree.get(tree, node_id)
g2 = node.reward + gamma * g
Tree.upd(tree, node_id, fn n ->
%{n | value_sum: n.value_sum + g2, visit_count: n.visit_count + 1}
end)
if node.parent_id != nil, do: backpropagate(tree, node.parent_id, g2, gamma)
end
defp get_ucb(parent, child) do
q = if child.visit_count == 0, do: 0.0, else: child.value_sum / child.visit_count
q + 1.5 * (:math.sqrt(parent.visit_count) / (child.visit_count + 1)) * child.prior
end
defp renorm(list) do
total = Enum.sum(list)
if total > 0, do: Enum.map(list, &(&1 / total)), else: list
end
end
MarketData — carga de ticks + features causais
Carregue ticks reais de PETR4 (export do MetaTrader5 ou série da B3) via from_csv/1, ou use synthetic/1 para rodar a pipeline end-to-end sem dados.
Colunas esperadas no CSV: price (obrigatória), volume (opcional). Uma linha por tick/barra.
defmodule MarketData do
@doc "Lê um CSV com cabeçalho contendo ao menos a coluna `price`."
def from_csv(path) do
[header | rows] =
path
|> File.read!()
|> String.split("\n", trim: true)
|> Enum.map(&String.split(&1, [",", ";", "\t"], trim: true))
cols = header |> Enum.map(&String.downcase/1) |> Enum.with_index() |> Map.new()
pcol = Map.fetch!(cols, "price")
vcol = Map.get(cols, "volume")
prices =
Enum.map(rows, fn r -> r |> Enum.at(pcol) |> parse_float() end)
volumes =
if vcol, do: Enum.map(rows, fn r -> r |> Enum.at(vcol) |> parse_float() end), else: nil
build(prices, volumes)
end
@doc "Série sintética mean-reverting + tendência, só para testar a pipeline."
def synthetic(n \\ 20_000) do
:rand.seed(:exsss, {1, 2, 3})
{prices, _} =
Enum.map_reduce(1..n, 100.0, fn _i, p ->
drift = 0.00002
mr = (100.0 - p) * 0.0008
shock = :rand.normal() * 0.05
np = max(p + drift * p + mr + shock, 1.0)
{np, np}
end)
build(prices, nil)
end
# ── features causais (cada feature no tick t usa só dados <= t) ──────────
defp build(prices, volumes) do
n = length(prices)
pl = prices
log_ret = causal_log_returns(pl)
mom_s = rolling_mean(log_ret, 10)
mom_l = rolling_mean(log_ret, 60)
vol_s = rolling_std(log_ret, 30)
zprice = zscore_rolling(pl, 60)
vol_feat =
case volumes do
nil -> List.duplicate(0.0, n)
v -> zscore_rolling(v, 60)
end
feats =
[log_ret, mom_s, mom_l, vol_s, zprice, vol_feat]
|> Enum.zip_with(& &1) # lista de linhas {n_features} por tick
%{
prices: pl,
features: feats, # lista de listas, shape {n, n_features}
n: n,
n_features: 6
}
end
defp parse_float(s), do: s |> String.trim() |> Float.parse() |> elem(0)
defp causal_log_returns([_ | _] = p) do
[first | _] = p
[0.0 | Enum.zip_with(tl(p), p, fn a, b -> :math.log(a / b) end)]
|> Enum.take(length(p))
|> then(fn l -> if length(l) == length(p), do: l, else: [0.0 | l] end)
|> Enum.take(length(p))
|> case do
l when length(l) == length(p) -> l
_ -> [0.0 | Enum.zip_with(tl(p), p, fn a, b -> :math.log(a / b) end)]
end
|> Enum.take(length([first | tl(p)]))
end
defp rolling_mean(xs, w) do
xs
|> Enum.with_index()
|> Enum.map(fn {_x, i} ->
lo = max(0, i - w + 1)
win = Enum.slice(xs, lo..i)
Enum.sum(win) / length(win)
end)
end
defp rolling_std(xs, w) do
xs
|> Enum.with_index()
|> Enum.map(fn {_x, i} ->
lo = max(0, i - w + 1)
win = Enum.slice(xs, lo..i)
m = Enum.sum(win) / length(win)
var = Enum.reduce(win, 0.0, fn x, a -> a + (x - m) * (x - m) end) / length(win)
:math.sqrt(var)
end)
end
defp zscore_rolling(xs, w) do
xs
|> Enum.with_index()
|> Enum.map(fn {x, i} ->
lo = max(0, i - w + 1)
win = Enum.slice(xs, lo..i)
m = Enum.sum(win) / length(win)
var = Enum.reduce(win, 0.0, fn v, a -> a + (v - m) * (v - m) end) / length(win)
sd = :math.sqrt(var)
if sd > 1.0e-9, do: (x - m) / sd, else: 0.0
end)
end
end
TradingEnv
defmodule TradingEnv do
@moduledoc """
Ambiente single-agent (MDP). O "estado" trafegado é um map
`%{idx, position, t}` — análogo ao `board` da sua versão de jogo.
Ações: 0 = Hold, 1 = Buy (vai/fica long), 2 = Sell (vai/fica flat).
Long-only por padrão (posição ∈ {0, 1}). Para habilitar short, veja `apply_action/2`.
"""
defstruct [
:feat_t, # Nx tensor {n, n_features} f32 — para encode (rede)
:prices_list, # lista Elixir {n} — para reward rápido no MCTS
:n_features,
:n_channels, # n_features + 1 (plano de posição)
:window,
:action_size,
:cost, # custo por unidade de mudança de posição (ex: 0.0005 = 5 bps)
:gamma,
:max_steps,
:start_min, # primeiro idx válido (precisa de window-1 de histórico)
:end_idx # idx exclusivo final dos dados utilizáveis
]
def new(market, opts \\ []) do
window = Keyword.get(opts, :window, 64)
feat_t = market.features |> Nx.tensor(type: :f32)
%TradingEnv{
feat_t: feat_t,
prices_list: market.prices,
n_features: market.n_features,
n_channels: market.n_features + 1,
window: window,
action_size: 3,
cost: Keyword.get(opts, :cost, 0.0005),
gamma: Keyword.get(opts, :gamma, 0.997),
max_steps: Keyword.get(opts, :max_steps, 256),
start_min: window - 1,
end_idx: market.n - 1 # precisa de idx+1 para o retorno
}
end
# ── estado inicial (análogo a get_initial_board) ────────────────────────
def initial_state(%TradingEnv{} = env, start_idx) do
%{idx: max(start_idx, env.start_min), position: 0, t: 0}
end
# ── ações válidas (long-only) → lista [hold, buy, sell] ────────────────
def valid_actions(_env, %{position: 0}), do: [1, 1, 0]
def valid_actions(_env, %{position: 1}), do: [1, 0, 1]
# ── transição (análogo a get_next_board, mas devolve reward e done) ─────
def step(%TradingEnv{} = env, %{idx: i} = state, action) do
old_pos = state.position
new_pos = apply_action(old_pos, action)
p_now = Enum.at(env.prices_list, i)
p_next = Enum.at(env.prices_list, i + 1)
ret = (p_next - p_now) / p_now
pnl = new_pos * ret
cost = env.cost * abs(new_pos - old_pos)
reward = pnl - cost
next = %{state | idx: i + 1, position: new_pos, t: state.t + 1}
done? = i + 1 >= env.end_idx or state.t + 1 >= env.max_steps
{next, reward, done?}
end
# long-only: 0=Hold mantém, 1=Buy→1, 2=Sell→0
# (para long/short: Buy→+1, Sell→-1, Hold mantém)
defp apply_action(pos, 0), do: pos
defp apply_action(_pos, 1), do: 1
defp apply_action(_pos, 2), do: 0
# ── terminal puro (para checar nó no MCTS) ──────────────────────────────
def terminal?(%TradingEnv{} = env, %{idx: i, t: t}) do
i + 1 >= env.end_idx or t >= env.max_steps
end
# ── encode CAUSAL: só a janela passada [i-w+1 .. i] + plano de posição ──
# devolve tensor {n_channels, window}
def encode(%TradingEnv{} = env, %{idx: i, position: pos}) do
w = env.window
lo = i - w + 1
feat =
env.feat_t[lo..i] # {w, n_features}
|> Nx.transpose() # {n_features, w}
pos_plane = Nx.broadcast(Nx.tensor(pos * 1.0, type: :f32), {1, w})
Nx.concatenate([feat, pos_plane], axis: 0) # {n_channels, w}
end
end
MuZeroZero — self-play (env real) + treino (unroll)
defmodule MuZeroZero do
def learn(env, args) do
Enum.each(1..args.num_iterations, fn it ->
IO.puts("iteração #{it}/#{args.num_iterations}")
trajectories =
Enum.map(1..args.num_selfplay_iterations, fn i ->
IO.write(" self-play #{i}/#{args.num_selfplay_iterations}\r")
self_play(env, args)
end)
samples = Enum.flat_map(trajectories, &to_unroll_samples(&1, args))
IO.puts("\n amostras de unroll: #{length(samples)}")
Enum.each(1..args.num_epochs, fn ep ->
loss = train(samples, args)
IO.puts(" epoch #{ep}/#{args.num_epochs} — loss: #{Float.round(loss, 5)}")
end)
MuZeroNet.save(it)
end)
end
# ── self-play: ações decididas pelo MCTS latente, ambiente REAL avança ──
def self_play(env, args) do
start = Enum.random(env.start_min..(env.end_idx - env.max_steps - 1))
state = TradingEnv.initial_state(env, start)
roll(env, args, state, [])
end
defp roll(env, args, state, hist) do
obs = TradingEnv.encode(env, state) |> Nx.reshape({1, env.n_channels, 1, env.window})
valid = TradingEnv.valid_actions(env, state)
probs = MuZeroMCTS.search(obs, valid, args)
action =
probs
|> Utils.apply_temperature(args.temperature)
|> Utils.mask_and_normalize(valid)
|> Utils.sample_action(args.action_size)
{next, reward, done?} = TradingEnv.step(env, state, action)
# guarda o obs SEM o axis de batch (re-empilhado no treino)
obs2 = Nx.squeeze(obs, axes: [0])
hist = [%{obs: obs2, action: action, policy: probs, reward: reward} | hist]
if done?, do: Enum.reverse(hist), else: roll(env, args, next, hist)
end
# ── alvos de valor: retorno Monte Carlo descontado (n-step é melhoria) ──
defp with_value_targets(traj, gamma) do
{out, _g} =
traj
|> Enum.reverse()
|> Enum.map_reduce(0.0, fn step, g_next ->
g = step.reward + gamma * g_next
{Map.put(step, :value, g), g}
end)
Enum.reverse(out)
end
# ── transforma trajetória em amostras de unroll de comprimento K ─────────
defp to_unroll_samples(traj, args) do
traj = with_value_targets(traj, args.gamma)
k = args.k_unroll
n = length(traj)
arr = List.to_tuple(traj)
# só inícios com K passos completos à frente (cauda descartada; pad é melhoria)
0..(n - k - 1)
|> Enum.map(fn j ->
window = for i <- j..(j + k), do: elem(arr, i)
%{
obs: hd(window).obs, # {nc,1,W}
actions: Enum.slice(window, 0, k) |> Enum.map(& &1.action),
policies: Enum.map(window, & &1.policy), # K+1
values: Enum.map(window, & &1.value), # K+1
rewards: Enum.slice(window, 1, k) |> Enum.map(& &1.reward) # K (r após cada ação)
}
end)
end
# ── treino: empilha em batch e chama o unroll conjunto ───────────────────
def train(samples, args) do
k = args.k_unroll
a = args.action_size
losses =
samples
|> Enum.shuffle()
|> Enum.chunk_every(args.batch_size, args.batch_size, :discard)
|> Enum.map(fn batch ->
obs =
batch
|> Enum.map(&Nx.new_axis(&1.obs, 0))
|> Nx.concatenate(axis: 0)
# cada lista vira UM tensor com eixo de passo na frente
# (jit não percorre listas como contêineres de tensores)
actions_oh =
(for step <- 0..(k - 1) do
batch |> Enum.map(fn s -> onehot(Enum.at(s.actions, step), a) end) |> Nx.stack()
end)
|> Nx.stack() # {K, B, A}
policies =
(for step <- 0..k do
batch |> Enum.map(fn s -> Nx.tensor(Enum.at(s.policies, step), type: :f32) end) |> Nx.stack()
end)
|> Nx.stack() # {K+1, B, A}
values =
(for step <- 0..k do
batch |> Enum.map(fn s -> Nx.tensor([Enum.at(s.values, step)], type: :f32) end) |> Nx.stack()
end)
|> Nx.stack() # {K+1, B, 1}
rewards =
(for step <- 0..(k - 1) do
batch |> Enum.map(fn s -> Nx.tensor([Enum.at(s.rewards, step)], type: :f32) end) |> Nx.stack()
end)
|> Nx.stack() # {K, B, 1}
MuZeroNet.train_step(obs, actions_oh, policies, values, rewards)
end)
if losses == [], do: 0.0, else: Enum.sum(losses) / length(losses)
end
defp onehot(a, n) do
0..(n - 1) |> Enum.map(fn i -> if i == a, do: 1.0, else: 0.0 end) |> Nx.tensor(type: :f32)
end
end
Runner
# requer env_train do livebook anterior; senão recrie:
market = MarketData.synthetic(20_000)
split = round(market.n * 0.7)
train_market = %{market | prices: Enum.slice(market.prices, 0, split),
features: Enum.slice(market.features, 0, split), n: split}
env_train = TradingEnv.new(train_market, window: 64)
cfg = MuZeroCfg.from_env(env_train, latent_dim: 64, hidden: 128, k_unroll: 5, gamma: env_train.gamma)
args = %{
num_iterations: 4,
num_selfplay_iterations: 8,
num_epochs: 4,
batch_size: 64,
temperature: 1.1,
num_mcts_searches: 40,
dirichlet_epsilon: 0.25,
dirichlet_alpha: 0.3,
action_size: cfg.action_size,
k_unroll: cfg.k_unroll,
gamma: cfg.gamma
}
case Process.whereis(MuZeroNet) do
nil -> :ok
pid -> Agent.stop(pid)
end
{:ok, _} = MuZeroNet.start_link(cfg, 0.001)
MuZeroZero.learn(env_train, args)