Powered by AppSignal & Oban Pro

MuZero — Backtesting PETR4 (espaço latente)

muzero.livemd

MuZero — Backtesting PETR4 (espaço latente)

Mix.install([
  {:nx, "~> 0.10.0"},
  {:axon, "~> 0.7.0"},
  {:exla, "~> 0.10.0"},
  {:kino, "~> 0.19.0"}
], system_env: %{
  "EXLA_CPU_ONLY" => "true",
  "CFLAGS" => "-Wno-error -Wno-invalid-specialization",
  "CXXFLAGS" => "-Wno-error -Wno-invalid-specialization"
})

Section

> Extensão do livebook AlphaZero. Depende das células anteriores: > Utils, Dirichlet, MarketData, TradingEnv. > > Diferença central vs. AlphaZero: o MCTS nunca chama TradingEnv.step. > Ele expande chamando a dinâmica aprendida g em espaço latente. O ambiente > real só é usado para (a) gerar trajetórias de treino e (b) o backtest final. > > Três funções aprendidas, treinadas em conjunto: > > h (representação): observação → s⁰ > g (dinâmica): (sᵏ, aᵏ) → (sᵏ⁺¹, rᵏ⁺¹) > * f (predição): sᵏ → (política, valor)

Config

Nx.global_default_backend(EXLA.Backend)
# Nx.Defn.global_default_options(compiler: EXLA)

Dirichlet

defmodule Dirichlet do
  def sample(alphas) when is_list(alphas) do
    gammas = Enum.map(alphas, &sample_gamma/1)
    total  = Enum.sum(gammas)
    Enum.map(gammas, fn g -> g / total end)
  end

  defp sample_gamma(alpha) when alpha < 1 do
    u = :rand.uniform()
    sample_gamma(alpha + 1.0) * :math.pow(u, 1.0 / alpha)
  end

  defp sample_gamma(alpha) do
    d = alpha - 1.0 / 3.0
    c = 1.0 / :math.sqrt(9.0 * d)
    marsaglia_loop(d, c)
  end

  defp marsaglia_loop(d, c) do
    {x, v} = sample_v(c)
    u = :rand.uniform()

    cond do
      v <= 0.0 -> marsaglia_loop(d, c)
      u < 1.0 - 0.0331 * (x * x) * (x * x) -> d * v
      :math.log(u) < 0.5 * x * x + d * (1.0 - v + :math.log(v)) -> d * v
      true -> marsaglia_loop(d, c)
    end
  end

  defp sample_v(c) do
    x = sample_normal()
    v = 1.0 + c * x
    {x, v * v * v}
  end

  defp sample_normal do
    u1 = :rand.uniform()
    u2 = :rand.uniform()
    :math.sqrt(-2.0 * :math.log(u1)) * :math.cos(2.0 * :math.pi() * u2)
  end
end

Utils

defmodule Utils do
  def mask_and_normalize(policy, valid_moves) do
    masked = Enum.zip_with(policy, valid_moves, fn p, v -> p * v end)
    total  = Enum.sum(masked)

    if total > 0 do
      Enum.map(masked, fn p -> p / total end)
    else
      n_valid = Enum.sum(valid_moves)
      Enum.map(valid_moves, fn v -> v / n_valid end)
    end
  end

  def apply_temperature(action_probs, temperature) do
    probs = Enum.map(action_probs, fn p -> :math.pow(p, 1.0 / temperature) end)
    total = Enum.sum(probs)
    Enum.map(probs, fn p -> p / total end)
  end

  def sample_action(probs, action_size) do
    r = :rand.uniform()

    probs
    |> Enum.with_index()
    |> Enum.reduce_while(0.0, fn {p, i}, acc ->
      acc = acc + p
      if r <= acc, do: {:halt, i}, else: {:cont, acc}
    end)
    |> then(fn
      i when is_integer(i) -> i
      _                    -> action_size - 1
    end)
  end

  def add_dirichlet_noise(policy, args) do
    epsilon = args.dirichlet_epsilon
    noise   = Dirichlet.sample(List.duplicate(args.dirichlet_alpha, args.action_size))

    Enum.zip_with(policy, noise, fn p, n ->
      (1 - epsilon) * p + epsilon * n
    end)
  end

  def argmax_index(list) do
    list |> Enum.with_index() |> Enum.max_by(fn {p, _} -> p end) |> elem(1)
  end
