Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

FizzBuzz Neural Network

fizz-buzz-neural-network.livemd

FizzBuzz Neural Network

Mix.install([
  {:nx, "~> 0.7.2"},
  {:axon, "~> 0.6.1"},
  {:exla, "~> 0.7.2"},
  {:kino, "~> 0.12.3"}
])

FizzBuzz

graph LR;
  n--/3==0-->fizz;
  n--/5==0-->buzz;
  n--/3==0 and /5==0 -->fizz_buzz;
fizz_buzz = fn
  n when rem(n, 15) == 0 -> [0, 0, 1, 0]
  n when rem(n, 3) == 0 -> [1, 0, 0, 0]
  n when rem(n, 5) == 0 -> [0, 1, 0, 0]
  _ -> [0, 0, 0, 0]
end

{fizz_buzz.(1), fizz_buzz.(3), fizz_buzz.(5), fizz_buzz.(15)}
{[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]}
fizz_buzz = fn n ->
  cond do
    rem(n, 15) == 0 -> [0, 0, 1, 0]
    rem(n, 3) == 0 -> [1, 0, 0, 0]
    rem(n, 5) == 0 -> [0, 1, 0, 0]
    true -> [0, 0, 0, 0]
  end
end
{fizz_buzz.(1), fizz_buzz.(3), fizz_buzz.(5), fizz_buzz.(15)}
{[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]}

Machine Learning

graph LR;
  subgraph Input
    A((3))
  end
  subgraph Prediction
    B((fizz?))
  end

  A -."weight=0.2".-> B
  
weight = 0.2

predict = fn input -> input * weight end

predict.(3)
0.6000000000000001
graph LR;
  subgraph Input
    A((3)) --/3--> A1((0))
    A --/5--> A2((3))
  end
  subgraph Prediction
    B((fizz?))
  end

  A1 -."weight=0.2".-> B
  A2 -."weight=0.7".-> B
mods = fn x -> [rem(x, 3), rem(x, 5)] end

vector = mods.(3)

w_sum = fn input, weights ->
  input
  |> Enum.with_index()
  |> Enum.reduce(0, fn {input, i}, acc ->
    weight = Enum.at(weights, i)
    acc + input * weight
  end)
end

predict = fn input, weights ->
  w_sum.(input, weights)
end

predict.([0.0, 3.0], [0.2, 0.7])
2.0999999999999996
graph LR;
  A((3)) --w=0.7--> B[fizz?]
  A --w=0.4--> C[buzz?]
  A --w=0.5--> D[fizz_buzz?]
ele_mul = fn number, vector ->
  vector
  |> Enum.map(fn item -> 
    number * item
  end)
end

predict = fn input, weights ->
  ele_mul.(input, weights)
end

input = 3.0

weights = [0.7, 0.4, 0.5]

predict.(input, weights)
[2.0999999999999996, 1.2000000000000002, 1.5]
graph LR;
  I1((0))-->O1[fizz?]
  I1-->O2[buzz?]
  I1-->O3[fizz_buzz?]
  I2((3)) --> O1
  I2 --> O2
  I2 --> O3
  linkStyle 3 stroke:#0f0,stroke-width:1px;
  linkStyle 4 stroke:#0f0,stroke-width:1px;
  linkStyle 5 stroke:#0f0,stroke-width:1px;
vect_mat_mul = fn vector, matrix ->
  matrix
  |> Enum.map(fn item ->
    w_sum.(vector, item)
  end)
end

predict = fn input, weights ->
  vect_mat_mul.(input, weights)
end

weights = [[0.2, 0.3], [0.7, 0.1], [0.5, 0.1]]

input = [0.0, 3.0]

predict.(input, weights)
[0.8999999999999999, 0.30000000000000004, 0.30000000000000004]
softmax = fn inputs ->
  sum = Enum.sum(inputs)
  Enum.map(inputs, fn n -> n / sum end)
end

softmax.([0.8999999999999999, 0.30000000000000004, 0.30000000000000004])
[0.6, 0.20000000000000004, 0.20000000000000004]
graph LR;
  I1((0))-->L1((0.6))
  I1-->L2((0.2))
  I1-->L3((0.2))
  I2((3)) --> L1
  I2 --> L2
  I2 --> L3
  L1 --> O1((1.0))
  L2 --> O2((0.0))
  L3 --> O3((0.0))
  linkStyle 3 stroke:#0f0,stroke-width:1px;
  linkStyle 4 stroke:#0f0,stroke-width:1px;
  linkStyle 5 stroke:#0f0,stroke-width:1px;

