Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Forecasting the Future

ForecastingTheFuture.livemd

Forecasting the Future

Mix.install([
  {:explorer, "~> 0.5.0"},
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:axon, "~> 0.5"},
  {:vega_lite, "~> 0.1.6"},
  {:kino, "~> 0.8.0"},
  {:kino_vega_lite, "~> 0.1.7"}
])

alias VegaLite, as: Vl
Resolving Hex dependencies...
Resolution completed in 0.117s
New:
  axon 0.5.1
  castore 1.0.2
  complex 0.5.0
  elixir_make 0.7.6
  exla 0.5.3
  explorer 0.5.7
  kino 0.8.1
  kino_vega_lite 0.1.8
  nx 0.5.3
  rustler_precompiled 0.6.1
  table 0.1.2
  table_rex 3.1.1
  telemetry 1.2.1
  vega_lite 0.1.7
  xla 0.4.4
* Getting explorer (Hex package)
* Getting nx (Hex package)
* Getting exla (Hex package)
* Getting axon (Hex package)
* Getting vega_lite (Hex package)
* Getting kino (Hex package)
* Getting kino_vega_lite (Hex package)
* Getting table (Hex package)
* Getting elixir_make (Hex package)
* Getting telemetry (Hex package)
* Getting xla (Hex package)
* Getting complex (Hex package)
* Getting rustler_precompiled (Hex package)
* Getting table_rex (Hex package)
* Getting castore (Hex package)
==> table
Compiling 5 files (.ex)
Generated table app
==> vega_lite
Compiling 5 files (.ex)
Generated vega_lite app
===> Analyzing applications...
===> Compiling telemetry
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 31 files (.ex)
Generated nx app
==> kino
Compiling 37 files (.ex)
Generated kino app
==> kino_vega_lite
Compiling 4 files (.ex)
Generated kino_vega_lite app
==> table_rex
Compiling 7 files (.ex)
Generated table_rex app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> castore
Compiling 1 file (.ex)
Generated castore app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /Users/sean/Library/Caches/xla/0.4.4/cache/download/xla_extension-aarch64-darwin-cpu.tar.gz into /Users/sean/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0.2/adb2e876bd0f4c056c66dfce006b140f/deps/exla/cache
Using libexla.so from /Users/sean/Library/Caches/xla/exla/elixir-1.14.2-erts-13.0.2-xla-0.4.4-exla-0.5.3-defofxrodwsk5sselkno4icm44/libexla.so
Compiling 21 files (.ex)
Generated exla app
==> rustler_precompiled
Compiling 4 files (.ex)
Generated rustler_precompiled app
==> explorer
Compiling 19 files (.ex)

06:46:43.386 [debug] Copying NIF from cache and extracting to /Users/sean/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0.2/adb2e876bd0f4c056c66dfce006b140f/_build/dev/lib/explorer/priv/native/libexplorer-v0.5.7-nif-2.16-aarch64-apple-darwin.so
Generated explorer app
VegaLite

The Data

File.cwd!()
"/Users/sean/projects/scholar/notebooks"
df =
  Explorer.DataFrame.from_csv!(
    "all_stocks_2006-01-01_to_2018-01-01.csv",
    parse_dates: true
  )
#Explorer.DataFrame<
  Polars[93612 x 7]
  Date date [2006-01-03, 2006-01-04, 2006-01-05, 2006-01-06, 2006-01-09, ...]
  Open float [77.76, 79.49, 78.41, 78.64, 78.5, ...]
  High float [79.35, 79.49, 78.65, 78.9, 79.83, ...]
  Low float [77.24, 78.25, 77.56, 77.64, 78.46, ...]
  Close float [79.11, 78.71, 77.99, 78.63, 79.02, ...]
  Volume integer [3117200, 2558000, 2529500, 2479500, 1845600, ...]
  Name string ["MMM", "MMM", "MMM", "MMM", "MMM", ...]
>
df = Explorer.DataFrame.select(df, ["Date", "Close", "Name"])
#Explorer.DataFrame<
  Polars[93612 x 3]
  Date date [2006-01-03, 2006-01-04, 2006-01-05, 2006-01-06, 2006-01-09, ...]
  Close float [79.11, 78.71, 77.99, 78.63, 79.02, ...]
  Name string ["MMM", "MMM", "MMM", "MMM", "MMM", ...]
>
Vl.new(title: "DJIA Stock Prices", width: 1280, height: 720)
|> Vl.data_from_values(Explorer.DataFrame.to_columns(df))
|> Vl.mark(:line)
|> Vl.encode_field(:x, "Date", type: :temporal)
|> Vl.encode_field(:y, "Close", type: :quantitative)
|> Vl.encode_field(:color, "Name", type: :nominal)
|> Kino.VegaLite.new()
aapl_df =
  Explorer.DataFrame.filter_with(df, fn df ->
    Explorer.Series.equal(df["Name"], "AAPL")
  end)

normalized_aapl_df =
  Explorer.DataFrame.mutate_with(aapl_df, fn df ->
    var = Explorer.Series.variance(df["Close"])
    mean = Explorer.Series.mean(df["Close"])
    centered = Explorer.Series.subtract(df["Close"], mean)
    norm = Explorer.Series.divide(centered, var)
    [Close: norm]
  end)
#Explorer.DataFrame<
  Polars[3019 x 3]
  Date date [2006-01-03, 2006-01-04, 2006-01-05, 2006-01-06, 2006-01-09, ...]
  Close float [-0.027216043254776588, -0.02720091843889246, -0.027241251281250132,
   -0.027105127938293, -0.027125294359471832, ...]
  Name string ["AAPL", "AAPL", "AAPL", "AAPL", "AAPL", ...]
>
Vl.new(title: "AAPL Stock Price", width: 1280, height: 720)
|> Vl.data_from_values(Explorer.DataFrame.to_columns(normalized_aapl_df))
|> Vl.mark(:line)
|> Vl.encode_field(:x, "Date", type: :temporal)
|> Vl.encode_field(:y, "Close", type: :quantitative)
|> Kino.VegaLite.new()
defmodule Data do
  def window(inputs, window_size, target_window_size) do
    inputs
    |> Stream.chunk_every(window_size + target_window_size, 1, :discard)
    |> Stream.map(fn window ->
      features = Enum.take(window, window_size)
      targets = Enum.drop(window, window_size)
      {Nx.tensor(features) |> Nx.new_axis(-1), Nx.tensor(targets) |> Nx.new_axis(-1)}
    end)
  end

  def batch(inputs, batch_size) do
    inputs
    |> Stream.chunk_every(batch_size, batch_size, :discard)
    |> Stream.map(fn windows ->
      {features, targets} = Enum.unzip(windows)
      {Nx.stack(features), Nx.stack(targets)}
    end)
  end
