Forecasting the Future
Mix.install([
{:explorer, "~> 0.5.0"},
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:axon, "~> 0.5"},
{:vega_lite, "~> 0.1.6"},
{:kino, "~> 0.8.0"},
{:kino_vega_lite, "~> 0.1.7"}
])
alias VegaLite, as: Vl
Resolving Hex dependencies...
Resolution completed in 0.117s
New:
axon 0.5.1
castore 1.0.2
complex 0.5.0
elixir_make 0.7.6
exla 0.5.3
explorer 0.5.7
kino 0.8.1
kino_vega_lite 0.1.8
nx 0.5.3
rustler_precompiled 0.6.1
table 0.1.2
table_rex 3.1.1
telemetry 1.2.1
vega_lite 0.1.7
xla 0.4.4
* Getting explorer (Hex package)
* Getting nx (Hex package)
* Getting exla (Hex package)
* Getting axon (Hex package)
* Getting vega_lite (Hex package)
* Getting kino (Hex package)
* Getting kino_vega_lite (Hex package)
* Getting table (Hex package)
* Getting elixir_make (Hex package)
* Getting telemetry (Hex package)
* Getting xla (Hex package)
* Getting complex (Hex package)
* Getting rustler_precompiled (Hex package)
* Getting table_rex (Hex package)
* Getting castore (Hex package)
==> table
Compiling 5 files (.ex)
Generated table app
==> vega_lite
Compiling 5 files (.ex)
Generated vega_lite app
===> Analyzing applications...
===> Compiling telemetry
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 31 files (.ex)
Generated nx app
==> kino
Compiling 37 files (.ex)
Generated kino app
==> kino_vega_lite
Compiling 4 files (.ex)
Generated kino_vega_lite app
==> table_rex
Compiling 7 files (.ex)
Generated table_rex app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> castore
Compiling 1 file (.ex)
Generated castore app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /Users/sean/Library/Caches/xla/0.4.4/cache/download/xla_extension-aarch64-darwin-cpu.tar.gz into /Users/sean/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0.2/adb2e876bd0f4c056c66dfce006b140f/deps/exla/cache
Using libexla.so from /Users/sean/Library/Caches/xla/exla/elixir-1.14.2-erts-13.0.2-xla-0.4.4-exla-0.5.3-defofxrodwsk5sselkno4icm44/libexla.so
Compiling 21 files (.ex)
Generated exla app
==> rustler_precompiled
Compiling 4 files (.ex)
Generated rustler_precompiled app
==> explorer
Compiling 19 files (.ex)
06:46:43.386 [debug] Copying NIF from cache and extracting to /Users/sean/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0.2/adb2e876bd0f4c056c66dfce006b140f/_build/dev/lib/explorer/priv/native/libexplorer-v0.5.7-nif-2.16-aarch64-apple-darwin.so
Generated explorer app
VegaLite
The Data
File.cwd!()
"/Users/sean/projects/scholar/notebooks"
df =
Explorer.DataFrame.from_csv!(
"all_stocks_2006-01-01_to_2018-01-01.csv",
parse_dates: true
)
#Explorer.DataFrame<
Polars[93612 x 7]
Date date [2006-01-03, 2006-01-04, 2006-01-05, 2006-01-06, 2006-01-09, ...]
Open float [77.76, 79.49, 78.41, 78.64, 78.5, ...]
High float [79.35, 79.49, 78.65, 78.9, 79.83, ...]
Low float [77.24, 78.25, 77.56, 77.64, 78.46, ...]
Close float [79.11, 78.71, 77.99, 78.63, 79.02, ...]
Volume integer [3117200, 2558000, 2529500, 2479500, 1845600, ...]
Name string ["MMM", "MMM", "MMM", "MMM", "MMM", ...]
>
df = Explorer.DataFrame.select(df, ["Date", "Close", "Name"])
#Explorer.DataFrame<
Polars[93612 x 3]
Date date [2006-01-03, 2006-01-04, 2006-01-05, 2006-01-06, 2006-01-09, ...]
Close float [79.11, 78.71, 77.99, 78.63, 79.02, ...]
Name string ["MMM", "MMM", "MMM", "MMM", "MMM", ...]
>
Vl.new(title: "DJIA Stock Prices", width: 1280, height: 720)
|> Vl.data_from_values(Explorer.DataFrame.to_columns(df))
|> Vl.mark(:line)
|> Vl.encode_field(:x, "Date", type: :temporal)
|> Vl.encode_field(:y, "Close", type: :quantitative)
|> Vl.encode_field(:color, "Name", type: :nominal)
|> Kino.VegaLite.new()
aapl_df =
Explorer.DataFrame.filter_with(df, fn df ->
Explorer.Series.equal(df["Name"], "AAPL")
end)
normalized_aapl_df =
Explorer.DataFrame.mutate_with(aapl_df, fn df ->
var = Explorer.Series.variance(df["Close"])
mean = Explorer.Series.mean(df["Close"])
centered = Explorer.Series.subtract(df["Close"], mean)
norm = Explorer.Series.divide(centered, var)
[Close: norm]
end)
#Explorer.DataFrame<
Polars[3019 x 3]
Date date [2006-01-03, 2006-01-04, 2006-01-05, 2006-01-06, 2006-01-09, ...]
Close float [-0.027216043254776588, -0.02720091843889246, -0.027241251281250132,
-0.027105127938293, -0.027125294359471832, ...]
Name string ["AAPL", "AAPL", "AAPL", "AAPL", "AAPL", ...]
>
Vl.new(title: "AAPL Stock Price", width: 1280, height: 720)
|> Vl.data_from_values(Explorer.DataFrame.to_columns(normalized_aapl_df))
|> Vl.mark(:line)
|> Vl.encode_field(:x, "Date", type: :temporal)
|> Vl.encode_field(:y, "Close", type: :quantitative)
|> Kino.VegaLite.new()
defmodule Data do
def window(inputs, window_size, target_window_size) do
inputs
|> Stream.chunk_every(window_size + target_window_size, 1, :discard)
|> Stream.map(fn window ->
features = Enum.take(window, window_size)
targets = Enum.drop(window, window_size)
{Nx.tensor(features) |> Nx.new_axis(-1), Nx.tensor(targets) |> Nx.new_axis(-1)}
end)
end
def batch(inputs, batch_size) do
inputs
|> Stream.chunk_every(batch_size, batch_size, :discard)
|> Stream.map(fn windows ->
{features, targets} = Enum.unzip(windows)
{Nx.stack(features), Nx.stack(targets)}
end)
end
end
{:module, Data, <<70, 79, 82, 49, 0, 0, 10, ...>>, {:batch, 2}}
train_df =
Explorer.DataFrame.filter_with(normalized_aapl_df, fn df ->
Explorer.Series.less(df["Date"], Date.new!(2016, 1, 1))
end)
test_df =
Explorer.DataFrame.filter_with(normalized_aapl_df, fn df ->
Explorer.Series.greater_equal(df["Date"], Date.new!(2016, 1, 1))
end)
#Explorer.DataFrame<
Polars[503 x 3]
Date date [2016-01-04, 2016-01-05, 2016-01-06, 2016-01-07, 2016-01-08, ...]
Close float [0.02051283407023046, 0.01918185027242737, 0.01816848760819093, 0.01602580535793974,
0.016282927227969878, ...]
Name string ["AAPL", "AAPL", "AAPL", "AAPL", "AAPL", ...]
>
window_size = 30
batch_size = 32
train_prices = Explorer.Series.to_list(train_df["Close"])
test_prices = Explorer.Series.to_list(test_df["Close"])
single_step_train_data =
train_prices
|> Data.window(window_size, 1)
|> Data.batch(batch_size)
single_step_test_data =
test_prices
|> Data.window(window_size, 1)
|> Data.batch(batch_size)
multi_step_train_data =
train_prices
|> Data.window(window_size, 5)
|> Data.batch(batch_size)
multi_step_test_data =
test_prices
|> Data.window(window_size, 5)
|> Data.batch(batch_size)
#Stream<[
enum: #Stream<[
enum: #Stream<[
enum: [0.02051283407023046, 0.01918185027242737, 0.01816848760819093, 0.01602580535793974,
0.016282927227969878, 0.017074459259239144, 0.01779540881638248, 0.016499716255642356,
0.017573578183415303, 0.016368634517979926, 0.01613167906912862, 0.016197219937959837,
0.015950181278519108, 0.018531483189409954, 0.017533245341057633, 0.01781053363226661,
0.01449819895364301, 0.014835986508388493, 0.016474508229168815, 0.016015722147350326,
0.015032609114882132, 0.01597538930499265, 0.016101429437360367, 0.014800695271325529,
0.015299814195501693, 0.01528973098491227, 0.014926735403693246, 0.014639363901894855,
0.014785570455441403, 0.016121595858539204, 0.016867753442156092, 0.015930014857340278,
0.015819099540856685, 0.01624259438561221, 0.015138482826071012, 0.015849349172624934,
0.01618209512207571, 0.016257719201496337, 0.01614680388501275, 0.01808278031818088,
0.01819369563466447, 0.01857181603176762, 0.019333098431268635, 0.018758355427671847,
0.018334860582916314, 0.018380235030568695, 0.018405443057042236, ...],
funs: [#Function<3.6935098/1 in Stream.chunk_while/4>]
]>,
funs: [#Function<48.6935098/1 in Stream.map/2>, #Function<3.6935098/1 in Stream.chunk_while/4>]
]>,
funs: [#Function<48.6935098/1 in Stream.map/2>]
]>
Enum.take(single_step_train_data, 1)
[
{#Nx.Tensor<
f32[32][30][1]
[
[
[-0.027216043323278427],
[-0.027200918644666672],
[-0.02724125050008297],
[-0.02710512839257717],
[-0.027125295251607895],
[-0.026777423918247223],
[-0.026555592194199562],
[-0.02653038501739502],
[-0.02643459476530552],
[-0.02650013566017151],
[-0.026661466807127],
[-0.026908505707979202],
[-0.027120253071188927],
[-0.027004295960068703],
[-0.027125295251607895],
[-0.027256375178694725],
[-0.027392499148845673],
[-0.027412666007876396],
[-0.027200918644666672],
[-0.027160584926605225],
[-0.02717066928744316],
[-0.027407623827457428],
[-0.02742779068648815],
[-0.0277554951608181],
[-0.027730286121368408],
[-0.027644580230116844],
[-0.02792186848819256],
[-0.027750452980399132],
[-0.027942035347223282],
[-0.027730286121368408]
],
[
[-0.027200918644666672],
[-0.02724125050008297],
[-0.02710512839257717],
[-0.027125295251607895],
[-0.026777423918247223],
[-0.026555592194199562],
[-0.02653038501739502],
[-0.02643459476530552],
[-0.02650013566017151],
[-0.026661466807127],
[-0.026908505707979202],
[-0.027120253071188927],
[-0.027004295960068703],
[-0.027125295251607895],
[-0.027256375178694725],
[-0.027392499148845673],
[-0.027412666007876396],
[-0.027200918644666672],
...
],
...
]
>,
#Nx.Tensor<
f32[32][1][1]
[
[
[-0.027614330872893333]
],
[
[-0.027518538758158684]
],
[
[-0.027538705617189407]
],
[
[-0.02762441337108612]
],
[
[-0.02746308222413063]
],
[
[-0.02743283286690712]
],
[
[-0.027452997863292694]
],
[
[-0.027488289400935173]
],
[
[-0.027669787406921387]
],
[
[-0.02762441337108612]
],
[
[-0.02758912183344364]
],
[
[-0.02772524580359459]
],
[
[-0.02788657695055008]
],
[
[-0.027826078236103058]
],
[
[-0.027871452271938324]
],
[
[-0.027997491881251335]
],
[
[-0.02804790809750557]
],
[
[-0.027871452271938324]
],
[
[-0.027750452980399132]
],
[
[-0.027831118553876877]
],
[
[-0.027967242524027824]
],
[
[-0.027942035347223282]
],
[
[-0.027992449700832367]
],
[
[-0.028148740530014038]
],
[
[-0.028158823028206825]
],
[
[-0.02826973795890808]
],
[
[-0.028279822319746017]
],
[
[-0.028315113857388496]
],
[
[-0.02837057039141655]
],
[
[-0.02811344899237156]
],
[
[-0.02808319963514805]
],
[
[-0.02808319963514805]
]
]
>}
]
Enum.take(multi_step_train_data, 1)
[
{#Nx.Tensor<
f32[32][30][1]
[
[
[-0.027216043323278427],
[-0.027200918644666672],
[-0.02724125050008297],
[-0.02710512839257717],
[-0.027125295251607895],
[-0.026777423918247223],
[-0.026555592194199562],
[-0.02653038501739502],
[-0.02643459476530552],
[-0.02650013566017151],
[-0.026661466807127],
[-0.026908505707979202],
[-0.027120253071188927],
[-0.027004295960068703],
[-0.027125295251607895],
[-0.027256375178694725],
[-0.027392499148845673],
[-0.027412666007876396],
[-0.027200918644666672],
[-0.027160584926605225],
[-0.02717066928744316],
[-0.027407623827457428],
[-0.02742779068648815],
[-0.0277554951608181],
[-0.027730286121368408],
[-0.027644580230116844],
[-0.02792186848819256],
[-0.027750452980399132],
[-0.027942035347223282],
[-0.027730286121368408]
],
[
[-0.027200918644666672],
[-0.02724125050008297],
[-0.02710512839257717],
[-0.027125295251607895],
[-0.026777423918247223],
[-0.026555592194199562],
[-0.02653038501739502],
[-0.02643459476530552],
[-0.02650013566017151],
[-0.026661466807127],
[-0.026908505707979202],
[-0.027120253071188927],
[-0.027004295960068703],
[-0.027125295251607895],
[-0.027256375178694725],
[-0.027392499148845673],
[-0.027412666007876396],
[-0.027200918644666672],
...
],
...
]
>,
#Nx.Tensor<
f32[32][5][1]
[
[
[-0.027614330872893333],
[-0.027518538758158684],
[-0.027538705617189407],
[-0.02762441337108612],
[-0.02746308222413063]
],
[
[-0.027518538758158684],
[-0.027538705617189407],
[-0.02762441337108612],
[-0.02746308222413063],
[-0.02743283286690712]
],
[
[-0.027538705617189407],
[-0.02762441337108612],
[-0.02746308222413063],
[-0.02743283286690712],
[-0.027452997863292694]
],
[
[-0.02762441337108612],
[-0.02746308222413063],
[-0.02743283286690712],
[-0.027452997863292694],
[-0.027488289400935173]
],
[
[-0.02746308222413063],
[-0.02743283286690712],
[-0.027452997863292694],
[-0.027488289400935173],
[-0.027669787406921387]
],
[
[-0.02743283286690712],
[-0.027452997863292694],
[-0.027488289400935173],
[-0.027669787406921387],
[-0.02762441337108612]
],
[
[-0.027452997863292694],
[-0.027488289400935173],
[-0.027669787406921387],
[-0.02762441337108612],
[-0.02758912183344364]
],
[
[-0.027488289400935173],
[-0.027669787406921387],
[-0.02762441337108612],
[-0.02758912183344364],
[-0.02772524580359459]
],
[
[-0.027669787406921387],
[-0.02762441337108612],
[-0.02758912183344364],
[-0.02772524580359459],
[-0.02788657695055008]
],
[
[-0.02762441337108612],
[-0.02758912183344364],
...
],
...
]
>}
]
Using CNNs for Time Series
cnn_model =
Axon.input("stock_price")
|> Axon.conv(32, kernel_size: window_size, activation: :relu)
|> Axon.dense(32, activation: :relu)
|> Axon.dense(1)
#Axon<
inputs: %{"stock_price" => nil}
outputs: "dense_1"
nodes: 6
>
template = Nx.template({32, 30, 1}, :f32)
Axon.Display.as_graph(cnn_model, template)
graph TD;
22[/"stock_price (:input) {32, 30, 1}"/];
23["conv_0 (:conv) {32, 1, 32}"];
24["relu_0 (:relu) {32, 1, 32}"];
25["dense_0 (:dense) {32, 1, 32}"];
26["relu_1 (:relu) {32, 1, 32}"];
27["dense_1 (:dense) {32, 1, 1}"];
26 --> 27;
25 --> 26;
24 --> 25;
23 --> 24;
22 --> 23;
cnn_trained_model_state =
cnn_model
|> Axon.Loop.trainer(:mean_squared_error, :adam)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.run(single_step_train_data, %{}, epochs: 50, compiler: EXLA)
06:56:43.142 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 50, loss: 0.0000663 mean_absolute_error: 0.0058765
Epoch: 1, Batch: 73, loss: 0.0001107 mean_absolute_error: 0.0067741
Epoch: 2, Batch: 46, loss: 0.0000913 mean_absolute_error: 0.0042740
Epoch: 3, Batch: 69, loss: 0.0000695 mean_absolute_error: 0.0042149
Epoch: 4, Batch: 42, loss: 0.0000642 mean_absolute_error: 0.0042601
Epoch: 5, Batch: 65, loss: 0.0000530 mean_absolute_error: 0.0028081
Epoch: 6, Batch: 38, loss: 0.0000486 mean_absolute_error: 0.0019006
Epoch: 7, Batch: 61, loss: 0.0000415 mean_absolute_error: 0.0016131
Epoch: 8, Batch: 34, loss: 0.0000386 mean_absolute_error: 0.0010020
Epoch: 9, Batch: 57, loss: 0.0000342 mean_absolute_error: 0.0013065
Epoch: 10, Batch: 30, loss: 0.0000324 mean_absolute_error: 0.0013755
Epoch: 11, Batch: 53, loss: 0.0000293 mean_absolute_error: 0.0014773
Epoch: 12, Batch: 76, loss: 0.0000269 mean_absolute_error: 0.0017283
Epoch: 13, Batch: 49, loss: 0.0000258 mean_absolute_error: 0.0015440
Epoch: 14, Batch: 72, loss: 0.0000240 mean_absolute_error: 0.0016460
Epoch: 15, Batch: 45, loss: 0.0000232 mean_absolute_error: 0.0014592
Epoch: 16, Batch: 68, loss: 0.0000217 mean_absolute_error: 0.0015306
Epoch: 17, Batch: 41, loss: 0.0000211 mean_absolute_error: 0.0014993
Epoch: 18, Batch: 64, loss: 0.0000200 mean_absolute_error: 0.0015089
Epoch: 19, Batch: 37, loss: 0.0000195 mean_absolute_error: 0.0014502
Epoch: 20, Batch: 60, loss: 0.0000185 mean_absolute_error: 0.0014855
Epoch: 21, Batch: 33, loss: 0.0000181 mean_absolute_error: 0.0014808
Epoch: 22, Batch: 56, loss: 0.0000173 mean_absolute_error: 0.0013899
Epoch: 23, Batch: 29, loss: 0.0000170 mean_absolute_error: 0.0014624
Epoch: 24, Batch: 52, loss: 0.0000163 mean_absolute_error: 0.0013806
Epoch: 25, Batch: 75, loss: 0.0000158 mean_absolute_error: 0.0016300
Epoch: 26, Batch: 48, loss: 0.0000155 mean_absolute_error: 0.0013704
Epoch: 27, Batch: 71, loss: 0.0000150 mean_absolute_error: 0.0015049
Epoch: 28, Batch: 44, loss: 0.0000148 mean_absolute_error: 0.0012911
Epoch: 29, Batch: 67, loss: 0.0000144 mean_absolute_error: 0.0016583
Epoch: 30, Batch: 40, loss: 0.0000143 mean_absolute_error: 0.0011887
Epoch: 31, Batch: 63, loss: 0.0000140 mean_absolute_error: 0.0019614
Epoch: 32, Batch: 36, loss: 0.0000140 mean_absolute_error: 0.0030620
Epoch: 33, Batch: 59, loss: 0.0000137 mean_absolute_error: 0.0011744
Epoch: 34, Batch: 32, loss: 0.0000135 mean_absolute_error: 0.0012163
Epoch: 35, Batch: 55, loss: 0.0000132 mean_absolute_error: 0.0012842
Epoch: 36, Batch: 28, loss: 0.0000130 mean_absolute_error: 0.0010907
Epoch: 37, Batch: 51, loss: 0.0000127 mean_absolute_error: 0.0017228
Epoch: 38, Batch: 74, loss: 0.0000126 mean_absolute_error: 0.0021361
Epoch: 39, Batch: 47, loss: 0.0000125 mean_absolute_error: 0.0020527
Epoch: 40, Batch: 70, loss: 0.0000125 mean_absolute_error: 0.0028232
Epoch: 41, Batch: 43, loss: 0.0000129 mean_absolute_error: 0.0050399
Epoch: 42, Batch: 66, loss: 0.0000132 mean_absolute_error: 0.0044193
Epoch: 43, Batch: 39, loss: 0.0000138 mean_absolute_error: 0.0064203
Epoch: 44, Batch: 62, loss: 0.0000145 mean_absolute_error: 0.0049469
Epoch: 45, Batch: 35, loss: 0.0000149 mean_absolute_error: 0.0054889
Epoch: 46, Batch: 58, loss: 0.0000151 mean_absolute_error: 0.0037215
Epoch: 47, Batch: 31, loss: 0.0000150 mean_absolute_error: 0.0026493
Epoch: 48, Batch: 54, loss: 0.0000148 mean_absolute_error: 0.0013843
Epoch: 49, Batch: 27, loss: 0.0000147 mean_absolute_error: 0.0006616
%{
"conv_0" => %{
"bias" => #Nx.Tensor<
f32[32]
EXLA.Backend
[-0.008091067895293236, -0.016978450119495392, -0.004820083733648062, -0.00792402308434248, -0.017319275066256523, 0.03160012513399124, -0.006478423718363047, -0.011876466684043407, 0.040104787796735764, -0.010837228037416935, -0.012144810520112514, -0.005171791650354862, -0.008013750426471233, -0.00792587362229824, -0.016933709383010864, -0.01142617966979742, -0.012313197366893291, -0.02336074225604534, -0.013391278684139252, -0.005351855419576168, -0.004199669696390629, -0.01840357296168804, 0.017371967434883118, -0.006281338166445494, -0.01266816072165966, -0.010549230501055717, -0.0077864243648946285, -0.009549328126013279, -0.0049909790977835655, -0.005896904040127993, -0.007447496056556702, -0.011414170265197754]
>,
"kernel" => #Nx.Tensor<
f32[30][1][32]
EXLA.Backend
[
[
[-0.05040709674358368, -0.021235600113868713, -0.010630444623529911, -0.019848855212330818, -0.040217604488134384, -0.12504053115844727, -0.02118268609046936, 0.046331584453582764, 0.03421409800648689, 0.06607695668935776, 0.12348750978708267, -0.12631401419639587, 0.019740285351872444, -0.008459499105811119, 0.03768952935934067, 0.06856438517570496, -5.624796031042933e-4, -0.04322008043527603, 0.06655385345220566, -0.04173768684267998, -0.12437783926725388, -0.03484911471605301, -0.10082294046878815, -0.06766377389431, 0.034250643104314804, 0.05543597787618637, -0.01215132512152195, -0.06139170378446579, 0.004699850454926491, 0.06339087337255478, -0.05598803982138634, 0.05522901192307472]
],
[
[0.00991822686046362, -0.0031699652317911386, -0.0019580740481615067, 0.055364325642585754, -0.02978527545928955, -0.006119378376752138, 0.002985905623063445, -0.035792168229818344, -0.04653836041688919, 0.003442571498453617, 0.12638360261917114, -0.11723502725362778, -0.0502297542989254, -0.13008807599544525, 0.056758686900138855, ...]
],
...
]
>
},
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[32]
EXLA.Backend
[-0.03525220975279808, -0.016225244849920273, -0.028320161625742912, -0.02871190384030342, -0.013132854364812374, -0.012726775370538235, -0.007784419227391481, -0.022564707323908806, 0.0, -0.007780132349580526, -0.02287977933883667, 0.020672377198934555, -0.030719507485628128, -0.014918905682861805, 0.002109799301251769, -0.014595966786146164, -0.02758587710559368, -0.01325188297778368, -0.004173007793724537, -0.03749808669090271, -0.009482248686254025, -0.01655416563153267, -0.021616848185658455, -0.021459853276610374, -0.018326226621866226, -0.007884755730628967, -0.01816698908805847, -0.03544246032834053, -0.006969088688492775, -0.007480396889150143, 0.0022076379973441362, -0.032438069581985474]
>,
"kernel" => #Nx.Tensor<
f32[32][32]
EXLA.Backend
[
[-0.17533017694950104, 0.012447532266378403, 0.11476895958185196, -0.2081768661737442, -0.009203040972352028, 0.18192830681800842, -0.131138876080513, -0.2810400426387787, 0.19938461482524872, 0.2786736488342285, -0.27819138765335083, -0.2636408805847168, -0.1670035719871521, -0.09961113333702087, -0.146389439702034, 0.15396006405353546, -0.23006494343280792, 0.024088190868496895, -0.21437454223632812, 0.2188781350851059, -0.2704562842845917, -0.24468927085399628, 0.17817959189414978, -0.23606330156326294, -0.14990635216236115, -0.0874180942773819, 0.13848312199115753, -0.07637172192335129, -0.006744588725268841, -0.0221336018294096, 0.12596826255321503, -0.0856495127081871],
[-0.1696496605873108, -0.09901808202266693, -0.041791852563619614, -0.29172369837760925, 0.10078410804271698, 0.28704243898391724, 0.17552822828292847, -0.20169097185134888, 0.048185210675001144, 0.2521952986717224, 0.16852286458015442, -0.05647372454404831, -0.09328987449407578, -0.1908075213432312, ...],
...
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[1]
EXLA.Backend
[0.003080096561461687]
>,
"kernel" => #Nx.Tensor<
f32[32][1]
EXLA.Backend
[
[-0.020385151728987694],
[0.1855364888906479],
[0.07487309724092484],
[-0.36159875988960266],
[0.2996695339679718],
[-0.14941802620887756],
[0.0248686745762825],
[-0.08696726709604263],
[0.025274399667978287],
[-0.06669101864099503],
[0.17602013051509857],
[0.36575624346733093],
[-0.13954167068004608],
[-0.08473435044288635],
[-0.29618003964424133],
[-0.1217213124036789],
[0.005105822812765837],
[0.00423429673537612],
[-0.05834190174937248],
[0.11214695125818253],
[0.07910924404859543],
[-0.3471531271934509],
[-0.003867669031023979],
[-3.470616138656624e-5],
[0.08255864679813385],
[0.38870716094970703],
[-0.11077144742012024],
[0.08462876826524734],
[0.018228236585855484],
[-0.2302139550447464],
[-0.19713029265403748],
[0.33921128511428833]
]
>
}
}
cnn_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.run(single_step_test_data, cnn_trained_model_state, compiler: EXLA)
06:56:44.326 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 13, mean_absolute_error: 0.0043546
%{
0 => %{
"mean_absolute_error" => #Nx.Tensor<
f32
EXLA.Backend
0.004354595206677914
>
}
}
0.0035810 * :math.sqrt(Explorer.Series.variance(aapl_df["Close"])) +
Explorer.Series.mean(aapl_df["Close"])
64.82237670703707
defmodule Analysis do
def visualize_predictions(
model,
model_state,
prices,
window_size,
target_window_size,
batch_size
) do
{_, predict_fn} = Axon.build(model, compiler: EXLA)
windows =
prices
|> Data.window(window_size, target_window_size)
|> Data.batch(batch_size)
|> Stream.map(&elem(&1, 0))
predicted =
Enum.flat_map(windows, fn window ->
predict_fn.(model_state, window) |> Nx.to_flat_list()
end)
predicted = [nil, nil, nil, nil, nil, nil, nil, nil, nil, nil | predicted]
types =
List.duplicate("AAPL", length(prices)) ++
List.duplicate("Predicted", length(prices))
days =
Enum.to_list(0..(length(prices) - 1)) ++
Enum.to_list(0..(length(prices) - 1))
prices = prices ++ predicted
plot(
%{
"day" => days,
"prices" => prices,
"types" => types
},
"AAPL Stock Price vs. Predicted, RNN Single-Shot"
)
end
defp plot(values, title) do
Vl.new(title: title, width: 1280, height: 720)
|> Vl.data_from_values(values)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "day", type: :temporal)
|> Vl.encode_field(:y, "prices", type: :quantitative)
|> Vl.encode_field(:color, "types", type: :nominal)
|> Kino.VegaLite.new()
end
end
{:module, Analysis, <<70, 79, 82, 49, 0, 0, 16, ...>>, {:plot, 2}}
Analysis.visualize_predictions(
cnn_model,
cnn_trained_model_state,
Explorer.Series.to_list(normalized_aapl_df["Close"]),
window_size,
1,
batch_size
)
Using RNNs for Time Series Prediction
rnn_model =
Axon.input("stock_prices")
|> Axon.lstm(32)
|> elem(0)
|> Axon.nx(& &1[[0..-1//1, -1, 0..-1//1]])
|> Axon.dense(1)
#Axon<
inputs: %{"stock_prices" => nil}
outputs: "dense_0"
nodes: 8
>
rnn_trained_model_state =
rnn_model
|> Axon.Loop.trainer(:mean_squared_error, :adam)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.run(single_step_train_data, %{}, epochs: 50, compiler: EXLA)
07:04:31.866 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 50, loss: 0.0000897 mean_absolute_error: 0.0073940
Epoch: 1, Batch: 73, loss: 0.0001568 mean_absolute_error: 0.0097099
Epoch: 2, Batch: 46, loss: 0.0001567 mean_absolute_error: 0.0086432
Epoch: 3, Batch: 69, loss: 0.0001245 mean_absolute_error: 0.0055883
Epoch: 4, Batch: 42, loss: 0.0001146 mean_absolute_error: 0.0052446
Epoch: 5, Batch: 65, loss: 0.0000927 mean_absolute_error: 0.0029908
Epoch: 6, Batch: 38, loss: 0.0000843 mean_absolute_error: 0.0022326
Epoch: 7, Batch: 61, loss: 0.0000710 mean_absolute_error: 0.0013697
Epoch: 8, Batch: 34, loss: 0.0000658 mean_absolute_error: 0.0011014
Epoch: 9, Batch: 57, loss: 0.0000575 mean_absolute_error: 0.0012567
Epoch: 10, Batch: 30, loss: 0.0000542 mean_absolute_error: 0.0013473
Epoch: 11, Batch: 53, loss: 0.0000486 mean_absolute_error: 0.0013981
Epoch: 12, Batch: 76, loss: 0.0000442 mean_absolute_error: 0.0015796
Epoch: 13, Batch: 49, loss: 0.0000423 mean_absolute_error: 0.0015473
Epoch: 14, Batch: 72, loss: 0.0000391 mean_absolute_error: 0.0016943
Epoch: 15, Batch: 45, loss: 0.0000377 mean_absolute_error: 0.0017310
Epoch: 16, Batch: 68, loss: 0.0000352 mean_absolute_error: 0.0018595
Epoch: 17, Batch: 41, loss: 0.0000342 mean_absolute_error: 0.0020594
Epoch: 18, Batch: 64, loss: 0.0000323 mean_absolute_error: 0.0021272
Epoch: 19, Batch: 37, loss: 0.0000316 mean_absolute_error: 0.0026239
Epoch: 20, Batch: 60, loss: 0.0000302 mean_absolute_error: 0.0025557
Epoch: 21, Batch: 33, loss: 0.0000297 mean_absolute_error: 0.0034725
Epoch: 22, Batch: 56, loss: 0.0000286 mean_absolute_error: 0.0029877
Epoch: 23, Batch: 29, loss: 0.0000283 mean_absolute_error: 0.0040037
Epoch: 24, Batch: 52, loss: 0.0000275 mean_absolute_error: 0.0031490
Epoch: 25, Batch: 75, loss: 0.0000267 mean_absolute_error: 0.0028734
Epoch: 26, Batch: 48, loss: 0.0000265 mean_absolute_error: 0.0030775
Epoch: 27, Batch: 71, loss: 0.0000258 mean_absolute_error: 0.0027187
Epoch: 28, Batch: 44, loss: 0.0000255 mean_absolute_error: 0.0029207
Epoch: 29, Batch: 67, loss: 0.0000248 mean_absolute_error: 0.0025301
Epoch: 30, Batch: 40, loss: 0.0000245 mean_absolute_error: 0.0028126
Epoch: 31, Batch: 63, loss: 0.0000239 mean_absolute_error: 0.0024180
Epoch: 32, Batch: 36, loss: 0.0000236 mean_absolute_error: 0.0028866
Epoch: 33, Batch: 59, loss: 0.0000230 mean_absolute_error: 0.0024546
Epoch: 34, Batch: 32, loss: 0.0000228 mean_absolute_error: 0.0030757
Epoch: 35, Batch: 55, loss: 0.0000223 mean_absolute_error: 0.0025202
Epoch: 36, Batch: 28, loss: 0.0000221 mean_absolute_error: 0.0032313
Epoch: 37, Batch: 51, loss: 0.0000217 mean_absolute_error: 0.0026088
Epoch: 38, Batch: 74, loss: 0.0000213 mean_absolute_error: 0.0025369
Epoch: 39, Batch: 47, loss: 0.0000211 mean_absolute_error: 0.0027412
Epoch: 40, Batch: 70, loss: 0.0000208 mean_absolute_error: 0.0026344
Epoch: 41, Batch: 43, loss: 0.0000207 mean_absolute_error: 0.0029186
Epoch: 42, Batch: 66, loss: 0.0000203 mean_absolute_error: 0.0026540
Epoch: 43, Batch: 39, loss: 0.0000203 mean_absolute_error: 0.0030953
Epoch: 44, Batch: 62, loss: 0.0000200 mean_absolute_error: 0.0026985
Epoch: 45, Batch: 35, loss: 0.0000199 mean_absolute_error: 0.0033785
Epoch: 46, Batch: 58, loss: 0.0000196 mean_absolute_error: 0.0028011
Epoch: 47, Batch: 31, loss: 0.0000196 mean_absolute_error: 0.0035820
Epoch: 48, Batch: 54, loss: 0.0000193 mean_absolute_error: 0.0028405
Epoch: 49, Batch: 27, loss: 0.0000193 mean_absolute_error: 0.0037105
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[1]
EXLA.Backend
[-0.0022059190087020397]
>,
"kernel" => #Nx.Tensor<
f32[32][1]
EXLA.Backend
[
[0.45946452021598816],
[-4.5703467912971973e-4],
[0.23180651664733887],
[-0.005186113528907299],
[0.2575457692146301],
[-0.04142376407980919],
[-0.016365909948945045],
[-0.059576015919446945],
[0.31844404339790344],
[-0.008791795000433922],
[0.39062708616256714],
[-0.06893440335988998],
[0.055386751890182495],
[0.002047080546617508],
[-0.07389257848262787],
[-0.08267084509134293],
[0.02030925452709198],
[-0.03783267363905907],
[0.29332268238067627],
[7.742071175016463e-4],
[0.37041202187538147],
[-0.004578940570354462],
[0.3596736788749695],
[0.33445674180984497],
[-0.28043803572654724],
[-0.3852922320365906],
[0.004091555718332529],
[-0.005642038304358721],
[0.014751603826880455],
[0.12446080893278122],
[-0.09465767443180084],
[-0.007954136468470097]
]
>
},
"lstm_0" => %{
"bias" => {#Nx.Tensor<
f32[32]
EXLA.Backend
[0.11335168778896332, -0.06401084363460541, 0.031347475945949554, -0.10822173953056335, 0.027776755392551422, -0.18160955607891083, -0.1230732873082161, -0.1875688135623932, -0.06549829989671707, -0.17022982239723206, -0.0032370283734053373, -0.22780485451221466, -0.1940574198961258, -0.10790818184614182, -0.08154703676700592, -0.2168857604265213, -0.16268092393875122, -0.1510622799396515, -0.11275982856750488, -0.053832173347473145, -0.0012655003229156137, -0.10015828162431717, 0.006671776529401541, -0.0446050763130188, -0.0055354926735162735, 0.03272601217031479, -0.07376735657453537, -0.11262574046850204, -0.054739225655794144, -0.23879562318325043, -0.11311056464910507, -0.17760637402534485]
>,
#Nx.Tensor<
f32[32]
EXLA.Backend
[0.10780646651983261, -0.06285148113965988, 0.023584432899951935, -0.10630014538764954, 0.022825146093964577, -0.17115405201911926, -0.11822430044412613, -0.17694641649723053, -0.0663384273648262, -0.16032005846500397, -0.010648874565958977, -0.21266703307628632, -0.18182429671287537, -0.10441649705171585, -0.08326958119869232, -0.209574893116951, -0.15406334400177002, -0.14615078270435333, -0.1299838423728943, -0.05346710607409477, -0.006063435692340136, -0.09791524708271027, 6.202268414199352e-4, -0.05215458944439888, -0.008472379297018051, 0.027613848447799683, -0.07298403233289719, -0.10905328392982483, -0.05714934691786766, -0.22263658046722412, -0.11449617892503738, -0.16627191007137299]
>,
#Nx.Tensor<
f32[32]
EXLA.Backend
[0.003731300588697195, 4.8132456140592694e-4, 0.005429568700492382, -0.003417501924559474, 0.004267899785190821, 0.0023204206954687834, -0.003725315211340785, 0.00241837534122169, 0.0011554744560271502, 0.00141581567004323, 0.0017936892108991742, 9.837766410782933e-4, -0.0013038406614214182, 0.0018337997607886791, -0.002206782577559352, -6.189559935592115e-4, -0.001890703453682363, -0.0014234009431675076, 0.003332881722599268, 6.696510245092213e-4, 0.0024242501240223646, -0.001337566296570003, 0.003793303854763508, 0.0028332320507615805, -0.001609275583177805, -0.0017298919847235084, 0.0016395390266552567, -3.6149637890048325e-4, 0.0016541111981496215, -3.5133626079186797e-4, -0.002388330176472664, -0.004853168968111277]
>,
#Nx.Tensor<
f32[32]
EXLA.Backend
[0.1105206236243248, -0.065388984978199, 0.027362428605556488, -0.11799526959657669, 0.025150544941425323, -0.17694349586963654, -0.12093527615070343, -0.18315498530864716, -0.06615359336137772, -0.16683995723724365, -0.007195259910076857, -0.22143132984638214, -0.18868662416934967, -0.1067797839641571, -0.083305224776268, -0.21811549365520477, -0.15893080830574036, -0.14961159229278564, -0.12222269177436829, -0.05434230715036392, -0.0038143477868288755, -0.09983281791210175, 0.0035336518194526434, -0.0486055389046669, -0.007181230932474136, 0.030025919899344444, -0.0738983154296875, -0.11130935698747635, -0.05862840265035629, -0.23184144496917725, -0.11468596011400223, -0.17238789796829224]
>},
"hidden_kernel" => {#Nx.Tensor<
f32[32][32]
EXLA.Backend
[
[-0.11421269923448563, 0.11127666383981705, -0.184355691075325, -0.27127766609191895, -0.06881655752658844, 0.10430511832237244, -0.2448892444372177, -0.2653137147426605, 0.06281697005033493, -0.013854310847818851, -0.006992769427597523, 0.13280320167541504, 0.14791104197502136, -0.1784118115901947, -0.06113160401582718, 0.1929483562707901, 0.04987379536032677, 0.2280454933643341, -0.19724753499031067, 0.15979085862636566, 0.27148231863975525, -0.2724880576133728, -0.1679612696170807, 0.21920432150363922, -0.23523777723312378, 0.31118467450141907, -0.2611495852470398, -0.065000019967556, 0.10206960141658783, 0.025555908679962158, 0.19122368097305298, 0.0023696101270616055],
[-0.027491245418787003, -0.2316892296075821, 0.015483269467949867, 0.02644273452460766, -0.23044149577617645, 0.122772216796875, -0.20443205535411835, -0.04793649911880493, -0.2743169069290161, -0.21698445081710815, 0.17927387356758118, 0.08292767405509949, 0.20277370512485504, ...],
...
]
>,
#Nx.Tensor<
f32[32][32]
EXLA.Backend
[
[-0.11538577824831009, 0.1112610325217247, -0.1845649778842926, -0.27134761214256287, -0.06905922293663025, 0.10186973214149475, -0.24537153542041779, -0.2668277323246002, 0.061841294169425964, -0.014799466356635094, -0.0065930732525885105, 0.12995310127735138, 0.14654433727264404, -0.17863759398460388, -0.06125305965542793, 0.1921101212501526, 0.049221768975257874, 0.2271725982427597, -0.197763130068779, 0.15969622135162354, 0.27121856808662415, -0.2727193534374237, -0.16848579049110413, 0.21903260052204132, -0.23458200693130493, 0.31177598237991333, -0.26127612590789795, -0.06564890593290329, 0.10203193873167038, 0.02351302094757557, 0.19066844880580902, 8.716363226994872e-4],
[-0.02810082957148552, -0.23168812692165375, 0.015438610687851906, 0.026450933888554573, -0.23074249923229218, 0.1225915178656578, -0.20462723076343536, -0.048075538128614426, -0.2743925154209137, -0.21695131063461304, 0.17911174893379211, 0.08302958309650421, ...],
...
]
>,
#Nx.Tensor<
f32[32][32]
EXLA.Backend
[
[0.008229146711528301, 0.1445600688457489, -0.10443291813135147, -0.23759327828884125, 0.014250179752707481, -0.02108207531273365, -0.19692735373973846, -0.379713773727417, 0.10434218496084213, -0.10935195535421371, 0.07177723199129105, -4.993749316781759e-4, 0.19762060046195984, -0.2601209580898285, -0.1646450161933899, 0.09917865693569183, 0.11334932595491409, 0.12844809889793396, -0.10628209263086319, 0.1521795243024826, 0.34273761510849, -0.22396697103977203, -0.08686701953411102, 0.28526702523231506, -0.358896940946579, 0.19280748069286346, -0.3215658664703369, -0.03377417474985123, 0.14930129051208496, 0.06669514626264572, 0.08144361525774002, 0.042655251920223236],
[-0.08700522035360336, -0.29304251074790955, -0.04271193593740463, -0.02882186509668827, -0.2872234284877777, 0.19526077806949615, -0.2664906084537506, 0.023175811395049095, -0.3267883062362671, -0.15051718056201935, 0.11813374608755112, ...],
...
]
>,
#Nx.Tensor<
f32[32][32]
EXLA.Backend
[
[-0.12982815504074097, 0.11378036439418793, -0.18149970471858978, -0.2677125632762909, -0.07103432714939117, 0.11825311928987503, -0.2336365133523941, -0.2543303370475769, 0.07081130146980286, 0.0022928505204617977, 0.0032019137870520353, 0.15594728291034698, 0.1622438281774521, -0.17180950939655304, -0.04690863564610481, 0.2031450867652893, 0.059193458408117294, 0.24310606718063354, -0.19648565351963043, 0.16401754319667816, 0.27107858657836914, -0.2651662528514862, -0.1745850145816803, 0.2315697968006134, -0.23139923810958862, 0.3055410087108612, -0.25599977374076843, -0.04990411549806595, 0.10681464523077011, 0.043445780873298645, 0.20912222564220428, 0.018217815086245537],
[-0.03424784541130066, -0.23208343982696533, 0.009599815122783184, 0.026051534339785576, -0.23864002525806427, 0.12222269177436829, -0.20658233761787415, -0.04860348626971245, -0.28558629751205444, -0.21860875189304352, ...],
...
]
>},
"input_kernel" => {#Nx.Tensor<
f32[1][32]
EXLA.Backend
[
[0.15832173824310303, -0.06945277750492096, 0.29155898094177246, -0.1992305964231491, 0.13684962689876556, 0.3534727692604065, -0.21600008010864258, 0.22524043917655945, 0.01816789247095585, 0.23636147379875183, 0.059991706162691116, 0.21631649136543274, -0.0855906531214714, 0.25307703018188477, 0.007452141959220171, 0.06080798804759979, -0.14563453197479248, 0.09734586626291275, 0.19660542905330658, 0.09763676673173904, 0.09367682784795761, -0.12293204665184021, 0.22191683948040009, 0.1307179480791092, -0.09499726444482803, -0.4519442617893219, 0.19236356019973755, -0.0888831838965416, 0.11068467795848846, -0.04182599112391472, -0.059468429535627365, -0.30685484409332275]
]
>,
#Nx.Tensor<
f32[1][32]
EXLA.Backend
[
[0.15510421991348267, -0.06956983357667923, 0.2913159430027008, -0.19954615831375122, 0.1365644782781601, 0.34723758697509766, -0.21715417504310608, 0.219896599650383, 0.018290145322680473, 0.23421815037727356, 0.06030489131808281, 0.21138715744018555, -0.08892093598842621, 0.2523980140686035, 0.007690330035984516, 0.05870341509580612, -0.14841823279857635, 0.09678869694471359, 0.1968754231929779, 0.09719568490982056, 0.09413376450538635, -0.12356305867433548, 0.22244440019130707, 0.1320788860321045, -0.09404245018959045, -0.4512008726596832, 0.19207976758480072, -0.08969637751579285, 0.11066418141126633, -0.04511291906237602, -0.05949528515338898, -0.3111755847930908]
]
>,
#Nx.Tensor<
f32[1][32]
EXLA.Backend
[
[0.5210095047950745, -0.04471205919981003, 0.5679916143417358, -0.24680404365062714, 0.42526307702064514, 0.10860466212034225, -0.1759696751832962, -0.01207458134740591, 0.21903882920742035, 0.051724907010793686, 0.30118006467819214, -0.03016113117337227, -0.004351356066763401, 0.12323488295078278, -0.2007794976234436, -0.13582377135753632, -0.060969483107328415, -0.09213265776634216, 0.4137095510959625, 0.08238348364830017, 0.3602970838546753, -0.0882539451122284, 0.5199012756347656, 0.3512928783893585, -0.25873085856437683, -0.5727313160896301, 0.12393758445978165, -0.06665874272584915, 0.23254042863845825, 0.04973064363002777, -0.27502191066741943, -0.26436713337898254]
]
>,
#Nx.Tensor<
f32[1][32]
EXLA.Backend
[
[0.14472824335098267, -0.0628005787730217, 0.27467724680900574, -0.1878383904695511, 0.1254066377878189, 0.36361345648765564, -0.211454838514328, 0.2393423318862915, 0.014579376205801964, 0.24944588541984558, 0.048402633517980576, 0.22778557240962982, -0.07714307308197021, 0.2585679888725281, 0.006932367570698261, 0.07686306536197662, -0.13720948994159698, 0.10166746377944946, 0.18309403955936432, 0.11173634976148605, 0.08637332171201706, -0.11274421215057373, 0.2134648561477661, 0.11872217804193497, -0.10573867708444595, -0.4594239592552185, 0.1963488757610321, -0.08307348936796188, 0.11799244582653046, -0.033315449953079224, -0.05906962230801582, -0.2986782193183899]
]
>}
},
"lstm__c_hidden_state" => %{
"key" => #Nx.Tensor<
u32[2]
EXLA.Backend
[392164166, 4227701866]
>
},
"lstm__h_hidden_state" => %{
"key" => #Nx.Tensor<
u32[2]
EXLA.Backend
[392164166, 4227701866]
>
}
}
0.0032470 * :math.sqrt(Explorer.Series.variance(aapl_df["Close"])) +
Explorer.Series.mean(aapl_df["Close"])
64.80750153333416
Analysis.visualize_predictions(
rnn_model,
rnn_trained_model_state,
Explorer.Series.to_list(normalized_aapl_df["Close"]),
30,
1,
32
)