Data

train_data =
  1..1000
  |> Stream.map(fn n ->
    tensor = Nx.tensor([mods.(n)])
    label = Nx.tensor([fizz_buzz.(n)])
    {tensor, label}
  end)
#Stream<[enum: 1..1000, funs: [#Function<50.38948127/1 in Stream.map/2>]]>

Model

model =
  Axon.input("input", shape: {nil, 2})
  |> Axon.dense(10, activation: :tanh)
  |> Axon.dense(4, activation: :softmax)
#Axon<
  inputs: %{"input" => {nil, 2}}
  outputs: "softmax_0"
  nodes: 5
>
template = Nx.template({1, 2}, :s32)
Axon.Display.as_graph(model, template)

Training

trained_model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)

19:35:43.337 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 950, loss: 0.3323180
Epoch: 1, Batch: 950, loss: 0.2097961
Epoch: 2, Batch: 950, loss: 0.1507273
Epoch: 3, Batch: 950, loss: 0.1159618
Epoch: 4, Batch: 950, loss: 0.0937235
Epoch: 5, Batch: 950, loss: 0.0784553
Epoch: 6, Batch: 950, loss: 0.0673833
Epoch: 7, Batch: 950, loss: 0.0590107
Epoch: 8, Batch: 950, loss: 0.0524685
Epoch: 9, Batch: 950, loss: 0.0472216
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[10]
      EXLA.Backend
      [0.007045082747936249, 0.6407318115234375, 0.3835016191005707, -0.7641347050666809, 0.588534951210022, -0.49212053418159485, 0.6598906517028809, 0.4348321259021759, -0.6103333830833435, 0.7605070471763611]
    >,
    "kernel" => #Nx.Tensor<
      f32[2][10]
      EXLA.Backend
      [
        [-0.4310312271118164, 1.1210333108901978, 0.8010596036911011, 1.7776334285736084, -1.3366260528564453, 1.6495720148086548, -1.4637233018875122, -0.579444944858551, 0.7187064290046692, 0.9409835338592529],
        [-0.15303368866443634, -1.3053309917449951, -1.0788722038269043, -0.537500262260437, -1.1483451128005981, 0.7073906660079956, -0.028834063559770584, 1.2421365976333618, 1.7351068258285522, -1.694563627243042]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      EXLA.Backend
      [0.07947669178247452, -0.17550566792488098, 0.2959376275539398, -1.2535045146942139]
    >,
    "kernel" => #Nx.Tensor<
      f32[10][4]
      EXLA.Backend
      [
        [0.10124798119068146, -0.6136413812637329, 0.43526801466941833, 0.7222546935081482],
        [-1.0806410312652588, 0.7023622393608093, 0.8863538503646851, -0.6920934319496155],
        [-1.474867820739746, 1.2845275402069092, 0.9321955442428589, -0.33254721760749817],
        [-0.9582555890083313, 1.3919050693511963, -0.7911892533302307, 0.8373909592628479],
        [-0.4029693603515625, -0.9608103632926941, 1.9728877544403076, 0.24115674197673798],
        [0.8901402950286865, 1.4944347143173218, -1.8724830150604248, -0.4015873968601227],
        [0.2996208071708679, -1.1419256925582886, 0.9844873547554016, -0.2348821610212326],
        [0.755868673324585, -0.8567834496498108, 0.28964880108833313, -1.5673366785049438],
        [1.1101288795471191, -0.17633414268493652, -1.2638652324676514, -0.7720271348953247],
        [-0.8016149401664734, 0.9841588735580444, 1.231026291847229, -0.46669310331344604]
      ]
    >
  }
}

Prediction

 {_init_fn, predict_fn} = Axon.build(model)
{#Function<134.70434077/2 in Nx.Defn.Compiler.fun/2>,
 #Function<134.70434077/2 in Nx.Defn.Compiler.fun/2>}
guess = fn x ->
  mod = Nx.tensor([mods.(x)])

  case predict_fn.(trained_model_state, mod) |> Nx.argmax() |> Nx.to_flat_list() do
    [0] -> "fizz"
    [1] -> "buzz"
    [2] -> "fizz_buzz"
    [3] -> "womp"
  end
end
#Function<42.39164016/1 in :erl_eval.expr/6>
guess.(3) |> IO.inspect(label: "3")
guess.(5) |> IO.inspect(label: "5")
guess.(15) |> IO.inspect(label: "15")
guess.(16) |> IO.inspect(label: "16")
3: "fizz"
5: "buzz"
15: "fizz_buzz"
16: "buzz"
"buzz"