end
defmodule MuZeroCfg do
  defstruct [
    :n_channels,        # = env.n_channels
    :window,            # = env.window
    latent_dim:   64,
    hidden:       128,
    action_size:  3,
    k_unroll:     5,    # passos de desenrolamento no treino
    gamma:        0.997
  ]

  def from_env(env, opts \\ []) do
    struct(%MuZeroCfg{n_channels: env.n_channels, window: env.window}, opts)
  end
end

MuZeroNet — h, g, f com parâmetros conjuntos

defmodule MuZeroNet do
  import Axon
  import Nx.Defn

  # ── modelos ───────────────────────────────────────────────────────────

  # h: observação {nc,1,W} → latente {latent_dim}
  defp repr_model(cfg) do
    input("obs", shape: {nil, cfg.n_channels, 1, cfg.window})
    |> conv(cfg.hidden, kernel_size: {1, 3}, padding: :same, channels: :first, use_bias: false)
    |> relu()
    |> conv(cfg.hidden, kernel_size: {1, 3}, padding: :same, channels: :first, use_bias: false)
    |> relu()
    |> flatten()
    |> dense(cfg.latent_dim)
  end

  # g: [latente ; ação_onehot] → {próximo_latente, recompensa}
  defp dyn_model(cfg) do
    sa = input("sa", shape: {nil, cfg.latent_dim + cfg.action_size})
    trunk = sa |> dense(cfg.hidden) |> relu() |> dense(cfg.hidden) |> relu()

    next   = trunk |> dense(cfg.latent_dim)
    reward = trunk |> dense(1)                # regressão escalar (ver nota: categórico é melhor)

    Axon.container({next, reward})
  end

  # f: latente → {logits_política, valor}
  defp pred_model(cfg) do
    s = input("s", shape: {nil, cfg.latent_dim})
    trunk = s |> dense(cfg.hidden) |> relu()

    policy = trunk |> dense(cfg.action_size)
    value  = trunk |> dense(1)                # linear

    Axon.container({policy, value})
  end

  # ── ciclo de vida ───────────────────────────────────────────────────────

  def start_link(cfg, learning_rate) do
    # IMPORTANTE: sem `compiler: EXLA` aqui. As predict_fn precisam ser Nx puras
    # para compor dentro do value_and_grad; o EXLA global compila o grafo inteiro.
    {h_init, h_pred} = Axon.build(repr_model(cfg), mode: :inference)
    {g_init, g_pred} = Axon.build(dyn_model(cfg),  mode: :inference)
    {f_init, f_pred} = Axon.build(pred_model(cfg), mode: :inference)

    params = %{
      h: h_init.(Nx.template({1, cfg.n_channels, 1, cfg.window}, :f32), %{}),
      g: g_init.(Nx.template({1, cfg.latent_dim + cfg.action_size}, :f32), %{}),
      f: f_init.(Nx.template({1, cfg.latent_dim}, :f32), %{})
    }

    {opt_init, opt_update} = Polaris.Optimizers.adam(learning_rate: learning_rate)
    opt_state = opt_init.(params)

    # Sem jit externo: obs/actions_oh/policies/values/rewards chegam como tensores
    # concretos, então tensor[i] funciona. Só `p` é simbólico dentro de value_and_grad.
    step_fn = fn params, obs, actions_oh, policies, values, rewards ->
      Nx.Defn.value_and_grad(params, fn p ->
        loss_unroll(p, h_pred, g_pred, f_pred, cfg.k_unroll,
                    obs, actions_oh, policies, values, rewards)
      end)
    end

    Agent.start_link(fn ->
      %{cfg: cfg, params: params,
        h_pred: h_pred, g_pred: g_pred, f_pred: f_pred,
        opt_update: opt_update, opt_state: opt_state, step_fn: step_fn}
    end, name: __MODULE__)
  end

  defp state, do: Agent.get(__MODULE__, &amp; &amp;1)

  # ── inferência (usada pelo MCTS) ─────────────────────────────────────────

  # observação {1,nc,1,W} → latente {1,D}
  def representation(obs) do
    s = state()
    s.h_pred.(s.params.h, %{"obs" => obs}) |> scale_hidden()
  end

  # (latente {1,D}, ação int) → {próximo_latente {1,D}, recompensa float}
  def dynamics(latent, action) do
    s  = state()
    oh = action_onehot(action, s.cfg.action_size)
    sa = Nx.concatenate([latent, oh], axis: 1)

    {next_raw, reward} = s.g_pred.(s.params.g, %{"sa" => sa})
    {scale_hidden(next_raw), Nx.to_number(Nx.reshape(reward, {}))}
  end

  # latente {1,D} → {probs (lista de 3), valor float}
  def prediction(latent) do
    s = state()
    {logits, value} = s.f_pred.(s.params.f, %{"s" => latent})

    probs =
      logits
      |> Axon.Activations.softmax(axis: 1)
      |> Nx.squeeze(axes: [0])
      |> Nx.to_flat_list()

    {probs, Nx.to_number(Nx.reshape(value, {}))}
  end

  # ── treino conjunto (desenrolamento de K passos) ─────────────────────────
  # tensores já empilhados em batch (ver MuZeroZero.train/2):
  #   obs        {B, nc, 1, W}
  #   actions_oh lista de K tensores {B, action_size}
  #   policies   lista de K+1 tensores {B, action_size}
  #   values     lista de K+1 tensores {B, 1}
  #   rewards    lista de K   tensores {B, 1}
  def train_step(obs, actions_oh, policies, values, rewards) do
    %{params: params, opt_update: opt_update, opt_state: opt_state, step_fn: step_fn} = state()

    {loss, grads} = step_fn.(params, obs, actions_oh, policies, values, rewards)

    {new_params, new_opt} = opt_update.(grads, opt_state, params)
    Agent.update(__MODULE__, fn st -> %{st | params: new_params, opt_state: new_opt} end)
    Nx.to_number(loss)
  end

  # perda do unroll — função regular: captura as predict_fn e compõe dentro do jit.
  # `next_raw |> scale_hidden() |> scale_gradient(0.5)` aplica half-gradient na dinâmica.
  defp loss_unroll(p, hp, gp, fp, k, obs, actions_oh, policies, values, rewards) do
    s0 = scale_hidden(hp.(p.h, %{"obs" => obs}))

    {pl0, vl0} = fp.(p.f, %{"s" => s0})
    loss0 = Nx.add(ce(pl0, policies[0]), mse(vl0, values[0]))

    {_s, total} =
      Enum.reduce(0..(k - 1), {s0, loss0}, fn step, {s, acc} ->
        sa = Nx.concatenate([s, actions_oh[step]], axis: 1)
        {next_raw, r} = gp.(p.g, %{"sa" => sa})
        next = next_raw |> scale_hidden() |> scale_gradient(0.5)
        {pl, vl} = fp.(p.f, %{"s" => next})

        step_loss =
          mse(r, rewards[step])
          |> Nx.add(ce(pl, policies[step + 1]))
          |> Nx.add(mse(vl, values[step + 1]))

        {next, Nx.add(acc, step_loss)}
      end)

    total
  end

  def save(it) do
    File.mkdir_p!("models")
    File.write!("models/muzero_#{it}.nx", Nx.serialize(state().params))
  end

  def load(it) do
    params = "models/muzero_#{it}.nx" |> File.read!() |> Nx.deserialize()
    Agent.update(__MODULE__, fn st -> %{st | params: params} end)
  end

  # ── helpers ───────────────────────────────────────────────────────────────

  # normaliza o latente para [0,1] por amostra (estabilidade do unroll, estilo MuZero)
  defn scale_hidden(x) do
    min = Nx.reduce_min(x, axes: [1], keep_axes: true)
    max = Nx.reduce_max(x, axes: [1], keep_axes: true)
    (x - min) / (max - min + 1.0e-5)
  end

  # escala o gradiente por `s` sem alterar o forward
  defn scale_gradient(x, s) do
    s * x + stop_grad((1.0 - s) * x)
  end

  defp ce(logits, target) do
    logits
    |> Axon.Activations.log_softmax(axis: 1)
    |> Nx.multiply(target)
    |> Nx.sum(axes: [1])
    |> Nx.mean()
    |> Nx.negate()
  end

  defp mse(pred, target), do: pred |> Nx.subtract(target) |> Nx.pow(2) |> Nx.mean()

  defp action_onehot(a, n) do
    0..(n - 1)
    |> Enum.map(fn i -> if i == a, do: 1.0, else: 0.0 end)
    |> Nx.tensor(type: :f32)
    |> Nx.reshape({1, n})
  end
