Chapter Three
Mix.install([
{:nx, "~> 0.6"},
{:exla, "~> 0.6"},
{:kino, "~> 0.10"},
{:stb_image, "~> 0.6"},
{:vega_lite, "~> 0.1"},
{:kino_vega_lite, "~> 0.1"}
])
Nx.default_backend(EXLA.Backend)
Section
defmodule BerryFarm do
import Nx.Defn
defn profits(trees) do
# trees
# |> Nx.subtract(1)
# |> Nx.pow(4)
# |> Nx.negate()
# |> Nx.add(Nx.multiply(2, Nx.pow(trees, 3)))
# |> Nx.add(Nx.pow(trees, 2))
-((trees - 1) ** 4) + trees ** 3 + trees ** 2
end
defn profits_derivative(trees) do
grad(trees, &profits/1)
end
end
trees = Nx.linspace(0, 4, n: 100)
profits = BerryFarm.profits(trees)
profits_derivative = BerryFarm.profits_derivative(trees)
title = "Berry profits and profits derivatives"
VegaLite.new(title: title, width: 800, height: 500)
|> VegaLite.data_from_values(%{
trees: Nx.to_flat_list(trees),
profits: Nx.to_flat_list(profits),
profits_derivative: Nx.to_flat_list(profits_derivative)
})
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.mark(:line, interpolate: :basis)
|> VegaLite.encode_field(:x, "trees", type: :quantitative, scale: %{domain: [0, 4]})
|> VegaLite.encode_field(:y, "profits", type: :quantitative, scale: %{domain: [-49, 21]}),
VegaLite.new()
|> VegaLite.mark(:line, interpolate: :basis)
|> VegaLite.encode_field(:x, "trees", type: :quantitative, scale: %{domain: [0, 4]})
|> VegaLite.encode_field(:y, "profits_derivative", type: :quantitative)
|> VegaLite.encode(:color, value: "#FF0000")
])
defmodule GradientPlayground do
import Nx.Defn
defn some_function(x) do
# x
# |> Nx.cos()
# |> Nx.exp()
# |> Nx.sum()
# |> print_expr()
Nx.sum(Nx.exp(Nx.cos(x))) |> print_expr()
end
defn(gradient_some_function(x), do: grad(x, &some_function/1) |> print_expr())
end
GradientPlayground.gradient_some_function(Nx.tensor([1, 2, 3]))