Chapter 3
Mix.install([
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:kino, "~> 0.8"},
{:stb_image, "~> 0.6"},
{:vega_lite, "~> 0.1"},
{:kino_vega_lite, "~> 0.1"}
])
Vectors
Everything in Nx is a tensor. Vectors are represented as Nx tensors with a rank of 1, or a single-dimensional tensor.
Nx.default_backend(EXLA.Backend)
Scalars
Scalars are represented as 0-dimensional tensors.
Matrices
Matrices are two-dimensional tensors.
Vector Addition
sales_day_1 = Nx.tensor([32, 10, 14])
sales_day_2 = Nx.tensor([10, 24, 21])
total_sales = Nx.add(sales_day_1, sales_day_2)
Scalar Multiplication
keep_rate = 0.9
unreturned_sales = Nx.multiply(keep_rate, total_sales)
price_per_product = Nx.tensor([9.95, 10.95, 5.99])
revenue_per_product = Nx.multiply(unreturned_sales, price_per_product)
sales_matrix = Nx.tensor([[32, 10, 14], [10, 24, 21]])
Nx.transpose(sales_matrix)
Linear Transformations
Maps inputs to outputs. Preserves linearity.
Probability
# simulation = fn key ->
# {value, key} = Nx.Random.uniform(key)
# rounded = if Nx.to_number(value) < 0.5, do: 0, else: 1
# {rounded, key}
# end
# key = Nx.Random.key(42)
# for n <- 1..4 do
# 1..round(:math.pow(10, n))
# |> Enum.map_reduce(key, fn _, key -> simulation.(key) end)
# |> elem(0)
# |> Enum.sum()
# |> IO.inspect()
# end
Tracking Change
defmodule BerryFarm do
import Nx.Defn
defn profits(trees) do
# trees
# |> Nx.subtract(1)
# |> Nx.pow(4)
# |> Nx.negate()
# |> Nx.add(Nx.pow(trees, 3))
# |> Nx.add(Nx.pow(trees, 2))
-((trees - 1) ** 4) + trees ** 3 + trees ** 2
end
defn profits_derivative(trees) do
grad(trees, &profits/1)
end
end
alias VegaLite, as: V
trees = Nx.linspace(0, 4, n: 100)
profits = BerryFarm.profits(trees)
profits_derivative = BerryFarm.profits_derivative(trees)
[height: 1080, title: "Berry Profits and Profits Rate of Change", width: 1440]
|> V.new()
|> V.data_from_values(%{
profits: Nx.to_flat_list(profits),
profits_derivative: Nx.to_flat_list(profits_derivative),
trees: Nx.to_flat_list(trees)
})
|> V.layers([
V.new()
|> V.mark(:line, interpolate: :basis)
|> V.encode_field(:x, "trees", type: :quantitative)
|> V.encode_field(:y, "profits", type: :quantitative),
V.new()
|> V.mark(:line, interpolate: :basis)
|> V.encode_field(:x, "trees", type: :quantitative)
|> V.encode_field(:y, "profits_derivative", type: :quantitative)
|> V.encode(:color, value: "#ff0000")
])
Differentiation
defmodule GradFun do
import Nx.Defn
defn my_function(x) do
x
|> Nx.cos()
|> Nx.exp()
|> Nx.sum()
|> print_expr(label: "my_function")
end
# Gradients are the direction of steepest change for scalar functions
# Must return a floating-point type
defn grad_my_function(x) do
x
|> grad(&my_function/1)
|> print_expr(label: "grad_my_function")
end
end
GradFun.grad_my_function(Nx.tensor([1.0, 2.0, 3.0]))