end
{:module, Data, <<70, 79, 82, 49, 0, 0, 10, ...>>, {:batch, 2}}
train_df =
  Explorer.DataFrame.filter_with(normalized_aapl_df, fn df ->
    Explorer.Series.less(df["Date"], Date.new!(2016, 1, 1))
  end)

test_df =
  Explorer.DataFrame.filter_with(normalized_aapl_df, fn df ->
    Explorer.Series.greater_equal(df["Date"], Date.new!(2016, 1, 1))
  end)
#Explorer.DataFrame<
  Polars[503 x 3]
  Date date [2016-01-04, 2016-01-05, 2016-01-06, 2016-01-07, 2016-01-08, ...]
  Close float [0.02051283407023046, 0.01918185027242737, 0.01816848760819093, 0.01602580535793974,
   0.016282927227969878, ...]
  Name string ["AAPL", "AAPL", "AAPL", "AAPL", "AAPL", ...]
>
window_size = 30
batch_size = 32

train_prices = Explorer.Series.to_list(train_df["Close"])
test_prices = Explorer.Series.to_list(test_df["Close"])

single_step_train_data =
  train_prices
  |> Data.window(window_size, 1)
  |> Data.batch(batch_size)

single_step_test_data =
  test_prices
  |> Data.window(window_size, 1)
  |> Data.batch(batch_size)

multi_step_train_data =
  train_prices
  |> Data.window(window_size, 5)
  |> Data.batch(batch_size)

multi_step_test_data =
  test_prices
  |> Data.window(window_size, 5)
  |> Data.batch(batch_size)
