Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter Three

Chapter03.livemd

Chapter Three

Mix.install([
  {:nx, "~> 0.6"},
  {:exla, "~> 0.6"},
  {:kino, "~> 0.10"},
  {:stb_image, "~> 0.6"},
  {:vega_lite, "~> 0.1"},
  {:kino_vega_lite, "~> 0.1"}
])

Nx.default_backend(EXLA.Backend)

Section

defmodule BerryFarm do
  import Nx.Defn

  defn profits(trees) do
    # trees
    # |> Nx.subtract(1)
    # |> Nx.pow(4)
    # |> Nx.negate()
    # |> Nx.add(Nx.multiply(2, Nx.pow(trees, 3)))
    # |> Nx.add(Nx.pow(trees, 2))
    -((trees - 1) ** 4) + trees ** 3 + trees ** 2
  end

  defn profits_derivative(trees) do
    grad(trees, &profits/1)
  end
end
trees = Nx.linspace(0, 4, n: 100)
profits = BerryFarm.profits(trees)
profits_derivative = BerryFarm.profits_derivative(trees)

title = "Berry profits and profits derivatives"

VegaLite.new(title: title, width: 800, height: 500)
|> VegaLite.data_from_values(%{
  trees: Nx.to_flat_list(trees),
  profits: Nx.to_flat_list(profits),
  profits_derivative: Nx.to_flat_list(profits_derivative)
})
|> VegaLite.layers([
  VegaLite.new()
  |> VegaLite.mark(:line, interpolate: :basis)
  |> VegaLite.encode_field(:x, "trees", type: :quantitative, scale: %{domain: [0, 4]})
  |> VegaLite.encode_field(:y, "profits", type: :quantitative, scale: %{domain: [-49, 21]}),
  VegaLite.new()
  |> VegaLite.mark(:line, interpolate: :basis)
  |> VegaLite.encode_field(:x, "trees", type: :quantitative, scale: %{domain: [0, 4]})
  |> VegaLite.encode_field(:y, "profits_derivative", type: :quantitative)
  |> VegaLite.encode(:color, value: "#FF0000")
])
defmodule GradientPlayground do
  import Nx.Defn

  defn some_function(x) do
    # x
    # |> Nx.cos()
    # |> Nx.exp()
    # |> Nx.sum()
    # |> print_expr()
    Nx.sum(Nx.exp(Nx.cos(x))) |> print_expr()
  end

  defn(gradient_some_function(x), do: grad(x, &some_function/1) |> print_expr())
end
GradientPlayground.gradient_some_function(Nx.tensor([1, 2, 3]))