micrograd
Mix.install([
{:axon, "~> 0.6.0"},
:dg,
{:scholar, "~> 0.1"},
{:exla, ">= 0.0.0"},
{:kino_vega_lite, "~> 0.1.9"}
])
Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
Value Tree
defmodule Node1 do
@moduledoc "stores a single scalar value and its gradient"
defstruct data: nil,
grad: 0,
label: "",
_backward: nil,
_op: "",
_children: []
@type t :: %Node1{
data: Integer.t() | nil,
grad: Integer.t(),
label: String.t(),
_backward: fun(any),
_op: String.t(),
_children: list(%Node1{})
}
def add(%Node1{} = left, %Node1{} = right, label \\ "add", grad \\ 0) do
out = %Node1{
data: left.data + right.data,
grad: grad,
_children: previous(left, right),
_op: "+",
label: label
}
backward = fn node ->
left = %Node1{left | grad: left.grad + node.grad}
right = %Node1{right | grad: right.grad + node.grad}
%Node1{node | _children: previous(left, right)}
end
%Node1{out | _backward: backward}
end
def mult(%Node1{} = left, %Node1{} = right, label \\ "mult", grad \\ 0) do
out = %Node1{
data: left.data * right.data,
grad: grad,
_children: previous(left, right),
_op: "*",
label: label
}
backward = fn node ->
left = %Node1{left | grad: right.data * node.grad}
right = %Node1{right | grad: left.data * node.grad}
%Node1{node | _children: previous(left, right)}
end
%Node1{out | _backward: backward}
end
def tanh(%Node1{data: data} = children, label \\ "tanh", grad \\ 1) do
t = (:math.exp(2 * data) - 1) / (:math.exp(2 * data) + 1)
out = %Node1{
data: t,
grad: grad,
label: label,
_children: [children],
_op: "tanh"
}
backward = fn node ->
children = %Node1{children | grad: 1 - t ** 2}
%Node1{node | _children: [children]}
end
%Node1{out | _backward: backward}
end
def previous(left, right) do
[left, right]
end
def backward(%Node1{_children: []} = root) do
root
end
def backward(
%Node1{
data: data,
grad: grad,
label: label,
_op: op,
_backward: backward
} = node
) do
node =
if node._backward do
node._backward.(node)
else
node
end
updated_children =
Enum.map(node._children, fn child ->
backward(child)
end)
%Node1{
data: data,
grad: grad,
label: label,
_backward: backward,
_op: op,
_children: updated_children
}
end
end
Interacting with the Value API
a = %Node1{data: 2, label: "a"}
b = %Node1{data: -3, label: "b"}
Node1.add(a, b)
a = %Node1{data: 2, label: "a"}
b = %Node1{data: -3, label: "b"}
c = %Node1{data: 10, label: "c"}
d = Node1.mult(a, b, "e") |> Node1.add(c, "d", 1)
Add Value Tree Visualizations
defmodule Graph do
def draw_dot(root) do
dot = DG.new()
build_dot(root, dot, %{count: 0, ops: [], rids: []})
dot
end
def build_dot(node, dot, visited) do
label = "#{node.label} -> data #{node.data} -> grad #{node.grad}"
rid = to_string(:rand.uniform(1000))
DG.add_vertex(dot, node.label, label)
count = Map.get(visited, :count)
ops = Map.get(visited, :ops)
rids = Map.get(visited, :rids)
val = [node._op]
visited = %{count: count + 1, ops: ops ++ val, rids: rids ++ [rid]}
if node._op != "" do
# create left _op vertex and connect to right edge
DG.add_vertex(dot, "#{node._op}" <> rid, node._op)
DG.add_edge(dot, "#{node._op}" <> rid, node.label)
end
if count != 0 do
# if not root node, create edge between vertex and _op
ops = Map.get(visited, :ops)
op = Enum.at(ops, count - 1)
rid = Enum.at(rids, count - 1)
DG.add_edge(dot, String.trim(node.label), String.trim(op) <> rid)
end
Enum.map(node._children, fn child ->
build_dot(child, dot, visited)
end)
end
end
d
|> Node1.backward()
|> Graph.draw_dot()
graph LR
b[b -> data -3 -> grad 2]-->*201[*]
*201[*]-->e[e -> data -6 -> grad 1]
e[e -> data -6 -> grad 1]-->+336[+]
c[c -> data 10 -> grad 1]-->+336[+]
a[a -> data 2 -> grad -3]-->*201[*]
+336[+]-->d[d -> data 4 -> grad 1]
What is h in the derivative?
h = 0.0001
# inputs
a = %Node1{data: 2.0, label: "a"}
b = %Node1{data: -3.0, label: "b"}
c = %Node1{data: 10.0, label: "c"}
# what does it mean to nudge a?
d1 = Node1.mult(a, b) |> Node1.add(c)
a = %Node1{data: 2.0 + h}
d2 = Node1.mult(a, b) |> Node1.add(c)
IO.inspect(d1.data)
IO.inspect(d2.data)
IO.inspect((d2.data - d1.data) / h)
h = 0.0001
# inputs
a = %Node1{data: 2.0, label: "a"}
b = %Node1{data: -3.0, label: "b"}
c = %Node1{data: 10.0, label: "c"}
# what does it mean to nudge b?
d1 = Node1.mult(a, b) |> Node1.add(c)
b = %Node1{data: -3.0 + h}
d2 = Node1.mult(a, b) |> Node1.add(c)
IO.inspect(d1.data)
IO.inspect(d2.data)
IO.inspect((d2.data - d1.data) / h)
h = 0.0001
# inputs
a = %Node1{data: 2.0 + h, label: "a"}
b = %Node1{data: -3.0, label: "b"}
c = %Node1{data: 10.0, label: "c"}
# what does it mean to nudge c?
d1 = Node1.mult(a, b) |> Node1.add(c)
c = %Node1{data: 10.0 + h}
d2 = Node1.mult(a, b) |> Node1.add(c)
IO.inspect(d1.data)
IO.inspect(d2.data)
IO.inspect((d2.data - d1.data) / h)
A More Complicated Value Tree
a = %Node1{data: 2, label: "a"}
b = %Node1{data: -3, label: "b"}
c = %Node1{data: 10, label: "c"}
e = Node1.mult(a, b, "e")
d = Node1.add(e, c, "d")
f = %Node1{data: -2, label: "f"}
l = Node1.mult(d, f, "L", 1)
l
|> Node1.backward()
|> Graph.draw_dot()
graph LR
f[f -> data -2 -> grad 4]-->*393[*]
b[b -> data -3 -> grad -4]-->*699[*]
*393[*]-->L[L -> data -8 -> grad 1]
e[e -> data -6 -> grad -2]-->+920[+]
*699[*]-->e[e -> data -6 -> grad -2]
c[c -> data 10 -> grad -2]-->+920[+]
a[a -> data 2 -> grad 6]-->*699[*]
d[d -> data 4 -> grad -2]-->*393[*]
+920[+]-->d[d -> data 4 -> grad -2]
Intro to Nx
x1 = Nx.tensor([2.0])
x2 = Nx.tensor([0.0])
w1 = Nx.tensor([-3.0])
w2 = Nx.tensor([1.0])
b = Nx.tensor([6.8813735870195432])
x1w1 = Nx.multiply(x1, w1)
x2w2 = Nx.multiply(x2, w2)
output = x1w1 |> Nx.add(x2w2) |> Nx.add(b)
Nx.tanh(output)
Visualizing a Linear Function: y = 2x
defmodule Mathy do
def actual(x) do
2 * x
end
end
chart =
Vl.new(width: 400, height: 400)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Kino.VegaLite.new()
|> Kino.render()
for i <- -5..5 do
point = %{x: i, y: Mathy.actual(i)}
Kino.VegaLite.push(chart, point)
Process.sleep(25)
end
Solving for a linear function using Nx
defmodule Solver do
import Nx.Defn
@learning_rate 0.01
@batch_size 64
defn predict({w, b}, x) do
w * x + b
end
defn mse(activations, x, y) do
y_hat = predict(activations, x)
(y - y_hat)
|> Nx.pow(2)
|> Nx.mean()
end
defn update({w, b} = activations, x, y) do
{grad_w, grad_b} = grad(activations, &mse(&1, x, y))
{
w - grad_w * @learning_rate,
b - grad_b * @learning_rate
}
end
def train(data, epochs \\ 100) do
Enum.reduce(1..epochs, random_parameters(), fn _i, acc ->
data
|> Enum.take(@batch_size)
|> Enum.reduce(acc, fn batch, activations ->
{x, y} = Enum.unzip(batch)
x = Nx.tensor(x)
y = Nx.tensor(y)
update(activations, x, y)
end)
end)
end
def random_parameters do
key = Nx.Random.key(12)
{w, _new_key} = Nx.Random.normal(key, 0.0, 1.0)
{b, _new_key} = Nx.Random.normal(key, 0.0, 1.0)
{w, b}
end
def generate_data do
num = :rand.uniform()
{x, y} = {num, Mathy.actual(num)}
{Nx.tensor(x), Nx.tensor(y)}
end
end
Manual Updates
# initialize a random w and b
activations = Solver.random_parameters()
# generate some data we know is true {1, 2}
# when x is 1, y is 2
{x, y} = Solver.generate_data()
Solver.update(activations, x, y)
Train In a Loop
data =
Stream.repeatedly(fn ->
Stream.map(0..64, fn _i ->
num = :rand.uniform() * 10
{num, Mathy.actual(num)}
end)
end)
model = Solver.train(data)
Solver.predict(model, 2)
Tanh Graphed
chart =
Vl.new(width: 400, height: 400)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Kino.VegaLite.new()
|> Kino.render()
for i <- -10..10 do
point = %{x: i, y: :math.tanh(i)}
Kino.VegaLite.push(chart, point)
Process.sleep(25)
end
Solving for tanh using Axon
model =
Axon.input("data")
|> Axon.dense(4, activation: :relu)
|> Axon.dense(1, activation: :tanh)
batch_size = 64
data =
Stream.repeatedly(fn ->
{x, _next_key} =
12
|> Nx.Random.key()
|> Nx.Random.normal(-5.0, 5.0, shape: {64, 1})
y = Nx.tanh(x)
{x, y}
end)
params =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.run(data, %{}, epochs: 100, iterations: 100)
Axon.predict(model, params, Nx.tensor([[1]]))