Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 4

ch-4.livemd

Chapter 4

Mix.install([{:nx, "~> 0.5"}])

Test Data

key = Nx.Random.key(42)
{true_params, key} = Nx.Random.uniform(key, shape: {32, 1})
true_fn = &Nx.dot(&2, &1)
{train_x, key} = Nx.Random.uniform(key, shape: {10_000, 32})
test_y = true_fn.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(test_y, 1))
{test_x, _} = Nx.Random.uniform(key, shape: {10_000, 32})
test_y = true_fn.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))
defmodule SGD do
  @moduledoc """
  Stochastic Gradient Descent
  """
  import Nx.Defn

  defn init_random_params(key) do
    Nx.Random.uniform(key, shape: {32, 1})
  end

  defn model(params, inputs) do
    Nx.dot(inputs, params)
  end

  # Mean squared error
  defn loss(actual, predicted) do
    actual
    |> Nx.subtract(predicted)
    |> Nx.pow(2)
    |> Nx.mean(axes: [-1])
  end

  defn objective(params, actual_inputs, actual_labels) do
    loss(actual_labels, model(params, actual_inputs))
  end

  defn step(params, actual_inputs, actual_labels) do
    learning_rate = 1.0e-2
    {loss, params_grad} = value_and_grad(params, &objective(&1, actual_inputs, actual_labels))
    {loss, params - learning_rate * params_grad}
  end

  def evaluate(trained_params, test_data) do
    Enum.reduce(test_data, 0, fn {x, y}, acc ->
      prediction = model(trained_params, x)
      Nx.add(loss(y, prediction), acc)
    end)
  end

  def train(data, iterations, key) do
    {params, _} = init_random_params(key)
    loss = Nx.tensor(0.0)

    {_, trained_params} =
      for i <- 1..iterations, reduce: {loss, params} do
        {loss, params} ->
          for {{x, y}, j} <- Enum.with_index(data), reduce: {loss, params} do
            {loss, params} ->
              {batch_loss, params} = step(params, x, y)
              avg_loss = Nx.divide(Nx.add(Nx.mean(batch_loss), loss), j + 1)
              IO.write("\nEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}")
              {avg_loss, params}
          end
      end

    trained_params
  end
end
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 1, key)
SGD.evaluate(trained_params, test_data)