end

MuZeroMCTS — busca em espaço latente

defmodule MuZeroMCTS do
  defmodule Node do
    defstruct [
      :id, :latent, :parent_id, :action_taken,
      prior: 0.0, reward: 0.0,
      visit_count: 0, value_sum: 0.0,
      children_ids: [], expanded: false
    ]
  end

  defmodule Tree do
    use Agent
    def start_link, do: Agent.start_link(fn -> {%{}, 0} end)
    def stop(t), do: Agent.stop(t)

    def put(t, node) do
      Agent.get_and_update(t, fn {nodes, id} ->
        {id, {Map.put(nodes, id, %{node | id: id}), id + 1}}
      end)
    end

    def get(t, id), do: Agent.get(t, fn {n, _} -> Map.fetch!(n, id) end)
    def upd(t, id, f), do: Agent.update(t, fn {n, i} -> {Map.update!(n, id, f), i} end)
  end

  # obs: {1,nc,1,W} ; valid: [h,b,s] do estado REAL (máscara só na raiz)
  def search(obs, valid, args) do
    {:ok, tree} = Tree.start_link()

    s0 = MuZeroNet.representation(obs)
    root_id = Tree.put(tree, %Node{latent: s0, parent_id: nil, action_taken: nil, visit_count: 1})

    expand(tree, root_id, valid, _noise = true, args)

    Enum.each(1..args.num_mcts_searches, fn _ ->
      leaf_id = select(tree, root_id)
      leaf    = Tree.get(tree, leaf_id)
      parent  = Tree.get(tree, leaf.parent_id)

      # dinâmica APRENDIDA — nada de TradingEnv aqui
      {latent, reward} = MuZeroNet.dynamics(parent.latent, leaf.action_taken)

      Tree.upd(tree, leaf_id, fn n -> %{n | latent: latent, reward: reward} end)

      value = expand(tree, leaf_id, _valid = nil, _noise = false, args)

      backpropagate(tree, leaf_id, value, args.gamma)
    end)

    root  = Tree.get(tree, root_id)
    probs = List.duplicate(0.0, args.action_size)

    probs =
      Enum.reduce(root.children_ids, probs, fn cid, acc ->
        c = Tree.get(tree, cid)
        List.replace_at(acc, c.action_taken, c.visit_count * 1.0)
      end)

    total  = Enum.sum(probs)
    result = if total > 0, do: Enum.map(probs, &amp;(&amp;1 / total)), else: probs

    Tree.stop(tree)
    result
  end

  # cria os filhos (placeholders) e devolve o valor de f(latente)
  defp expand(tree, node_id, valid, noise, args) do
    node = Tree.get(tree, node_id)
    {policy, value} = MuZeroNet.prediction(node.latent)

    policy = if valid,  do: Utils.mask_and_normalize(policy, valid), else: policy
    policy = if noise,  do: Utils.add_dirichlet_noise(policy, args) |> renorm(), else: policy

    child_ids =
      policy
      |> Enum.with_index()
      |> Enum.map(fn {prob, action} ->
        Tree.put(tree, %Node{
          parent_id: node_id, action_taken: action,
          prior: prob, expanded: false
        })
      end)

    Tree.upd(tree, node_id, fn n -> %{n | children_ids: child_ids, expanded: true} end)
    value
  end

  # desce por PUCT até achar um nó ainda não expandido (folha)
  defp select(tree, node_id) do
    node = Tree.get(tree, node_id)

    if not node.expanded do
      node_id
    else
      best =
        Enum.max_by(node.children_ids, fn cid ->
          get_ucb(node, Tree.get(tree, cid))
        end)

      select(tree, best)
    end
  end

  # backup single-agent: G = reward + γ·G ; sem inversão de sinal
  defp backpropagate(tree, node_id, g, gamma) do
    node = Tree.get(tree, node_id)
    g2 = node.reward + gamma * g

    Tree.upd(tree, node_id, fn n ->
      %{n | value_sum: n.value_sum + g2, visit_count: n.visit_count + 1}
    end)

    if node.parent_id != nil, do: backpropagate(tree, node.parent_id, g2, gamma)
  end

  defp get_ucb(parent, child) do
    q = if child.visit_count == 0, do: 0.0, else: child.value_sum / child.visit_count
    q + 1.5 * (:math.sqrt(parent.visit_count) / (child.visit_count + 1)) * child.prior
  end

  defp renorm(list) do
    total = Enum.sum(list)
    if total > 0, do: Enum.map(list, &amp;(&amp;1 / total)), else: list
  end
