chapter04
Mix.install([
{:nx, "~> 0.5"}
])
Section
key = Nx.Random.key(42)
{true_params, new_key} = Nx.Random.uniform(key, shape: {32, 1})
true_function = fn params, x ->
Nx.dot(x, params)
end
{train_x, new_key} = Nx.Random.uniform(key, shape: {10000, 32})
train_y = true_function.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(train_y, 1))
{test_x, _new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
test_y = true_function.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))
defmodule SGD do
import Nx.Defn
defn init_random_param(key) do
Nx.Random.uniform(key, shape: {32, 1})
end
defn model(params, inputs) do
labels = Nx.dot(inputs, params)
labels
end
defn mean_squared_error(y_true, y_pred) do
y_true
|> Nx.subtract(y_pred)
|> Nx.pow(2)
|> Nx.mean(axes: [-1])
end
defn loss(actual_label, predicted_label) do
loss_value = mean_squared_error(actual_label, predicted_label)
loss_value
end
defn objective(params, actual_inputs, actual_labels) do
predicted_labels = model(params, actual_inputs)
loss(actual_labels, predicted_labels)
end
defn step(params, actual_inputs, actual_labels) do
{loss, params_grad} =
value_and_grad(params, fn params ->
objective(params, actual_inputs, actual_labels)
end)
new_params = params - 1.0e-2 * params_grad
{loss, new_params}
end
def init_random_params(key) do
end
def evaluate(trained_params, test_data) do
test_data
|> Enum.map(fn {x, y} ->
prediction = model(trained_params, x)
loss(y, prediction)
end)
|> Enum.reduce(0, &Nx.add/2)
end
def train(data, iterations, key) do
{params, _key} = init_random_params(key)
loss = Nx.tensor(0.0)
{_, trained_params} =
for i <- 1..iterations, reduce: {loss, params} do
{loss, params} ->
for {{x, y}, j} <- Enum.with_index(data), reduce: {loss, params} do
{loss, params} ->
{batch_loss, new_params} = step(params, x, y)
avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
IO.write("\rEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}")
{avg_loss, new_params}
end
end
trained_params
end
end
key = Nx.Random.key(100)
{random_params, _} = SGD.init_random_param(key)
SGD.evaluate(random_params, test_data)
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 1, key)