#Stream<[
  enum: #Stream<[
    enum: #Stream<[
      enum: [0.02051283407023046, 0.01918185027242737, 0.01816848760819093, 0.01602580535793974,
       0.016282927227969878, 0.017074459259239144, 0.01779540881638248, 0.016499716255642356,
       0.017573578183415303, 0.016368634517979926, 0.01613167906912862, 0.016197219937959837,
       0.015950181278519108, 0.018531483189409954, 0.017533245341057633, 0.01781053363226661,
       0.01449819895364301, 0.014835986508388493, 0.016474508229168815, 0.016015722147350326,
       0.015032609114882132, 0.01597538930499265, 0.016101429437360367, 0.014800695271325529,
       0.015299814195501693, 0.01528973098491227, 0.014926735403693246, 0.014639363901894855,
       0.014785570455441403, 0.016121595858539204, 0.016867753442156092, 0.015930014857340278,
       0.015819099540856685, 0.01624259438561221, 0.015138482826071012, 0.015849349172624934,
       0.01618209512207571, 0.016257719201496337, 0.01614680388501275, 0.01808278031818088,
       0.01819369563466447, 0.01857181603176762, 0.019333098431268635, 0.018758355427671847,
       0.018334860582916314, 0.018380235030568695, 0.018405443057042236, ...],
      funs: [#Function<3.6935098/1 in Stream.chunk_while/4>]
    ]>,
    funs: [#Function<48.6935098/1 in Stream.map/2>, #Function<3.6935098/1 in Stream.chunk_while/4>]
  ]>,
  funs: [#Function<48.6935098/1 in Stream.map/2>]
]>
Enum.take(single_step_train_data, 1)
[
  {#Nx.Tensor<
     f32[32][30][1]
     [
       [
         [-0.027216043323278427],
         [-0.027200918644666672],
         [-0.02724125050008297],
         [-0.02710512839257717],
         [-0.027125295251607895],
         [-0.026777423918247223],
         [-0.026555592194199562],
         [-0.02653038501739502],
         [-0.02643459476530552],
         [-0.02650013566017151],
         [-0.026661466807127],
         [-0.026908505707979202],
         [-0.027120253071188927],
         [-0.027004295960068703],
         [-0.027125295251607895],
         [-0.027256375178694725],
         [-0.027392499148845673],
         [-0.027412666007876396],
         [-0.027200918644666672],
         [-0.027160584926605225],
         [-0.02717066928744316],
         [-0.027407623827457428],
         [-0.02742779068648815],
         [-0.0277554951608181],
         [-0.027730286121368408],
         [-0.027644580230116844],
         [-0.02792186848819256],
         [-0.027750452980399132],
         [-0.027942035347223282],
         [-0.027730286121368408]
       ],
       [
         [-0.027200918644666672],
         [-0.02724125050008297],
         [-0.02710512839257717],
         [-0.027125295251607895],
         [-0.026777423918247223],
         [-0.026555592194199562],
         [-0.02653038501739502],
         [-0.02643459476530552],
         [-0.02650013566017151],
         [-0.026661466807127],
         [-0.026908505707979202],
         [-0.027120253071188927],
         [-0.027004295960068703],
         [-0.027125295251607895],
         [-0.027256375178694725],
         [-0.027392499148845673],
         [-0.027412666007876396],
         [-0.027200918644666672],
         ...
       ],
       ...
     ]
   >,
   #Nx.Tensor<
     f32[32][1][1]
     [
       [
         [-0.027614330872893333]
       ],
       [
         [-0.027518538758158684]
       ],
       [
         [-0.027538705617189407]
       ],
       [
         [-0.02762441337108612]
       ],
       [
         [-0.02746308222413063]
       ],
       [
         [-0.02743283286690712]
       ],
       [
         [-0.027452997863292694]
       ],
       [
         [-0.027488289400935173]
       ],
       [
         [-0.027669787406921387]
       ],
       [
         [-0.02762441337108612]
       ],
       [
         [-0.02758912183344364]
       ],
       [
         [-0.02772524580359459]
       ],
       [
         [-0.02788657695055008]
       ],
       [
         [-0.027826078236103058]
       ],
       [
         [-0.027871452271938324]
       ],
       [
         [-0.027997491881251335]
       ],
       [
         [-0.02804790809750557]
       ],
       [
         [-0.027871452271938324]
       ],
       [
         [-0.027750452980399132]
       ],
       [
         [-0.027831118553876877]
       ],
       [
         [-0.027967242524027824]
       ],
       [
         [-0.027942035347223282]
       ],
       [
         [-0.027992449700832367]
       ],
       [
         [-0.028148740530014038]
       ],
       [
         [-0.028158823028206825]
       ],
       [
         [-0.02826973795890808]
       ],
       [
         [-0.028279822319746017]
       ],
       [
         [-0.028315113857388496]
       ],
       [
         [-0.02837057039141655]
       ],
       [
         [-0.02811344899237156]
       ],
       [
         [-0.02808319963514805]
       ],
       [
         [-0.02808319963514805]
       ]
     ]
   >}
]
Enum.take(multi_step_train_data, 1)
[
  {#Nx.Tensor<
     f32[32][30][1]
     [
       [
         [-0.027216043323278427],
         [-0.027200918644666672],
         [-0.02724125050008297],
         [-0.02710512839257717],
         [-0.027125295251607895],
         [-0.026777423918247223],
         [-0.026555592194199562],
         [-0.02653038501739502],
         [-0.02643459476530552],
         [-0.02650013566017151],
         [-0.026661466807127],
         [-0.026908505707979202],
         [-0.027120253071188927],
         [-0.027004295960068703],
         [-0.027125295251607895],
         [-0.027256375178694725],
         [-0.027392499148845673],
         [-0.027412666007876396],
         [-0.027200918644666672],
         [-0.027160584926605225],
         [-0.02717066928744316],
         [-0.027407623827457428],
         [-0.02742779068648815],
         [-0.0277554951608181],
         [-0.027730286121368408],
         [-0.027644580230116844],
         [-0.02792186848819256],
         [-0.027750452980399132],
         [-0.027942035347223282],
         [-0.027730286121368408]
       ],
       [
         [-0.027200918644666672],
         [-0.02724125050008297],
         [-0.02710512839257717],
         [-0.027125295251607895],
         [-0.026777423918247223],
         [-0.026555592194199562],
         [-0.02653038501739502],
         [-0.02643459476530552],
         [-0.02650013566017151],
         [-0.026661466807127],
         [-0.026908505707979202],
         [-0.027120253071188927],
         [-0.027004295960068703],
         [-0.027125295251607895],
         [-0.027256375178694725],
         [-0.027392499148845673],
         [-0.027412666007876396],
         [-0.027200918644666672],
         ...
       ],
       ...
     ]
   >,
   #Nx.Tensor<
     f32[32][5][1]
     [
       [
         [-0.027614330872893333],
         [-0.027518538758158684],
         [-0.027538705617189407],
         [-0.02762441337108612],
         [-0.02746308222413063]
       ],
       [
         [-0.027518538758158684],
         [-0.027538705617189407],
         [-0.02762441337108612],
         [-0.02746308222413063],
         [-0.02743283286690712]
       ],
       [
         [-0.027538705617189407],
         [-0.02762441337108612],
         [-0.02746308222413063],
         [-0.02743283286690712],
         [-0.027452997863292694]
       ],
       [
         [-0.02762441337108612],
         [-0.02746308222413063],
         [-0.02743283286690712],
         [-0.027452997863292694],
         [-0.027488289400935173]
       ],
       [
         [-0.02746308222413063],
         [-0.02743283286690712],
         [-0.027452997863292694],
         [-0.027488289400935173],
         [-0.027669787406921387]
       ],
       [
         [-0.02743283286690712],
         [-0.027452997863292694],
         [-0.027488289400935173],
         [-0.027669787406921387],
         [-0.02762441337108612]
       ],
       [
         [-0.027452997863292694],
         [-0.027488289400935173],
         [-0.027669787406921387],
         [-0.02762441337108612],
         [-0.02758912183344364]
       ],
       [
         [-0.027488289400935173],
         [-0.027669787406921387],
         [-0.02762441337108612],
         [-0.02758912183344364],
         [-0.02772524580359459]
       ],
       [
         [-0.027669787406921387],
         [-0.02762441337108612],
         [-0.02758912183344364],
         [-0.02772524580359459],
         [-0.02788657695055008]
       ],
       [
         [-0.02762441337108612],
         [-0.02758912183344364],
         ...
       ],
       ...
     ]
   >}
]

Using CNNs for Time Series

cnn_model =
  Axon.input("stock_price")
  |> Axon.conv(32, kernel_size: window_size, activation: :relu)
  |> Axon.dense(32, activation: :relu)
  |> Axon.dense(1)
#Axon<
  inputs: %{"stock_price" => nil}
  outputs: "dense_1"
  nodes: 6
>
template = Nx.template({32, 30, 1}, :f32)
Axon.Display.as_graph(cnn_model, template)
graph TD;
22[/"stock_price (:input) {32, 30, 1}"/];
23["conv_0 (:conv) {32, 1, 32}"];
24["relu_0 (:relu) {32, 1, 32}"];
25["dense_0 (:dense) {32, 1, 32}"];
26["relu_1 (:relu) {32, 1, 32}"];
27["dense_1 (:dense) {32, 1, 1}"];
26 --> 27;
25 --> 26;
24 --> 25;
23 --> 24;
22 --> 23;
cnn_trained_model_state =
  cnn_model
  |> Axon.Loop.trainer(:mean_squared_error, :adam)
  |> Axon.Loop.metric(:mean_absolute_error)
  |> Axon.Loop.run(single_step_train_data, %{}, epochs: 50, compiler: EXLA)