end

MarketData — carga de ticks + features causais

Carregue ticks reais de PETR4 (export do MetaTrader5 ou série da B3) via from_csv/1, ou use synthetic/1 para rodar a pipeline end-to-end sem dados.

Colunas esperadas no CSV: price (obrigatória), volume (opcional). Uma linha por tick/barra.

defmodule MarketData do
  @doc "Lê um CSV com cabeçalho contendo ao menos a coluna `price`."
  def from_csv(path) do
    [header | rows] =
      path
      |> File.read!()
      |> String.split("\n", trim: true)
      |> Enum.map(&amp;String.split(&amp;1, [",", ";", "\t"], trim: true))

    cols  = header |> Enum.map(&amp;String.downcase/1) |> Enum.with_index() |> Map.new()
    pcol  = Map.fetch!(cols, "price")
    vcol  = Map.get(cols, "volume")

    prices =
      Enum.map(rows, fn r -> r |> Enum.at(pcol) |> parse_float() end)

    volumes =
      if vcol, do: Enum.map(rows, fn r -> r |> Enum.at(vcol) |> parse_float() end), else: nil

    build(prices, volumes)
  end

  @doc "Série sintética mean-reverting + tendência, só para testar a pipeline."
  def synthetic(n \\ 20_000) do
    :rand.seed(:exsss, {1, 2, 3})

    {prices, _} =
      Enum.map_reduce(1..n, 100.0, fn _i, p ->
        drift = 0.00002
        mr    = (100.0 - p) * 0.0008
        shock = :rand.normal() * 0.05
        np    = max(p + drift * p + mr + shock, 1.0)
        {np, np}
      end)

    build(prices, nil)
  end

  # ── features causais (cada feature no tick t usa só dados <= t) ──────────
  defp build(prices, volumes) do
    n = length(prices)
    pl = prices

    log_ret = causal_log_returns(pl)
    mom_s   = rolling_mean(log_ret, 10)
    mom_l   = rolling_mean(log_ret, 60)
    vol_s   = rolling_std(log_ret, 30)
    zprice  = zscore_rolling(pl, 60)

    vol_feat =
      case volumes do
        nil -> List.duplicate(0.0, n)
        v   -> zscore_rolling(v, 60)
      end

    feats =
      [log_ret, mom_s, mom_l, vol_s, zprice, vol_feat]
      |> Enum.zip_with(&amp; &amp;1)        # lista de linhas {n_features} por tick


    %{
      prices: pl,
      features: feats,             # lista de listas, shape {n, n_features}
      n: n,
      n_features: 6
    }
  end

  defp parse_float(s), do: s |> String.trim() |> Float.parse() |> elem(0)

  defp causal_log_returns([_ | _] = p) do
    [first | _] = p
    [0.0 | Enum.zip_with(tl(p), p, fn a, b -> :math.log(a / b) end)]
    |> Enum.take(length(p))
    |> then(fn l -> if length(l) == length(p), do: l, else: [0.0 | l] end)
    |> Enum.take(length(p))
    |> case do
      l when length(l) == length(p) -> l
      _ -> [0.0 | Enum.zip_with(tl(p), p, fn a, b -> :math.log(a / b) end)]
    end
    |> Enum.take(length([first | tl(p)]))
  end

  defp rolling_mean(xs, w) do
    xs
    |> Enum.with_index()
    |> Enum.map(fn {_x, i} ->
      lo = max(0, i - w + 1)
      win = Enum.slice(xs, lo..i)
      Enum.sum(win) / length(win)
    end)
  end

  defp rolling_std(xs, w) do
    xs
    |> Enum.with_index()
    |> Enum.map(fn {_x, i} ->
      lo = max(0, i - w + 1)
      win = Enum.slice(xs, lo..i)
      m = Enum.sum(win) / length(win)
      var = Enum.reduce(win, 0.0, fn x, a -> a + (x - m) * (x - m) end) / length(win)
      :math.sqrt(var)
    end)
  end

  defp zscore_rolling(xs, w) do
    xs
    |> Enum.with_index()
    |> Enum.map(fn {x, i} ->
      lo = max(0, i - w + 1)
      win = Enum.slice(xs, lo..i)
      m = Enum.sum(win) / length(win)
      var = Enum.reduce(win, 0.0, fn v, a -> a + (v - m) * (v - m) end) / length(win)
      sd = :math.sqrt(var)
      if sd > 1.0e-9, do: (x - m) / sd, else: 0.0
    end)
  end
