Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Optimize Everything

OptimizeEverything.livemd

Optimize Everything

Mix.install([
  {:nx, "~> 0.5"}
])

Descending Gradients

key = Nx.Random.key(42)

{true_params, new_key} = Nx.Random.uniform(key, shape: {32, 1})

true_function = fn params, x ->
  Nx.dot(x, params)
end

{train_x, new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
train_y = true_function.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(train_y, 1))

{test_x, _new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
test_y = true_function.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))
defmodule SGD do
  import Nx.Defn

  defn init_random_params(key) do
    Nx.Random.uniform(key, shape: {32, 1})
  end

  defn model(params, inputs) do
    labels = Nx.dot(inputs, params)
    labels
  end

  defn mean_squared_error(y_true, y_pred) do
    y_true
    |> Nx.subtract(y_pred)
    |> Nx.pow(2)
    |> Nx.mean(axes: [-1])
  end

  defn loss(actual_label, predicted_label) do
    loss_value = mean_squared_error(actual_label, predicted_label)
    loss_value
  end

  defn objective(params, actual_inputs, actual_labels) do
    predicted_labels = model(params, actual_inputs)
    loss(actual_labels, predicted_labels)
  end

  defn step(params, actual_inputs, actual_labels) do
    {loss, params_grad} =
      value_and_grad(params, fn params ->
        objective(params, actual_inputs, actual_labels)
      end)

    new_params = params - 1.0e-4 * params_grad
    {loss, new_params}
  end

  def evaluate(trained_params, test_data) do
    test_data
    |> Enum.map(fn {x, y} ->
      prediction = model(trained_params, x)
      loss(y, prediction)
    end)
    |> Enum.reduce(0, &Nx.add/2)
  end

  def train(data, iterations, key) do
    {params, _key} = init_random_params(key)
    loss = Nx.tensor(0.0)

    {_, trained_params} =
      for i <- 1..iterations, reduce: {loss, params} do
        {loss, params} ->
          for {{x, y}, j} <- Enum.with_index(data), reduce: {loss, params} do
            {loss, params} ->
              {batch_loss, new_params} = step(params, x, y)
              avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
              IO.write("\rEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}")
              {avg_loss, new_params}
          end
      end

    trained_params
  end
end
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 10, key)
SGD.evaluate(trained_params, test_data)
key = Nx.Random.key(100)
{random_params, _} = SGD.init_random_params(key)
SGD.evaluate(random_params, test_data)
Nx.subtract(trained_params, true_params)

Making it Fail

key = Nx.Random.key(42)

{true_params, new_key} = Nx.Random.uniform(key, shape: {32, 1})

true_function = fn params, x ->
  Nx.dot(x, params) |> Nx.cos()
end

{train_x, new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
train_y = true_function.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(train_y, 1))

{test_x, _new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
test_y = true_function.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 10, key)
SGD.evaluate(trained_params, test_data)