06:56:43.142 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 50, loss: 0.0000663 mean_absolute_error: 0.0058765
Epoch: 1, Batch: 73, loss: 0.0001107 mean_absolute_error: 0.0067741
Epoch: 2, Batch: 46, loss: 0.0000913 mean_absolute_error: 0.0042740
Epoch: 3, Batch: 69, loss: 0.0000695 mean_absolute_error: 0.0042149
Epoch: 4, Batch: 42, loss: 0.0000642 mean_absolute_error: 0.0042601
Epoch: 5, Batch: 65, loss: 0.0000530 mean_absolute_error: 0.0028081
Epoch: 6, Batch: 38, loss: 0.0000486 mean_absolute_error: 0.0019006
Epoch: 7, Batch: 61, loss: 0.0000415 mean_absolute_error: 0.0016131
Epoch: 8, Batch: 34, loss: 0.0000386 mean_absolute_error: 0.0010020
Epoch: 9, Batch: 57, loss: 0.0000342 mean_absolute_error: 0.0013065
Epoch: 10, Batch: 30, loss: 0.0000324 mean_absolute_error: 0.0013755
Epoch: 11, Batch: 53, loss: 0.0000293 mean_absolute_error: 0.0014773
Epoch: 12, Batch: 76, loss: 0.0000269 mean_absolute_error: 0.0017283
Epoch: 13, Batch: 49, loss: 0.0000258 mean_absolute_error: 0.0015440
Epoch: 14, Batch: 72, loss: 0.0000240 mean_absolute_error: 0.0016460
Epoch: 15, Batch: 45, loss: 0.0000232 mean_absolute_error: 0.0014592
Epoch: 16, Batch: 68, loss: 0.0000217 mean_absolute_error: 0.0015306
Epoch: 17, Batch: 41, loss: 0.0000211 mean_absolute_error: 0.0014993
Epoch: 18, Batch: 64, loss: 0.0000200 mean_absolute_error: 0.0015089
Epoch: 19, Batch: 37, loss: 0.0000195 mean_absolute_error: 0.0014502
Epoch: 20, Batch: 60, loss: 0.0000185 mean_absolute_error: 0.0014855
Epoch: 21, Batch: 33, loss: 0.0000181 mean_absolute_error: 0.0014808
Epoch: 22, Batch: 56, loss: 0.0000173 mean_absolute_error: 0.0013899
Epoch: 23, Batch: 29, loss: 0.0000170 mean_absolute_error: 0.0014624
Epoch: 24, Batch: 52, loss: 0.0000163 mean_absolute_error: 0.0013806
Epoch: 25, Batch: 75, loss: 0.0000158 mean_absolute_error: 0.0016300
Epoch: 26, Batch: 48, loss: 0.0000155 mean_absolute_error: 0.0013704
Epoch: 27, Batch: 71, loss: 0.0000150 mean_absolute_error: 0.0015049
Epoch: 28, Batch: 44, loss: 0.0000148 mean_absolute_error: 0.0012911
Epoch: 29, Batch: 67, loss: 0.0000144 mean_absolute_error: 0.0016583
Epoch: 30, Batch: 40, loss: 0.0000143 mean_absolute_error: 0.0011887
Epoch: 31, Batch: 63, loss: 0.0000140 mean_absolute_error: 0.0019614
Epoch: 32, Batch: 36, loss: 0.0000140 mean_absolute_error: 0.0030620
Epoch: 33, Batch: 59, loss: 0.0000137 mean_absolute_error: 0.0011744
Epoch: 34, Batch: 32, loss: 0.0000135 mean_absolute_error: 0.0012163
Epoch: 35, Batch: 55, loss: 0.0000132 mean_absolute_error: 0.0012842
Epoch: 36, Batch: 28, loss: 0.0000130 mean_absolute_error: 0.0010907
Epoch: 37, Batch: 51, loss: 0.0000127 mean_absolute_error: 0.0017228
Epoch: 38, Batch: 74, loss: 0.0000126 mean_absolute_error: 0.0021361
Epoch: 39, Batch: 47, loss: 0.0000125 mean_absolute_error: 0.0020527
Epoch: 40, Batch: 70, loss: 0.0000125 mean_absolute_error: 0.0028232
Epoch: 41, Batch: 43, loss: 0.0000129 mean_absolute_error: 0.0050399
Epoch: 42, Batch: 66, loss: 0.0000132 mean_absolute_error: 0.0044193
Epoch: 43, Batch: 39, loss: 0.0000138 mean_absolute_error: 0.0064203
Epoch: 44, Batch: 62, loss: 0.0000145 mean_absolute_error: 0.0049469
Epoch: 45, Batch: 35, loss: 0.0000149 mean_absolute_error: 0.0054889
Epoch: 46, Batch: 58, loss: 0.0000151 mean_absolute_error: 0.0037215
Epoch: 47, Batch: 31, loss: 0.0000150 mean_absolute_error: 0.0026493
Epoch: 48, Batch: 54, loss: 0.0000148 mean_absolute_error: 0.0013843
Epoch: 49, Batch: 27, loss: 0.0000147 mean_absolute_error: 0.0006616
%{
  "conv_0" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend
      [-0.008091067895293236, -0.016978450119495392, -0.004820083733648062, -0.00792402308434248, -0.017319275066256523, 0.03160012513399124, -0.006478423718363047, -0.011876466684043407, 0.040104787796735764, -0.010837228037416935, -0.012144810520112514, -0.005171791650354862, -0.008013750426471233, -0.00792587362229824, -0.016933709383010864, -0.01142617966979742, -0.012313197366893291, -0.02336074225604534, -0.013391278684139252, -0.005351855419576168, -0.004199669696390629, -0.01840357296168804, 0.017371967434883118, -0.006281338166445494, -0.01266816072165966, -0.010549230501055717, -0.0077864243648946285, -0.009549328126013279, -0.0049909790977835655, -0.005896904040127993, -0.007447496056556702, -0.011414170265197754]
    >,
    "kernel" => #Nx.Tensor<
      f32[30][1][32]
      EXLA.Backend
      [
        [
          [-0.05040709674358368, -0.021235600113868713, -0.010630444623529911, -0.019848855212330818, -0.040217604488134384, -0.12504053115844727, -0.02118268609046936, 0.046331584453582764, 0.03421409800648689, 0.06607695668935776, 0.12348750978708267, -0.12631401419639587, 0.019740285351872444, -0.008459499105811119, 0.03768952935934067, 0.06856438517570496, -5.624796031042933e-4, -0.04322008043527603, 0.06655385345220566, -0.04173768684267998, -0.12437783926725388, -0.03484911471605301, -0.10082294046878815, -0.06766377389431, 0.034250643104314804, 0.05543597787618637, -0.01215132512152195, -0.06139170378446579, 0.004699850454926491, 0.06339087337255478, -0.05598803982138634, 0.05522901192307472]
        ],
        [
          [0.00991822686046362, -0.0031699652317911386, -0.0019580740481615067, 0.055364325642585754, -0.02978527545928955, -0.006119378376752138, 0.002985905623063445, -0.035792168229818344, -0.04653836041688919, 0.003442571498453617, 0.12638360261917114, -0.11723502725362778, -0.0502297542989254, -0.13008807599544525, 0.056758686900138855, ...]
        ],
        ...
      ]
    >
  },
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend
      [-0.03525220975279808, -0.016225244849920273, -0.028320161625742912, -0.02871190384030342, -0.013132854364812374, -0.012726775370538235, -0.007784419227391481, -0.022564707323908806, 0.0, -0.007780132349580526, -0.02287977933883667, 0.020672377198934555, -0.030719507485628128, -0.014918905682861805, 0.002109799301251769, -0.014595966786146164, -0.02758587710559368, -0.01325188297778368, -0.004173007793724537, -0.03749808669090271, -0.009482248686254025, -0.01655416563153267, -0.021616848185658455, -0.021459853276610374, -0.018326226621866226, -0.007884755730628967, -0.01816698908805847, -0.03544246032834053, -0.006969088688492775, -0.007480396889150143, 0.0022076379973441362, -0.032438069581985474]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][32]
      EXLA.Backend
      [
        [-0.17533017694950104, 0.012447532266378403, 0.11476895958185196, -0.2081768661737442, -0.009203040972352028, 0.18192830681800842, -0.131138876080513, -0.2810400426387787, 0.19938461482524872, 0.2786736488342285, -0.27819138765335083, -0.2636408805847168, -0.1670035719871521, -0.09961113333702087, -0.146389439702034, 0.15396006405353546, -0.23006494343280792, 0.024088190868496895, -0.21437454223632812, 0.2188781350851059, -0.2704562842845917, -0.24468927085399628, 0.17817959189414978, -0.23606330156326294, -0.14990635216236115, -0.0874180942773819, 0.13848312199115753, -0.07637172192335129, -0.006744588725268841, -0.0221336018294096, 0.12596826255321503, -0.0856495127081871],
        [-0.1696496605873108, -0.09901808202266693, -0.041791852563619614, -0.29172369837760925, 0.10078410804271698, 0.28704243898391724, 0.17552822828292847, -0.20169097185134888, 0.048185210675001144, 0.2521952986717224, 0.16852286458015442, -0.05647372454404831, -0.09328987449407578, -0.1908075213432312, ...],
        ...
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      EXLA.Backend
      [0.003080096561461687]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][1]
      EXLA.Backend
      [
        [-0.020385151728987694],
        [0.1855364888906479],
        [0.07487309724092484],
        [-0.36159875988960266],
        [0.2996695339679718],
        [-0.14941802620887756],
        [0.0248686745762825],
        [-0.08696726709604263],
        [0.025274399667978287],
        [-0.06669101864099503],
        [0.17602013051509857],
        [0.36575624346733093],
        [-0.13954167068004608],
        [-0.08473435044288635],
        [-0.29618003964424133],
        [-0.1217213124036789],
        [0.005105822812765837],
        [0.00423429673537612],
        [-0.05834190174937248],
        [0.11214695125818253],
        [0.07910924404859543],
        [-0.3471531271934509],
        [-0.003867669031023979],
        [-3.470616138656624e-5],
        [0.08255864679813385],
        [0.38870716094970703],
        [-0.11077144742012024],
        [0.08462876826524734],
        [0.018228236585855484],
        [-0.2302139550447464],
        [-0.19713029265403748],
        [0.33921128511428833]
      ]
    >
  }
}
cnn_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.run(single_step_test_data, cnn_trained_model_state, compiler: EXLA)