end

TradingEnv

defmodule TradingEnv do
  @moduledoc """
  Ambiente single-agent (MDP). O "estado" trafegado é um map
  `%{idx, position, t}` — análogo ao `board` da sua versão de jogo.

  Ações: 0 = Hold, 1 = Buy (vai/fica long), 2 = Sell (vai/fica flat).
  Long-only por padrão (posição ∈ {0, 1}). Para habilitar short, veja `apply_action/2`.
  """

  defstruct [
    :feat_t,        # Nx tensor {n, n_features} f32 — para encode (rede)
    :prices_list,   # lista Elixir {n} — para reward rápido no MCTS
    :n_features,
    :n_channels,    # n_features + 1 (plano de posição)
    :window,
    :action_size,
    :cost,          # custo por unidade de mudança de posição (ex: 0.0005 = 5 bps)
    :gamma,
    :max_steps,
    :start_min,     # primeiro idx válido (precisa de window-1 de histórico)
    :end_idx        # idx exclusivo final dos dados utilizáveis
  ]

  def new(market, opts \\ []) do
    window = Keyword.get(opts, :window, 64)

    feat_t = market.features |> Nx.tensor(type: :f32)

    %TradingEnv{
      feat_t:      feat_t,
      prices_list: market.prices,
      n_features:  market.n_features,
      n_channels:  market.n_features + 1,
      window:      window,
      action_size: 3,
      cost:        Keyword.get(opts, :cost, 0.0005),
      gamma:       Keyword.get(opts, :gamma, 0.997),
      max_steps:   Keyword.get(opts, :max_steps, 256),
      start_min:   window - 1,
      end_idx:     market.n - 1     # precisa de idx+1 para o retorno
    }
  end

  # ── estado inicial (análogo a get_initial_board) ────────────────────────
  def initial_state(%TradingEnv{} = env, start_idx) do
    %{idx: max(start_idx, env.start_min), position: 0, t: 0}
  end

  # ── ações válidas (long-only) → lista [hold, buy, sell] ────────────────
  def valid_actions(_env, %{position: 0}), do: [1, 1, 0]
  def valid_actions(_env, %{position: 1}), do: [1, 0, 1]

  # ── transição (análogo a get_next_board, mas devolve reward e done) ─────
  def step(%TradingEnv{} = env, %{idx: i} = state, action) do
    old_pos = state.position
    new_pos = apply_action(old_pos, action)

    p_now  = Enum.at(env.prices_list, i)
    p_next = Enum.at(env.prices_list, i + 1)
    ret    = (p_next - p_now) / p_now

    pnl    = new_pos * ret
    cost   = env.cost * abs(new_pos - old_pos)
    reward = pnl - cost

    next  = %{state | idx: i + 1, position: new_pos, t: state.t + 1}
    done? = i + 1 >= env.end_idx or state.t + 1 >= env.max_steps

    {next, reward, done?}
  end

  # long-only: 0=Hold mantém, 1=Buy→1, 2=Sell→0
  # (para long/short: Buy→+1, Sell→-1, Hold mantém)
  defp apply_action(pos, 0), do: pos
  defp apply_action(_pos, 1), do: 1
  defp apply_action(_pos, 2), do: 0

  # ── terminal puro (para checar nó no MCTS) ──────────────────────────────
  def terminal?(%TradingEnv{} = env, %{idx: i, t: t}) do
    i + 1 >= env.end_idx or t >= env.max_steps
  end

  # ── encode CAUSAL: só a janela passada [i-w+1 .. i] + plano de posição ──
  # devolve tensor {n_channels, window}
  def encode(%TradingEnv{} = env, %{idx: i, position: pos}) do
    w  = env.window
    lo = i - w + 1

    feat =
      env.feat_t[lo..i]          # {w, n_features}
      |> Nx.transpose()          # {n_features, w}

    pos_plane = Nx.broadcast(Nx.tensor(pos * 1.0, type: :f32), {1, w})

    Nx.concatenate([feat, pos_plane], axis: 0)   # {n_channels, w}
  end
