Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

micrograd

sgd/micrograd.livemd

micrograd

Mix.install([
  {:axon, "~> 0.6.0"},
  :dg,
  {:scholar, "~> 0.1"},
  {:exla, ">= 0.0.0"},
  {:kino_vega_lite, "~> 0.1.9"}
])

Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl

Value Tree

defmodule Node1 do
  @moduledoc "stores a single scalar value and its gradient"
  defstruct data: nil,
            grad: 0,
            label: "",
            _backward: nil,
            _op: "",
            _children: []

  @type t :: %Node1{
          data: Integer.t() | nil,
          grad: Integer.t(),
          label: String.t(),
          _backward: fun(any),
          _op: String.t(),
          _children: list(%Node1{})
        }

  def add(%Node1{} = left, %Node1{} = right, label \\ "add", grad \\ 0) do
    out = %Node1{
      data: left.data + right.data,
      grad: grad,
      _children: previous(left, right),
      _op: "+",
      label: label
    }

    backward = fn node ->
      left = %Node1{left | grad: left.grad + node.grad}
      right = %Node1{right | grad: right.grad + node.grad}

      %Node1{node | _children: previous(left, right)}
    end

    %Node1{out | _backward: backward}
  end

  def mult(%Node1{} = left, %Node1{} = right, label \\ "mult", grad \\ 0) do
    out = %Node1{
      data: left.data * right.data,
      grad: grad,
      _children: previous(left, right),
      _op: "*",
      label: label
    }

    backward = fn node ->
      left = %Node1{left | grad: right.data * node.grad}
      right = %Node1{right | grad: left.data * node.grad}

      %Node1{node | _children: previous(left, right)}
    end

    %Node1{out | _backward: backward}
  end

  def tanh(%Node1{data: data} = children, label \\ "tanh", grad \\ 1) do
    t = (:math.exp(2 * data) - 1) / (:math.exp(2 * data) + 1)

    out = %Node1{
      data: t,
      grad: grad,
      label: label,
      _children: [children],
      _op: "tanh"
    }

    backward = fn node ->
      children = %Node1{children | grad: 1 - t ** 2}
      %Node1{node | _children: [children]}
    end

    %Node1{out | _backward: backward}
  end

  def previous(left, right) do
    [left, right]
  end

  def backward(%Node1{_children: []} = root) do
    root
  end

  def backward(
        %Node1{
          data: data,
          grad: grad,
          label: label,
          _op: op,
          _backward: backward
        } = node
      ) do
    node =
      if node._backward do
        node._backward.(node)
      else
        node
      end

    updated_children =
      Enum.map(node._children, fn child ->
        backward(child)
      end)

    %Node1{
      data: data,
      grad: grad,
      label: label,
      _backward: backward,
      _op: op,
      _children: updated_children
    }
  end
end

Interacting with the Value API

a = %Node1{data: 2, label: "a"}
b = %Node1{data: -3, label: "b"}
Node1.add(a, b)
a = %Node1{data: 2, label: "a"}
b = %Node1{data: -3, label: "b"}
c = %Node1{data: 10, label: "c"}
d = Node1.mult(a, b, "e") |> Node1.add(c, "d", 1)

Add Value Tree Visualizations

defmodule Graph do
  def draw_dot(root) do
    dot = DG.new()
    build_dot(root, dot, %{count: 0, ops: [], rids: []})

    dot
  end

  def build_dot(node, dot, visited) do
    label = "#{node.label} -> data #{node.data} -> grad #{node.grad}"
    rid = to_string(:rand.uniform(1000))

    DG.add_vertex(dot, node.label, label)

    count = Map.get(visited, :count)
    ops = Map.get(visited, :ops)
    rids = Map.get(visited, :rids)
    val = [node._op]
    visited = %{count: count + 1, ops: ops ++ val, rids: rids ++ [rid]}

    if node._op != "" do
      # create left _op vertex and connect to right edge
      DG.add_vertex(dot, "#{node._op}" <> rid, node._op)
      DG.add_edge(dot, "#{node._op}" <> rid, node.label)
    end

    if count != 0 do
      # if not root node, create edge between vertex and _op
      ops = Map.get(visited, :ops)
      op = Enum.at(ops, count - 1)
      rid = Enum.at(rids, count - 1)
      DG.add_edge(dot, String.trim(node.label), String.trim(op) <> rid)
    end

    Enum.map(node._children, fn child ->
      build_dot(child, dot, visited)
    end)
  end
end
d
|> Node1.backward()
|> Graph.draw_dot()
graph LR
    b[b -> data -3 -> grad 2]-->*201[*]
    *201[*]-->e[e -> data -6 -> grad 1]
    e[e -> data -6 -> grad 1]-->+336[+]
    c[c -> data 10 -> grad 1]-->+336[+]
    a[a -> data 2 -> grad -3]-->*201[*]
    +336[+]-->d[d -> data 4 -> grad 1]

What is h in the derivative?

h = 0.0001

# inputs
a = %Node1{data: 2.0, label: "a"}
b = %Node1{data: -3.0, label: "b"}
c = %Node1{data: 10.0, label: "c"}

# what does it mean to nudge a?
d1 = Node1.mult(a, b) |> Node1.add(c)
a = %Node1{data: 2.0 + h}
d2 = Node1.mult(a, b) |> Node1.add(c)

IO.inspect(d1.data)
IO.inspect(d2.data)
IO.inspect((d2.data - d1.data) / h)
h = 0.0001

# inputs
a = %Node1{data: 2.0, label: "a"}
b = %Node1{data: -3.0, label: "b"}
c = %Node1{data: 10.0, label: "c"}

