Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Micrograd

ML/micrograd.livemd

Micrograd

Mix.install([
  {:axon, "~> 0.6.0"},
  {:dg, "~> 0.4.0"},
  {:scholar, "~> 0.2.1"},
  {:exla, "~> 0.6.1"},
  {:kino_vega_lite, "~> 0.1.10"},
  {:kino, "~> 0.11.0"}
])

Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl

Node of gradient calculation (aka Value in micrograd)

defmodule MicrogradNode do
  @moduledoc "stores a single scalar value and its gradient"
  defstruct data: nil,
            grad: 0,
            label: "",
            _backward: nil,
            _op: "",
            _children: []

  @type t :: %MicrogradNode{
          data: Integer.t() | nil,
          grad: Integer.t(),
          label: String.t(),
          _backward: fun(any),
          _op: String.t(),
          _children: list(%MicrogradNode{})
        }

  def add(%MicrogradNode{} = left, %MicrogradNode{} = right, label \\ "add", grad \\ 0) do
    out = %MicrogradNode{
      data: left.data + right.data,
      grad: grad,
      _children: previous(left, right),
      _op: "+",
      label: label
    }

    backward = fn node ->
      left = %MicrogradNode{left | grad: left.grad + node.grad}
      right = %MicrogradNode{right | grad: right.grad + node.grad}

      %MicrogradNode{node | _children: previous(left, right)}
    end

    %MicrogradNode{out | _backward: backward}
  end

  def mult(%MicrogradNode{} = left, %MicrogradNode{} = right, label \\ "mult", grad \\ 0) do
    out = %MicrogradNode{
      data: left.data * right.data,
      grad: grad,
      _children: previous(left, right),
      _op: "*",
      label: label
    }

    backward = fn node ->
      left = %MicrogradNode{left | grad: right.data * node.grad}
      right = %MicrogradNode{right | grad: left.data * node.grad}

      %MicrogradNode{node | _children: previous(left, right)}
    end

    %MicrogradNode{out | _backward: backward}
  end

  def tanh(%MicrogradNode{data: data} = children, label \\ "tanh", grad \\ 1) do
    t = (:math.exp(2 * data) - 1) / (:math.exp(2 * data) + 1)

    out = %MicrogradNode{
      data: t,
      grad: grad,
      label: label,
      _children: [children],
      _op: "tanh"
    }

    backward = fn node ->
      children = %MicrogradNode{children | grad: 1 - t ** 2}
      %MicrogradNode{node | _children: [children]}
    end

    %MicrogradNode{out | _backward: backward}
  end

  def previous(left, right) do
    [left, right]
  end

  def backward(%MicrogradNode{_children: []} = root) do
    root
  end

  def backward(
        %MicrogradNode{
          data: data,
          grad: grad,
          label: label,
          _op: op,
          _backward: backward
        } = node
      ) do
    node =
      if node._backward do
        node._backward.(node)
      else
        node
      end

    updated_children =
      Enum.map(node._children, fn child ->
        backward(child)
      end)

    %MicrogradNode{
      data: data,
      grad: grad,
      label: label,
      _backward: backward,
      _op: op,
      _children: updated_children
    }
  end
end

Test micrograd node API

a = %MicrogradNode{data: 2, label: "a"}
b = %MicrogradNode{data: -3, label: "b"}
MicrogradNode.add(a, b)
a = %MicrogradNode{data: 2, label: "a"}
b = %MicrogradNode{data: -3, label: "b"}
c = %MicrogradNode{data: 10, label: "c"}
d = MicrogradNode.mult(a, b, "e") |> MicrogradNode.add(c, "d", 1)

Visualization