end

MuZeroZero — self-play (env real) + treino (unroll)

defmodule MuZeroZero do
  def learn(env, args) do
    Enum.each(1..args.num_iterations, fn it ->
      IO.puts("iteração #{it}/#{args.num_iterations}")

      trajectories =
        Enum.map(1..args.num_selfplay_iterations, fn i ->
          IO.write("  self-play #{i}/#{args.num_selfplay_iterations}\r")
          self_play(env, args)
        end)

      samples = Enum.flat_map(trajectories, &amp;to_unroll_samples(&amp;1, args))
      IO.puts("\n  amostras de unroll: #{length(samples)}")

      Enum.each(1..args.num_epochs, fn ep ->
        loss = train(samples, args)
        IO.puts("  epoch #{ep}/#{args.num_epochs} — loss: #{Float.round(loss, 5)}")
      end)

      MuZeroNet.save(it)
    end)
  end

  # ── self-play: ações decididas pelo MCTS latente, ambiente REAL avança ──
  def self_play(env, args) do
    start = Enum.random(env.start_min..(env.end_idx - env.max_steps - 1))
    state = TradingEnv.initial_state(env, start)
    roll(env, args, state, [])
  end

  defp roll(env, args, state, hist) do
    obs   = TradingEnv.encode(env, state) |> Nx.reshape({1, env.n_channels, 1, env.window})
    valid = TradingEnv.valid_actions(env, state)

    probs = MuZeroMCTS.search(obs, valid, args)

    action =
      probs
      |> Utils.apply_temperature(args.temperature)
      |> Utils.mask_and_normalize(valid)
      |> Utils.sample_action(args.action_size)

    {next, reward, done?} = TradingEnv.step(env, state, action)

    # guarda o obs SEM o axis de batch (re-empilhado no treino)
    obs2  = Nx.squeeze(obs, axes: [0])
    hist  = [%{obs: obs2, action: action, policy: probs, reward: reward} | hist]

    if done?, do: Enum.reverse(hist), else: roll(env, args, next, hist)
  end

  # ── alvos de valor: retorno Monte Carlo descontado (n-step é melhoria) ──
  defp with_value_targets(traj, gamma) do
    {out, _g} =
      traj
      |> Enum.reverse()
      |> Enum.map_reduce(0.0, fn step, g_next ->
        g = step.reward + gamma * g_next
        {Map.put(step, :value, g), g}
      end)

    Enum.reverse(out)
  end

  # ── transforma trajetória em amostras de unroll de comprimento K ─────────
  defp to_unroll_samples(traj, args) do
    traj = with_value_targets(traj, args.gamma)
    k    = args.k_unroll
    n    = length(traj)
    arr  = List.to_tuple(traj)

    # só inícios com K passos completos à frente (cauda descartada; pad é melhoria)
    0..(n - k - 1)
    |> Enum.map(fn j ->
      window = for i <- j..(j + k), do: elem(arr, i)

      %{
        obs:      hd(window).obs,                               # {nc,1,W}
        actions:  Enum.slice(window, 0, k) |> Enum.map(&amp; &amp;1.action),
        policies: Enum.map(window, &amp; &amp;1.policy),                # K+1
        values:   Enum.map(window, &amp; &amp;1.value),                 # K+1
        rewards:  Enum.slice(window, 1, k) |> Enum.map(&amp; &amp;1.reward)  # K (r após cada ação)
      }
    end)
  end

  # ── treino: empilha em batch e chama o unroll conjunto ───────────────────
  def train(samples, args) do
    k = args.k_unroll
    a = args.action_size

    losses =
      samples
      |> Enum.shuffle()
      |> Enum.chunk_every(args.batch_size, args.batch_size, :discard)
      |> Enum.map(fn batch ->
        obs =
          batch
          |> Enum.map(&amp;Nx.new_axis(&amp;1.obs, 0))
          |> Nx.concatenate(axis: 0)

        # cada lista vira UM tensor com eixo de passo na frente
        # (jit não percorre listas como contêineres de tensores)
        actions_oh =
          (for step <- 0..(k - 1) do
             batch |> Enum.map(fn s -> onehot(Enum.at(s.actions, step), a) end) |> Nx.stack()
           end)
          |> Nx.stack()                                   # {K, B, A}

        policies =
          (for step <- 0..k do
             batch |> Enum.map(fn s -> Nx.tensor(Enum.at(s.policies, step), type: :f32) end) |> Nx.stack()
           end)
          |> Nx.stack()                                   # {K+1, B, A}

        values =
          (for step <- 0..k do
             batch |> Enum.map(fn s -> Nx.tensor([Enum.at(s.values, step)], type: :f32) end) |> Nx.stack()
           end)
          |> Nx.stack()                                   # {K+1, B, 1}

        rewards =
          (for step <- 0..(k - 1) do
             batch |> Enum.map(fn s -> Nx.tensor([Enum.at(s.rewards, step)], type: :f32) end) |> Nx.stack()
           end)
          |> Nx.stack()                                   # {K, B, 1}

        MuZeroNet.train_step(obs, actions_oh, policies, values, rewards)
      end)

    if losses == [], do: 0.0, else: Enum.sum(losses) / length(losses)
  end

  defp onehot(a, n) do
    0..(n - 1) |> Enum.map(fn i -> if i == a, do: 1.0, else: 0.0 end) |> Nx.tensor(type: :f32)
  end
end

Runner

# requer env_train do livebook anterior; senão recrie:
market = MarketData.synthetic(20_000)
split  = round(market.n * 0.7)
train_market = %{market | prices: Enum.slice(market.prices, 0, split),
                          features: Enum.slice(market.features, 0, split), n: split}
env_train = TradingEnv.new(train_market, window: 64)

cfg = MuZeroCfg.from_env(env_train, latent_dim: 64, hidden: 128, k_unroll: 5, gamma: env_train.gamma)

args = %{
  num_iterations:          4,
  num_selfplay_iterations: 8,
  num_epochs:              4,
  batch_size:              64,
  temperature:             1.1,
  num_mcts_searches:       40,
  dirichlet_epsilon:       0.25,
  dirichlet_alpha:         0.3,
  action_size:             cfg.action_size,
  k_unroll:                cfg.k_unroll,
  gamma:                   cfg.gamma
}

case Process.whereis(MuZeroNet) do
  nil -> :ok
  pid -> Agent.stop(pid)
end
{:ok, _} = MuZeroNet.start_link(cfg, 0.001)

MuZeroZero.learn(env_train, args)