Ch 4: Optimisation
Mix.install([
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
])
Nx.default_backend(EXLA.Backend)
Section
defmodule Loss do
import Nx.Defn
# example implementations. Not for actual use
# y_pred is probability in range 0-1
defn binary_cross_entropy(y_true, y_pred) do
y_true * Nx.log(y_pred) - (1 - y_true) * Nx.log(1 - y_pred)
end
defn mean_squared_error(y_true, y_pred) do
y_true
|> Nx.subtract(y_pred)
|> Nx.pow(2)
|> Nx.mean(axes: [-1])
end
end
Gradient Descent
defmodule SGD do
import Nx.Defn
defn init_random_params(key) do
Nx.Random.uniform(
key,
shape: {32, 1}
)
end
# This fits ideal model above but this is contrived example.
# IRL we wouldn't have that knowledge and would instead guess at ideal model
defn model(params, inputs) do
labels = Nx.dot(inputs, params)
labels
end
defn loss(predicted_label, actual_label) do
loss_value = Loss.mean_squared_error(actual_label, predicted_label)
loss_value
end
defn objective(params, actual_inputs, actual_labels) do
model(params, actual_inputs)
|> loss(actual_labels)
end
defn step(params, actual_inputs, actual_labels, learning_rate \\ 1.0e-2) do
{loss, params_grad} =
value_and_grad(params, fn params ->
objective(params, actual_inputs, actual_labels)
end)
new_params = params - learning_rate * params_grad
{loss, new_params}
end
def evaluate(trained_params, test_data) do
test_data
|> Enum.map(fn
{x, y} ->
prediction = model(trained_params, x)
loss(y, prediction)
end)
|> Enum.reduce(0, &Nx.add/2)
end
def train(data, iterations, key, learning_rate \\ 1.0e-2) do
{params, _key} = init_random_params(key)
loss = Nx.tensor(0.0)
{_, trained_params} =
for i <- 1..iterations,
reduce: {loss, params} do
{loss, params} ->
for {{x, y}, j} <- Enum.with_index(data),
reduce: {loss, params} do
{loss, params} ->
{batch_loss, new_params} = step(params, x, y, learning_rate)
avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
IO.write("\rEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}")
{avg_loss, new_params}
end
end
trained_params
end
end
# Generate test data for some random ideal fit function
key = Nx.Random.key(42)
{true_params, new_key} =
Nx.Random.uniform(key,
shape: {32, 1}
)
true_function =
fn
params, x ->
Nx.dot(x, params)
end
# Used to illustrate poorer performance when
# model is less close of a match to true function
true_function_2 =
fn
params, x ->
Nx.dot(x, params) |> Nx.cos()
end
{train_x, new_key} =
Nx.Random.uniform(new_key,
shape: {10000, 32}
)
train_y = true_function_2.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(train_y, 1))
{test_x, _new_key} =
Nx.Random.uniform(new_key,
shape: {10000, 32}
)
test_y = true_function.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))
# Get baseline for loss
key = Nx.Random.key(100)
{random_params, _} = SGD.init_random_params(key)
SGD.evaluate(random_params, test_data)
# Train for one iteration
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 1, key)
SGD.evaluate(trained_params, test_data)