FizzBuzz Neural Network
Mix.install([
{:nx, "~> 0.7.2"},
{:axon, "~> 0.6.1"},
{:exla, "~> 0.7.2"},
{:kino, "~> 0.12.3"}
])
FizzBuzz
graph LR;
n--/3==0-->fizz;
n--/5==0-->buzz;
n--/3==0 and /5==0 -->fizz_buzz;
fizz_buzz = fn
n when rem(n, 15) == 0 -> [0, 0, 1, 0]
n when rem(n, 3) == 0 -> [1, 0, 0, 0]
n when rem(n, 5) == 0 -> [0, 1, 0, 0]
_ -> [0, 0, 0, 0]
end
{fizz_buzz.(1), fizz_buzz.(3), fizz_buzz.(5), fizz_buzz.(15)}
{[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]}
fizz_buzz = fn n ->
cond do
rem(n, 15) == 0 -> [0, 0, 1, 0]
rem(n, 3) == 0 -> [1, 0, 0, 0]
rem(n, 5) == 0 -> [0, 1, 0, 0]
true -> [0, 0, 0, 0]
end
end
{fizz_buzz.(1), fizz_buzz.(3), fizz_buzz.(5), fizz_buzz.(15)}
{[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]}
Machine Learning
graph LR;
subgraph Input
A((3))
end
subgraph Prediction
B((fizz?))
end
A -."weight=0.2".-> B
weight = 0.2
predict = fn input -> input * weight end
predict.(3)
0.6000000000000001
graph LR;
subgraph Input
A((3)) --/3--> A1((0))
A --/5--> A2((3))
end
subgraph Prediction
B((fizz?))
end
A1 -."weight=0.2".-> B
A2 -."weight=0.7".-> B
mods = fn x -> [rem(x, 3), rem(x, 5)] end
vector = mods.(3)
w_sum = fn input, weights ->
input
|> Enum.with_index()
|> Enum.reduce(0, fn {input, i}, acc ->
weight = Enum.at(weights, i)
acc + input * weight
end)
end
predict = fn input, weights ->
w_sum.(input, weights)
end
predict.([0.0, 3.0], [0.2, 0.7])
2.0999999999999996
graph LR;
A((3)) --w=0.7--> B[fizz?]
A --w=0.4--> C[buzz?]
A --w=0.5--> D[fizz_buzz?]
ele_mul = fn number, vector ->
vector
|> Enum.map(fn item ->
number * item
end)
end
predict = fn input, weights ->
ele_mul.(input, weights)
end
input = 3.0
weights = [0.7, 0.4, 0.5]
predict.(input, weights)
[2.0999999999999996, 1.2000000000000002, 1.5]
graph LR;
I1((0))-->O1[fizz?]
I1-->O2[buzz?]
I1-->O3[fizz_buzz?]
I2((3)) --> O1
I2 --> O2
I2 --> O3
linkStyle 3 stroke:#0f0,stroke-width:1px;
linkStyle 4 stroke:#0f0,stroke-width:1px;
linkStyle 5 stroke:#0f0,stroke-width:1px;
vect_mat_mul = fn vector, matrix ->
matrix
|> Enum.map(fn item ->
w_sum.(vector, item)
end)
end
predict = fn input, weights ->
vect_mat_mul.(input, weights)
end
weights = [[0.2, 0.3], [0.7, 0.1], [0.5, 0.1]]
input = [0.0, 3.0]
predict.(input, weights)
[0.8999999999999999, 0.30000000000000004, 0.30000000000000004]
softmax = fn inputs ->
sum = Enum.sum(inputs)
Enum.map(inputs, fn n -> n / sum end)
end
softmax.([0.8999999999999999, 0.30000000000000004, 0.30000000000000004])
[0.6, 0.20000000000000004, 0.20000000000000004]
graph LR;
I1((0))-->L1((0.6))
I1-->L2((0.2))
I1-->L3((0.2))
I2((3)) --> L1
I2 --> L2
I2 --> L3
L1 --> O1((1.0))
L2 --> O2((0.0))
L3 --> O3((0.0))
linkStyle 3 stroke:#0f0,stroke-width:1px;
linkStyle 4 stroke:#0f0,stroke-width:1px;
linkStyle 5 stroke:#0f0,stroke-width:1px;
Data
train_data =
1..1000
|> Stream.map(fn n ->
tensor = Nx.tensor([mods.(n)])
label = Nx.tensor([fizz_buzz.(n)])
{tensor, label}
end)
#Stream<[enum: 1..1000, funs: [#Function<50.38948127/1 in Stream.map/2>]]>
Model
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(10, activation: :tanh)
|> Axon.dense(4, activation: :softmax)
#Axon<
inputs: %{"input" => {nil, 2}}
outputs: "softmax_0"
nodes: 5
>
template = Nx.template({1, 2}, :s32)
Axon.Display.as_graph(model, template)
Training
trained_model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)
19:35:43.337 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 950, loss: 0.3323180
Epoch: 1, Batch: 950, loss: 0.2097961
Epoch: 2, Batch: 950, loss: 0.1507273
Epoch: 3, Batch: 950, loss: 0.1159618
Epoch: 4, Batch: 950, loss: 0.0937235
Epoch: 5, Batch: 950, loss: 0.0784553
Epoch: 6, Batch: 950, loss: 0.0673833
Epoch: 7, Batch: 950, loss: 0.0590107
Epoch: 8, Batch: 950, loss: 0.0524685
Epoch: 9, Batch: 950, loss: 0.0472216
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[10]
EXLA.Backend
[0.007045082747936249, 0.6407318115234375, 0.3835016191005707, -0.7641347050666809, 0.588534951210022, -0.49212053418159485, 0.6598906517028809, 0.4348321259021759, -0.6103333830833435, 0.7605070471763611]
>,
"kernel" => #Nx.Tensor<
f32[2][10]
EXLA.Backend
[
[-0.4310312271118164, 1.1210333108901978, 0.8010596036911011, 1.7776334285736084, -1.3366260528564453, 1.6495720148086548, -1.4637233018875122, -0.579444944858551, 0.7187064290046692, 0.9409835338592529],
[-0.15303368866443634, -1.3053309917449951, -1.0788722038269043, -0.537500262260437, -1.1483451128005981, 0.7073906660079956, -0.028834063559770584, 1.2421365976333618, 1.7351068258285522, -1.694563627243042]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
EXLA.Backend
[0.07947669178247452, -0.17550566792488098, 0.2959376275539398, -1.2535045146942139]
>,
"kernel" => #Nx.Tensor<
f32[10][4]
EXLA.Backend
[
[0.10124798119068146, -0.6136413812637329, 0.43526801466941833, 0.7222546935081482],
[-1.0806410312652588, 0.7023622393608093, 0.8863538503646851, -0.6920934319496155],
[-1.474867820739746, 1.2845275402069092, 0.9321955442428589, -0.33254721760749817],
[-0.9582555890083313, 1.3919050693511963, -0.7911892533302307, 0.8373909592628479],
[-0.4029693603515625, -0.9608103632926941, 1.9728877544403076, 0.24115674197673798],
[0.8901402950286865, 1.4944347143173218, -1.8724830150604248, -0.4015873968601227],
[0.2996208071708679, -1.1419256925582886, 0.9844873547554016, -0.2348821610212326],
[0.755868673324585, -0.8567834496498108, 0.28964880108833313, -1.5673366785049438],
[1.1101288795471191, -0.17633414268493652, -1.2638652324676514, -0.7720271348953247],
[-0.8016149401664734, 0.9841588735580444, 1.231026291847229, -0.46669310331344604]
]
>
}
}
Prediction
{_init_fn, predict_fn} = Axon.build(model)
{#Function<134.70434077/2 in Nx.Defn.Compiler.fun/2>,
#Function<134.70434077/2 in Nx.Defn.Compiler.fun/2>}
guess = fn x ->
mod = Nx.tensor([mods.(x)])
case predict_fn.(trained_model_state, mod) |> Nx.argmax() |> Nx.to_flat_list() do
[0] -> "fizz"
[1] -> "buzz"
[2] -> "fizz_buzz"
[3] -> "womp"
end
end
#Function<42.39164016/1 in :erl_eval.expr/6>
guess.(3) |> IO.inspect(label: "3")
guess.(5) |> IO.inspect(label: "5")
guess.(15) |> IO.inspect(label: "15")
guess.(16) |> IO.inspect(label: "16")
3: "fizz"
5: "buzz"
15: "fizz_buzz"
16: "buzz"
"buzz"