Powered by AppSignal & Oban Pro

Ch 4: Optimisation

Ch4 - Optimize.livemd

Ch 4: Optimisation

Mix.install([
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
])

Nx.default_backend(EXLA.Backend)

Section

defmodule Loss do
  import Nx.Defn
  # example implementations. Not for actual use

  # y_pred is probability in range 0-1

  defn binary_cross_entropy(y_true, y_pred) do
    y_true * Nx.log(y_pred) - (1 - y_true) * Nx.log(1 - y_pred)
  end

  defn mean_squared_error(y_true, y_pred) do
    y_true
    |> Nx.subtract(y_pred)
    |> Nx.pow(2)
    |> Nx.mean(axes: [-1])
  end
end

Gradient Descent

defmodule SGD do
  import Nx.Defn

  defn init_random_params(key) do
    Nx.Random.uniform(
      key,
      shape: {32, 1}
    )
  end

  # This fits ideal model above but this is contrived example. 
  # IRL we wouldn't have that knowledge and would instead guess at ideal model
  defn model(params, inputs) do
    labels = Nx.dot(inputs, params)
    labels
  end

  defn loss(predicted_label, actual_label) do
    loss_value = Loss.mean_squared_error(actual_label, predicted_label)
    loss_value
  end

  defn objective(params, actual_inputs, actual_labels) do
    model(params, actual_inputs)
    |> loss(actual_labels)
  end

  defn step(params, actual_inputs, actual_labels, learning_rate \\ 1.0e-2) do
    {loss, params_grad} =
      value_and_grad(params, fn params ->
        objective(params, actual_inputs, actual_labels)
      end)

    new_params = params - learning_rate * params_grad
    {loss, new_params}
  end

  def evaluate(trained_params, test_data) do
    test_data
    |> Enum.map(fn
      {x, y} ->
        prediction = model(trained_params, x)
        loss(y, prediction)
    end)
    |> Enum.reduce(0, &Nx.add/2)
  end

  def train(data, iterations, key, learning_rate \\ 1.0e-2) do
    {params, _key} = init_random_params(key)
    loss = Nx.tensor(0.0)

    {_, trained_params} =
      for i <- 1..iterations,
          reduce: {loss, params} do
        {loss, params} ->
          for {{x, y}, j} <- Enum.with_index(data),
              reduce: {loss, params} do
            {loss, params} ->
              {batch_loss, new_params} = step(params, x, y, learning_rate)
              avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
              IO.write("\rEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}")
              {avg_loss, new_params}
          end
      end

    trained_params
  end
end
# Generate test data for some random ideal fit function

key = Nx.Random.key(42)

{true_params, new_key} =
  Nx.Random.uniform(key,
    shape: {32, 1}
  )

true_function =
  fn
    params, x ->
      Nx.dot(x, params)
  end

# Used to illustrate poorer performance when 
# model is less close of a match to true function
true_function_2 =
  fn
    params, x ->
      Nx.dot(x, params) |> Nx.cos()
  end



{train_x, new_key} =
  Nx.Random.uniform(new_key,
    shape: {10000, 32}
  )

train_y = true_function_2.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(train_y, 1))

{test_x, _new_key} =
  Nx.Random.uniform(new_key,
    shape: {10000, 32}
  )

test_y = true_function.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))

# Get baseline for loss
key = Nx.Random.key(100)
{random_params, _} = SGD.init_random_params(key)
SGD.evaluate(random_params, test_data)
# Train for one iteration
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 1, key)
SGD.evaluate(trained_params, test_data)