06:56:44.326 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 13, mean_absolute_error: 0.0043546
%{
  0 => %{
    "mean_absolute_error" => #Nx.Tensor<
      f32
      EXLA.Backend
      0.004354595206677914
    >
  }
}
0.0035810 * :math.sqrt(Explorer.Series.variance(aapl_df["Close"])) +
  Explorer.Series.mean(aapl_df["Close"])
64.82237670703707
defmodule Analysis do
  def visualize_predictions(
        model,
        model_state,
        prices,
        window_size,
        target_window_size,
        batch_size
      ) do
    {_, predict_fn} = Axon.build(model, compiler: EXLA)

    windows =
      prices
      |> Data.window(window_size, target_window_size)
      |> Data.batch(batch_size)
      |> Stream.map(&amp;elem(&amp;1, 0))

    predicted =
      Enum.flat_map(windows, fn window ->
        predict_fn.(model_state, window) |> Nx.to_flat_list()
      end)

    predicted = [nil, nil, nil, nil, nil, nil, nil, nil, nil, nil | predicted]

    types =
      List.duplicate("AAPL", length(prices)) ++
        List.duplicate("Predicted", length(prices))

    days =
      Enum.to_list(0..(length(prices) - 1)) ++
        Enum.to_list(0..(length(prices) - 1))

    prices = prices ++ predicted

    plot(
      %{
        "day" => days,
        "prices" => prices,
        "types" => types
      },
      "AAPL Stock Price vs. Predicted, RNN Single-Shot"
    )
  end

  defp plot(values, title) do
    Vl.new(title: title, width: 1280, height: 720)
    |> Vl.data_from_values(values)
    |> Vl.mark(:line)
    |> Vl.encode_field(:x, "day", type: :temporal)
    |> Vl.encode_field(:y, "prices", type: :quantitative)
    |> Vl.encode_field(:color, "types", type: :nominal)
    |> Kino.VegaLite.new()
  end
end
{:module, Analysis, <<70, 79, 82, 49, 0, 0, 16, ...>>, {:plot, 2}}
Analysis.visualize_predictions(
  cnn_model,
  cnn_trained_model_state,
  Explorer.Series.to_list(normalized_aapl_df["Close"]),
  window_size,
  1,
  batch_size
)

Using RNNs for Time Series Prediction