defmodule Graph do
  def draw_dot(root) do
    dot = DG.new()
    build_dot(root, dot, %{count: 0, ops: [], rids: []})

    dot
  end

  def build_dot(node, dot, visited) do
    label = "#{node.label} -> data #{node.data} -> grad #{node.grad}"
    rid = to_string(:rand.uniform(1000))

    DG.add_vertex(dot, node.label, label)

    count = Map.get(visited, :count)
    ops = Map.get(visited, :ops)
    rids = Map.get(visited, :rids)
    val = [node._op]
    visited = %{count: count + 1, ops: ops ++ val, rids: rids ++ [rid]}

    if node._op != "" do
      # create left _op vertex and connect to right edge
      DG.add_vertex(dot, "#{node._op}" <> rid, node._op)
      DG.add_edge(dot, "#{node._op}" <> rid, node.label)
    end

    if count != 0 do
      # if not root node, create edge between vertex and _op
      ops = Map.get(visited, :ops)
      op = Enum.at(ops, count - 1)
      rid = Enum.at(rids, count - 1)
      DG.add_edge(dot, String.trim(node.label), String.trim(op) <> rid)
    end

    Enum.map(node._children, fn child ->
      build_dot(child, dot, visited)
    end)
  end
end
d
|> MicrogradNode.backward()
|> Graph.draw_dot()
|> inspect()
|> Kino.Mermaid.new()

Nx

x1 = Nx.tensor([2.0])
x2 = Nx.tensor([0.0])

w1 = Nx.tensor([-3.0])
w2 = Nx.tensor([1.0])

b = Nx.tensor([6.8813735870195432])

x1w1 = Nx.multiply(x1, w1)
x2w2 = Nx.multiply(x2, w2)

output = x1w1 |> Nx.add(x2w2) |> Nx.add(b)
Nx.tanh(output)

Visualizing a Linear Function: y = 2x

defmodule Mathy do
  def actual(x) do
    2 * x
  end
end

chart =
  Vl.new(width: 400, height: 400)
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative)
  |> Kino.VegaLite.new()
  |> Kino.render()

for i <- -5..5 do
  point = %{x: i, y: Mathy.actual(i)}
  Kino.VegaLite.push(chart, point)
  Process.sleep(25)
end

Solving linear function with Nx

defmodule Solver do
  import Nx.Defn

  @learning_rate 0.01
  @batch_size 64

  defn predict({w, b}, x) do
    w * x + b
  end

  defn mse(activations, x, y) do
    y_hat = predict(activations, x)

    (y - y_hat)
    |> Nx.pow(2)
    |> Nx.mean()
  end

  defn update({w, b} = activations, x, y) do
    {grad_w, grad_b} = grad(activations, &amp;mse(&amp;1, x, y))

    {
      w - grad_w * @learning_rate,
      b - grad_b * @learning_rate
    }
  end

  def train(data, epochs \\ 100) do
    Enum.reduce(1..epochs, random_parameters(), fn _i, acc ->
      data
      |> Enum.take(@batch_size)
      |> Enum.reduce(acc, fn batch, activations ->
        {x, y} = Enum.unzip(batch)
        x = Nx.tensor(x)
        y = Nx.tensor(y)
        update(activations, x, y)
      end)
    end)
  end

  def random_parameters do
    key = Nx.Random.key(12)
    {w, _new_key} = Nx.Random.normal(key, 0.0, 1.0)
    {b, _new_key} = Nx.Random.normal(key, 0.0, 1.0)

    {w, b}
  end

  def generate_data do
    num = :rand.uniform()
    {x, y} = {num, Mathy.actual(num)}
    {Nx.tensor(x), Nx.tensor(y)}
  end
end

Train in loop

data =
  Stream.repeatedly(fn ->
    Stream.map(0..64, fn _i ->
      num = :rand.uniform() * 10
      {num, Mathy.actual(num)}
    end)
  end)

model = Solver.train(data)
Solver.predict(model, 2)

Tanh graph

chart =
  Vl.new(width: 400, height: 400)
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative)
  |> Kino.VegaLite.new()
  |> Kino.render()

for i <- -10..10 do
  point = %{x: i, y: :math.tanh(i)}
  Kino.VegaLite.push(chart, point)
  Process.sleep(25)
end

Train NN for tanh with Axon

model =
  Axon.input("data")
  |> Axon.dense(4, activation: :relu)
  |> Axon.dense(1, activation: :tanh)

batch_size = 64

data =
  Stream.repeatedly(fn ->
    {x, _next_key} =
      12
      |> Nx.Random.key()
      |> Nx.Random.normal(-5.0, 5.0, shape: {64, 1})

    y = Nx.tanh(x)
    {x, y}
  end)

params =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.run(data, %{}, epochs: 100, iterations: 100)
Axon.predict(model, params, Nx.tensor([[1]]))