Chapter 6
Mix.install([{:nx, "~> 0.5"}])
Section
defmodule NeuralNetwork do
import Nx.Defn
defn dense(input, weight, bias) do
input
|> Nx.dot(weight)
|> Nx.add(bias)
end
defn activation(input) do
Nx.sigmoid(input)
end
defn hidden(input, weight, bias) do
input
|> dense(weight, bias)
|> activation()
end
# Same as `hidden`
defn output(input, weight, bias) do
input
|> dense(weight, bias)
|> activation()
end
defn predict(input, w1, b1, w2, b2) do
input
|> hidden(w1, b1)
|> output(w2, b2)
end
end
key = Nx.Random.key(42)
{w1, key} = Nx.Random.uniform(key)
{b1, key} = Nx.Random.uniform(key)
{w2, key} = Nx.Random.uniform(key)
{b2, key} = Nx.Random.uniform(key)
key
|> Nx.Random.uniform_split(shape: {})
|> NeuralNetwork.predict(w1, b2, w2, b2)