# what does it mean to nudge b?
d1 = Node1.mult(a, b) |> Node1.add(c)
b = %Node1{data: -3.0 + h}
d2 = Node1.mult(a, b) |> Node1.add(c)

IO.inspect(d1.data)
IO.inspect(d2.data)
IO.inspect((d2.data - d1.data) / h)
h = 0.0001

# inputs
a = %Node1{data: 2.0 + h, label: "a"}
b = %Node1{data: -3.0, label: "b"}
c = %Node1{data: 10.0, label: "c"}

# what does it mean to nudge c?
d1 = Node1.mult(a, b) |> Node1.add(c)
c = %Node1{data: 10.0 + h}
d2 = Node1.mult(a, b) |> Node1.add(c)

IO.inspect(d1.data)
IO.inspect(d2.data)
IO.inspect((d2.data - d1.data) / h)

A More Complicated Value Tree

a = %Node1{data: 2, label: "a"}
b = %Node1{data: -3, label: "b"}
c = %Node1{data: 10, label: "c"}
e = Node1.mult(a, b, "e")
d = Node1.add(e, c, "d")
f = %Node1{data: -2, label: "f"}
l = Node1.mult(d, f, "L", 1)

l
|> Node1.backward()
|> Graph.draw_dot()
graph LR
    f[f -> data -2 -> grad 4]-->*393[*]
    b[b -> data -3 -> grad -4]-->*699[*]
    *393[*]-->L[L -> data -8 -> grad 1]
    e[e -> data -6 -> grad -2]-->+920[+]
    *699[*]-->e[e -> data -6 -> grad -2]
    c[c -> data 10 -> grad -2]-->+920[+]
    a[a -> data 2 -> grad 6]-->*699[*]
    d[d -> data 4 -> grad -2]-->*393[*]
    +920[+]-->d[d -> data 4 -> grad -2]

Intro to Nx

x1 = Nx.tensor([2.0])
x2 = Nx.tensor([0.0])

w1 = Nx.tensor([-3.0])
w2 = Nx.tensor([1.0])

b = Nx.tensor([6.8813735870195432])

x1w1 = Nx.multiply(x1, w1)
x2w2 = Nx.multiply(x2, w2)

output = x1w1 |> Nx.add(x2w2) |> Nx.add(b)
Nx.tanh(output)

Visualizing a Linear Function: y = 2x

defmodule Mathy do
  def actual(x) do
    2 * x
  end
end

chart =
  Vl.new(width: 400, height: 400)
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative)
  |> Kino.VegaLite.new()
  |> Kino.render()

for i <- -5..5 do
  point = %{x: i, y: Mathy.actual(i)}
  Kino.VegaLite.push(chart, point)
  Process.sleep(25)
end

Solving for a linear function using Nx

defmodule Solver do
  import Nx.Defn

  @learning_rate 0.01
  @batch_size 64

  defn predict({w, b}, x) do
    w * x + b
  end

  defn mse(activations, x, y) do
    y_hat = predict(activations, x)

    (y - y_hat)
    |> Nx.pow(2)
    |> Nx.mean()
  end

  defn update({w, b} = activations, x, y) do
    {grad_w, grad_b} = grad(activations, &amp;mse(&amp;1, x, y))

    {
      w - grad_w * @learning_rate,
      b - grad_b * @learning_rate
    }
  end

  def train(data, epochs \\ 100) do
    Enum.reduce(1..epochs, random_parameters(), fn _i, acc ->
      data
      |> Enum.take(@batch_size)
      |> Enum.reduce(acc, fn batch, activations ->
        {x, y} = Enum.unzip(batch)
        x = Nx.tensor(x)
        y = Nx.tensor(y)
        update(activations, x, y)
      end)
    end)
  end

  def random_parameters do
    key = Nx.Random.key(12)
    {w, _new_key} = Nx.Random.normal(key, 0.0, 1.0)
    {b, _new_key} = Nx.Random.normal(key, 0.0, 1.0)

    {w, b}
  end

  def generate_data do
    num = :rand.uniform()
    {x, y} = {num, Mathy.actual(num)}
    {Nx.tensor(x), Nx.tensor(y)}
  end
end

Manual Updates

# initialize a random w and b
activations = Solver.random_parameters()
# generate some data we know is true {1, 2}
# when x is 1, y is 2

{x, y} = Solver.generate_data()
Solver.update(activations, x, y)

Train In a Loop

data =
  Stream.repeatedly(fn ->
    Stream.map(0..64, fn _i ->
      num = :rand.uniform() * 10
      {num, Mathy.actual(num)}
    end)
  end)

model = Solver.train(data)
Solver.predict(model, 2)

Tanh Graphed

chart =
  Vl.new(width: 400, height: 400)
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative)
  |> Kino.VegaLite.new()
  |> Kino.render()

for i <- -10..10 do
  point = %{x: i, y: :math.tanh(i)}
  Kino.VegaLite.push(chart, point)
  Process.sleep(25)
end

Solving for tanh using Axon

model =
  Axon.input("data")
  |> Axon.dense(4, activation: :relu)
  |> Axon.dense(1, activation: :tanh)

batch_size = 64

data =
  Stream.repeatedly(fn ->
    {x, _next_key} =
      12
      |> Nx.Random.key()
      |> Nx.Random.normal(-5.0, 5.0, shape: {64, 1})

    y = Nx.tanh(x)
    {x, y}
  end)

params =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.run(data, %{}, epochs: 100, iterations: 100)
Axon.predict(model, params, Nx.tensor([[1]]))