rnn_model =
  Axon.input("stock_prices")
  |> Axon.lstm(32)
  |> elem(0)
  |> Axon.nx(&amp; &amp;1[[0..-1//1, -1, 0..-1//1]])
  |> Axon.dense(1)
#Axon<
  inputs: %{"stock_prices" => nil}
  outputs: "dense_0"
  nodes: 8
>
rnn_trained_model_state =
  rnn_model
  |> Axon.Loop.trainer(:mean_squared_error, :adam)
  |> Axon.Loop.metric(:mean_absolute_error)
  |> Axon.Loop.run(single_step_train_data, %{}, epochs: 50, compiler: EXLA)

07:04:31.866 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 50, loss: 0.0000897 mean_absolute_error: 0.0073940
Epoch: 1, Batch: 73, loss: 0.0001568 mean_absolute_error: 0.0097099
Epoch: 2, Batch: 46, loss: 0.0001567 mean_absolute_error: 0.0086432
Epoch: 3, Batch: 69, loss: 0.0001245 mean_absolute_error: 0.0055883
Epoch: 4, Batch: 42, loss: 0.0001146 mean_absolute_error: 0.0052446
Epoch: 5, Batch: 65, loss: 0.0000927 mean_absolute_error: 0.0029908
Epoch: 6, Batch: 38, loss: 0.0000843 mean_absolute_error: 0.0022326
Epoch: 7, Batch: 61, loss: 0.0000710 mean_absolute_error: 0.0013697
Epoch: 8, Batch: 34, loss: 0.0000658 mean_absolute_error: 0.0011014
Epoch: 9, Batch: 57, loss: 0.0000575 mean_absolute_error: 0.0012567
Epoch: 10, Batch: 30, loss: 0.0000542 mean_absolute_error: 0.0013473
Epoch: 11, Batch: 53, loss: 0.0000486 mean_absolute_error: 0.0013981
Epoch: 12, Batch: 76, loss: 0.0000442 mean_absolute_error: 0.0015796
Epoch: 13, Batch: 49, loss: 0.0000423 mean_absolute_error: 0.0015473
Epoch: 14, Batch: 72, loss: 0.0000391 mean_absolute_error: 0.0016943
Epoch: 15, Batch: 45, loss: 0.0000377 mean_absolute_error: 0.0017310
Epoch: 16, Batch: 68, loss: 0.0000352 mean_absolute_error: 0.0018595
Epoch: 17, Batch: 41, loss: 0.0000342 mean_absolute_error: 0.0020594
Epoch: 18, Batch: 64, loss: 0.0000323 mean_absolute_error: 0.0021272
Epoch: 19, Batch: 37, loss: 0.0000316 mean_absolute_error: 0.0026239
Epoch: 20, Batch: 60, loss: 0.0000302 mean_absolute_error: 0.0025557
Epoch: 21, Batch: 33, loss: 0.0000297 mean_absolute_error: 0.0034725
Epoch: 22, Batch: 56, loss: 0.0000286 mean_absolute_error: 0.0029877
Epoch: 23, Batch: 29, loss: 0.0000283 mean_absolute_error: 0.0040037
Epoch: 24, Batch: 52, loss: 0.0000275 mean_absolute_error: 0.0031490
Epoch: 25, Batch: 75, loss: 0.0000267 mean_absolute_error: 0.0028734
Epoch: 26, Batch: 48, loss: 0.0000265 mean_absolute_error: 0.0030775
Epoch: 27, Batch: 71, loss: 0.0000258 mean_absolute_error: 0.0027187
Epoch: 28, Batch: 44, loss: 0.0000255 mean_absolute_error: 0.0029207
Epoch: 29, Batch: 67, loss: 0.0000248 mean_absolute_error: 0.0025301
Epoch: 30, Batch: 40, loss: 0.0000245 mean_absolute_error: 0.0028126
Epoch: 31, Batch: 63, loss: 0.0000239 mean_absolute_error: 0.0024180
Epoch: 32, Batch: 36, loss: 0.0000236 mean_absolute_error: 0.0028866
Epoch: 33, Batch: 59, loss: 0.0000230 mean_absolute_error: 0.0024546
Epoch: 34, Batch: 32, loss: 0.0000228 mean_absolute_error: 0.0030757
Epoch: 35, Batch: 55, loss: 0.0000223 mean_absolute_error: 0.0025202
Epoch: 36, Batch: 28, loss: 0.0000221 mean_absolute_error: 0.0032313
Epoch: 37, Batch: 51, loss: 0.0000217 mean_absolute_error: 0.0026088
Epoch: 38, Batch: 74, loss: 0.0000213 mean_absolute_error: 0.0025369
Epoch: 39, Batch: 47, loss: 0.0000211 mean_absolute_error: 0.0027412
Epoch: 40, Batch: 70, loss: 0.0000208 mean_absolute_error: 0.0026344
Epoch: 41, Batch: 43, loss: 0.0000207 mean_absolute_error: 0.0029186
Epoch: 42, Batch: 66, loss: 0.0000203 mean_absolute_error: 0.0026540
Epoch: 43, Batch: 39, loss: 0.0000203 mean_absolute_error: 0.0030953
Epoch: 44, Batch: 62, loss: 0.0000200 mean_absolute_error: 0.0026985
Epoch: 45, Batch: 35, loss: 0.0000199 mean_absolute_error: 0.0033785
Epoch: 46, Batch: 58, loss: 0.0000196 mean_absolute_error: 0.0028011
Epoch: 47, Batch: 31, loss: 0.0000196 mean_absolute_error: 0.0035820
Epoch: 48, Batch: 54, loss: 0.0000193 mean_absolute_error: 0.0028405
Epoch: 49, Batch: 27, loss: 0.0000193 mean_absolute_error: 0.0037105
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      EXLA.Backend
      [-0.0022059190087020397]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][1]
      EXLA.Backend
      [
        [0.45946452021598816],
        [-4.5703467912971973e-4],
        [0.23180651664733887],
        [-0.005186113528907299],
        [0.2575457692146301],
        [-0.04142376407980919],
        [-0.016365909948945045],
        [-0.059576015919446945],
        [0.31844404339790344],
        [-0.008791795000433922],
        [0.39062708616256714],
        [-0.06893440335988998],
        [0.055386751890182495],
        [0.002047080546617508],
        [-0.07389257848262787],
        [-0.08267084509134293],
        [0.02030925452709198],
        [-0.03783267363905907],
        [0.29332268238067627],
        [7.742071175016463e-4],
        [0.37041202187538147],
        [-0.004578940570354462],
        [0.3596736788749695],
        [0.33445674180984497],
        [-0.28043803572654724],
        [-0.3852922320365906],
        [0.004091555718332529],
        [-0.005642038304358721],
        [0.014751603826880455],
        [0.12446080893278122],
        [-0.09465767443180084],
        [-0.007954136468470097]
      ]
    >
  },
  "lstm_0" => %{
    "bias" => {#Nx.Tensor<
       f32[32]
       EXLA.Backend
       [0.11335168778896332, -0.06401084363460541, 0.031347475945949554, -0.10822173953056335, 0.027776755392551422, -0.18160955607891083, -0.1230732873082161, -0.1875688135623932, -0.06549829989671707, -0.17022982239723206, -0.0032370283734053373, -0.22780485451221466, -0.1940574198961258, -0.10790818184614182, -0.08154703676700592, -0.2168857604265213, -0.16268092393875122, -0.1510622799396515, -0.11275982856750488, -0.053832173347473145, -0.0012655003229156137, -0.10015828162431717, 0.006671776529401541, -0.0446050763130188, -0.0055354926735162735, 0.03272601217031479, -0.07376735657453537, -0.11262574046850204, -0.054739225655794144, -0.23879562318325043, -0.11311056464910507, -0.17760637402534485]
     >,
     #Nx.Tensor<
       f32[32]
       EXLA.Backend
       [0.10780646651983261, -0.06285148113965988, 0.023584432899951935, -0.10630014538764954, 0.022825146093964577, -0.17115405201911926, -0.11822430044412613, -0.17694641649723053, -0.0663384273648262, -0.16032005846500397, -0.010648874565958977, -0.21266703307628632, -0.18182429671287537, -0.10441649705171585, -0.08326958119869232, -0.209574893116951, -0.15406334400177002, -0.14615078270435333, -0.1299838423728943, -0.05346710607409477, -0.006063435692340136, -0.09791524708271027, 6.202268414199352e-4, -0.05215458944439888, -0.008472379297018051, 0.027613848447799683, -0.07298403233289719, -0.10905328392982483, -0.05714934691786766, -0.22263658046722412, -0.11449617892503738, -0.16627191007137299]
     >,
     #Nx.Tensor<
       f32[32]
       EXLA.Backend
       [0.003731300588697195, 4.8132456140592694e-4, 0.005429568700492382, -0.003417501924559474, 0.004267899785190821, 0.0023204206954687834, -0.003725315211340785, 0.00241837534122169, 0.0011554744560271502, 0.00141581567004323, 0.0017936892108991742, 9.837766410782933e-4, -0.0013038406614214182, 0.0018337997607886791, -0.002206782577559352, -6.189559935592115e-4, -0.001890703453682363, -0.0014234009431675076, 0.003332881722599268, 6.696510245092213e-4, 0.0024242501240223646, -0.001337566296570003, 0.003793303854763508, 0.0028332320507615805, -0.001609275583177805, -0.0017298919847235084, 0.0016395390266552567, -3.6149637890048325e-4, 0.0016541111981496215, -3.5133626079186797e-4, -0.002388330176472664, -0.004853168968111277]
     >,
     #Nx.Tensor<
       f32[32]
       EXLA.Backend
       [0.1105206236243248, -0.065388984978199, 0.027362428605556488, -0.11799526959657669, 0.025150544941425323, -0.17694349586963654, -0.12093527615070343, -0.18315498530864716, -0.06615359336137772, -0.16683995723724365, -0.007195259910076857, -0.22143132984638214, -0.18868662416934967, -0.1067797839641571, -0.083305224776268, -0.21811549365520477, -0.15893080830574036, -0.14961159229278564, -0.12222269177436829, -0.05434230715036392, -0.0038143477868288755, -0.09983281791210175, 0.0035336518194526434, -0.0486055389046669, -0.007181230932474136, 0.030025919899344444, -0.0738983154296875, -0.11130935698747635, -0.05862840265035629, -0.23184144496917725, -0.11468596011400223, -0.17238789796829224]
     >},
    "hidden_kernel" => {#Nx.Tensor<
       f32[32][32]
       EXLA.Backend
       [
         [-0.11421269923448563, 0.11127666383981705, -0.184355691075325, -0.27127766609191895, -0.06881655752658844, 0.10430511832237244, -0.2448892444372177, -0.2653137147426605, 0.06281697005033493, -0.013854310847818851, -0.006992769427597523, 0.13280320167541504, 0.14791104197502136, -0.1784118115901947, -0.06113160401582718, 0.1929483562707901, 0.04987379536032677, 0.2280454933643341, -0.19724753499031067, 0.15979085862636566, 0.27148231863975525, -0.2724880576133728, -0.1679612696170807, 0.21920432150363922, -0.23523777723312378, 0.31118467450141907, -0.2611495852470398, -0.065000019967556, 0.10206960141658783, 0.025555908679962158, 0.19122368097305298, 0.0023696101270616055],
         [-0.027491245418787003, -0.2316892296075821, 0.015483269467949867, 0.02644273452460766, -0.23044149577617645, 0.122772216796875, -0.20443205535411835, -0.04793649911880493, -0.2743169069290161, -0.21698445081710815, 0.17927387356758118, 0.08292767405509949, 0.20277370512485504, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[32][32]
       EXLA.Backend
       [
         [-0.11538577824831009, 0.1112610325217247, -0.1845649778842926, -0.27134761214256287, -0.06905922293663025, 0.10186973214149475, -0.24537153542041779, -0.2668277323246002, 0.061841294169425964, -0.014799466356635094, -0.0065930732525885105, 0.12995310127735138, 0.14654433727264404, -0.17863759398460388, -0.06125305965542793, 0.1921101212501526, 0.049221768975257874, 0.2271725982427597, -0.197763130068779, 0.15969622135162354, 0.27121856808662415, -0.2727193534374237, -0.16848579049110413, 0.21903260052204132, -0.23458200693130493, 0.31177598237991333, -0.26127612590789795, -0.06564890593290329, 0.10203193873167038, 0.02351302094757557, 0.19066844880580902, 8.716363226994872e-4],
         [-0.02810082957148552, -0.23168812692165375, 0.015438610687851906, 0.026450933888554573, -0.23074249923229218, 0.1225915178656578, -0.20462723076343536, -0.048075538128614426, -0.2743925154209137, -0.21695131063461304, 0.17911174893379211, 0.08302958309650421, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[32][32]
       EXLA.Backend
       [
         [0.008229146711528301, 0.1445600688457489, -0.10443291813135147, -0.23759327828884125, 0.014250179752707481, -0.02108207531273365, -0.19692735373973846, -0.379713773727417, 0.10434218496084213, -0.10935195535421371, 0.07177723199129105, -4.993749316781759e-4, 0.19762060046195984, -0.2601209580898285, -0.1646450161933899, 0.09917865693569183, 0.11334932595491409, 0.12844809889793396, -0.10628209263086319, 0.1521795243024826, 0.34273761510849, -0.22396697103977203, -0.08686701953411102, 0.28526702523231506, -0.358896940946579, 0.19280748069286346, -0.3215658664703369, -0.03377417474985123, 0.14930129051208496, 0.06669514626264572, 0.08144361525774002, 0.042655251920223236],
         [-0.08700522035360336, -0.29304251074790955, -0.04271193593740463, -0.02882186509668827, -0.2872234284877777, 0.19526077806949615, -0.2664906084537506, 0.023175811395049095, -0.3267883062362671, -0.15051718056201935, 0.11813374608755112, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[32][32]
       EXLA.Backend
       [
         [-0.12982815504074097, 0.11378036439418793, -0.18149970471858978, -0.2677125632762909, -0.07103432714939117, 0.11825311928987503, -0.2336365133523941, -0.2543303370475769, 0.07081130146980286, 0.0022928505204617977, 0.0032019137870520353, 0.15594728291034698, 0.1622438281774521, -0.17180950939655304, -0.04690863564610481, 0.2031450867652893, 0.059193458408117294, 0.24310606718063354, -0.19648565351963043, 0.16401754319667816, 0.27107858657836914, -0.2651662528514862, -0.1745850145816803, 0.2315697968006134, -0.23139923810958862, 0.3055410087108612, -0.25599977374076843, -0.04990411549806595, 0.10681464523077011, 0.043445780873298645, 0.20912222564220428, 0.018217815086245537],
         [-0.03424784541130066, -0.23208343982696533, 0.009599815122783184, 0.026051534339785576, -0.23864002525806427, 0.12222269177436829, -0.20658233761787415, -0.04860348626971245, -0.28558629751205444, -0.21860875189304352, ...],
         ...
       ]
     >},
    "input_kernel" => {#Nx.Tensor<
       f32[1][32]
       EXLA.Backend
       [
         [0.15832173824310303, -0.06945277750492096, 0.29155898094177246, -0.1992305964231491, 0.13684962689876556, 0.3534727692604065, -0.21600008010864258, 0.22524043917655945, 0.01816789247095585, 0.23636147379875183, 0.059991706162691116, 0.21631649136543274, -0.0855906531214714, 0.25307703018188477, 0.007452141959220171, 0.06080798804759979, -0.14563453197479248, 0.09734586626291275, 0.19660542905330658, 0.09763676673173904, 0.09367682784795761, -0.12293204665184021, 0.22191683948040009, 0.1307179480791092, -0.09499726444482803, -0.4519442617893219, 0.19236356019973755, -0.0888831838965416, 0.11068467795848846, -0.04182599112391472, -0.059468429535627365, -0.30685484409332275]
       ]
     >,
     #Nx.Tensor<
       f32[1][32]
       EXLA.Backend
       [
         [0.15510421991348267, -0.06956983357667923, 0.2913159430027008, -0.19954615831375122, 0.1365644782781601, 0.34723758697509766, -0.21715417504310608, 0.219896599650383, 0.018290145322680473, 0.23421815037727356, 0.06030489131808281, 0.21138715744018555, -0.08892093598842621, 0.2523980140686035, 0.007690330035984516, 0.05870341509580612, -0.14841823279857635, 0.09678869694471359, 0.1968754231929779, 0.09719568490982056, 0.09413376450538635, -0.12356305867433548, 0.22244440019130707, 0.1320788860321045, -0.09404245018959045, -0.4512008726596832, 0.19207976758480072, -0.08969637751579285, 0.11066418141126633, -0.04511291906237602, -0.05949528515338898, -0.3111755847930908]
       ]
     >,
     #Nx.Tensor<
       f32[1][32]
       EXLA.Backend
       [
         [0.5210095047950745, -0.04471205919981003, 0.5679916143417358, -0.24680404365062714, 0.42526307702064514, 0.10860466212034225, -0.1759696751832962, -0.01207458134740591, 0.21903882920742035, 0.051724907010793686, 0.30118006467819214, -0.03016113117337227, -0.004351356066763401, 0.12323488295078278, -0.2007794976234436, -0.13582377135753632, -0.060969483107328415, -0.09213265776634216, 0.4137095510959625, 0.08238348364830017, 0.3602970838546753, -0.0882539451122284, 0.5199012756347656, 0.3512928783893585, -0.25873085856437683, -0.5727313160896301, 0.12393758445978165, -0.06665874272584915, 0.23254042863845825, 0.04973064363002777, -0.27502191066741943, -0.26436713337898254]
       ]
     >,
     #Nx.Tensor<
       f32[1][32]
       EXLA.Backend
       [
         [0.14472824335098267, -0.0628005787730217, 0.27467724680900574, -0.1878383904695511, 0.1254066377878189, 0.36361345648765564, -0.211454838514328, 0.2393423318862915, 0.014579376205801964, 0.24944588541984558, 0.048402633517980576, 0.22778557240962982, -0.07714307308197021, 0.2585679888725281, 0.006932367570698261, 0.07686306536197662, -0.13720948994159698, 0.10166746377944946, 0.18309403955936432, 0.11173634976148605, 0.08637332171201706, -0.11274421215057373, 0.2134648561477661, 0.11872217804193497, -0.10573867708444595, -0.4594239592552185, 0.1963488757610321, -0.08307348936796188, 0.11799244582653046, -0.033315449953079224, -0.05906962230801582, -0.2986782193183899]
       ]
     >}
  },
  "lstm__c_hidden_state" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [392164166, 4227701866]
    >
  },
  "lstm__h_hidden_state" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [392164166, 4227701866]
    >
  }
}
0.0032470 * :math.sqrt(Explorer.Series.variance(aapl_df["Close"])) +
  Explorer.Series.mean(aapl_df["Close"])
64.80750153333416
Analysis.visualize_predictions(
  rnn_model,
  rnn_trained_model_state,
  Explorer.Series.to_list(normalized_aapl_df["Close"]),
  30,
  1,
  32
)