Harnessing the Power of Math
Mix.install([
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:kino, "~> 0.8"},
{:stb_image, "~> 0.6"},
{:vega_lite, "~> 0.1"},
{:kino_vega_lite, "~> 0.1"}
])
Speaking the Language of Data
Nx.default_backend(EXLA.Backend)
{Nx.BinaryBackend, []}
The Building Blocks of Linear Algebra
a = Nx.tensor([1, 2, 3])
b = Nx.tensor([4.0, 5.0, 6.0])
c = Nx.tensor([1, 0, 1], type: {:u, 8})
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
IO.inspect(c, label: :c)
a: #Nx.Tensor<
s64[3]
EXLA.Backend
[1, 2, 3]
>
b: #Nx.Tensor<
f32[3]
EXLA.Backend
[4.0, 5.0, 6.0]
>
c: #Nx.Tensor<
u8[3]
EXLA.Backend
[1, 0, 1]
>
#Nx.Tensor<
u8[3]
EXLA.Backend
[1, 0, 1]
>
goog_current_price = 2677.32
goog_pe = 23.86
goog_mkt_cap = 1760
goog = Nx.tensor([goog_current_price, goog_pe, goog_mkt_cap])
#Nx.Tensor<
f32[3]
EXLA.Backend
[2677.320068359375, 23.860000610351562, 1760.0]
>
i_am_a_scalar = Nx.tensor(5)
i_am_also_a_scalar = 5
5
goog_current_price = 2677.32
goog_pe = 23.86
goog_mkt_cap = 1760
meta_current_price = 133.93
meta_pe = 11.10
meta_mkt_cap = 360
stocks_matrix =
Nx.tensor([
[goog_current_price, goog_pe, goog_mkt_cap],
[meta_current_price, meta_pe, meta_mkt_cap]
])
IO.inspect(stocks_matrix)
#Nx.Tensor<
f32[2][3]
EXLA.Backend
[
[2677.320068359375, 23.860000610351562, 1760.0],
[133.92999267578125, 11.100000381469727, 360.0]
]
>
#Nx.Tensor<
f32[2][3]
EXLA.Backend
[
[2677.320068359375, 23.860000610351562, 1760.0],
[133.92999267578125, 11.100000381469727, 360.0]
]
>
Important Operations in Linear Algebra
Vector Addition
sales_day_1 = Nx.tensor([32, 10, 14])
sales_day_2 = Nx.tensor([10, 24, 21])
#Nx.Tensor<
s64[3]
EXLA.Backend
[10, 24, 21]
>
total_sales = Nx.add(sales_day_1, sales_day_2)
#Nx.Tensor<
s64[3]
EXLA.Backend
[42, 34, 35]
>
Scalar Multiplication
sales_day_1 = Nx.tensor([32, 10, 14])
sales_day_2 = Nx.tensor([10, 24, 21])
total_sales = Nx.add(sales_day_1, sales_day_2)
keep_rate = 0.9
unreturned_sales = Nx.multiply(keep_rate, total_sales)
#Nx.Tensor<
f32[3]
EXLA.Backend
[37.79999923706055, 30.599998474121094, 31.5]
>
price_per_product = Nx.tensor([9.95, 10.95, 5.99])
revenue_per_product = Nx.multiply(unreturned_sales, price_per_product)
#Nx.Tensor<
f32[3]
EXLA.Backend
[376.1099853515625, 335.0699768066406, 188.68499755859375]
>
Transpose
sales_matrix =
Nx.tensor([
[32, 10, 14],
[10, 24, 21]
])
Nx.transpose(sales_matrix)
#Nx.Tensor<
s64[3][2]
EXLA.Backend
[
[32, 10],
[10, 24],
[14, 21]
]
>
notebook_dir = __DIR__
"/Users/daniel/Code/iMetrical/elixir-garden/projects/machine-learning-in-elixir/code"
Linear Transformation
invert_color_channels =
Nx.tensor([
[-1, 0, 0],
[0, -1, 0],
[0, 0, -1]
])
# Cat.jpg is relative to livebook Directory
image_path = Path.join([__DIR__, "Cat.jpg"])
image_path
|> StbImage.read_file!()
|> StbImage.resize(256, 256)
|> StbImage.to_nx()
|> Nx.dot(invert_color_channels)
|> Nx.as_type({:u, 8})
|> Kino.Image.new()
vector = Nx.dot(Nx.tensor([1, 2, 3]), Nx.tensor([1, 2, 3]))
vector_matrix = Nx.dot(Nx.tensor([1, 2]), Nx.tensor([[1], [2]]))
matrix_matrix = Nx.dot(Nx.tensor([[1, 2]]), Nx.tensor([[3], [4]]))
vector |> IO.inspect(label: :vector)
vector_matrix |> IO.inspect(label: :vector_matrix)
matrix_matrix |> IO.inspect(label: :matrix_matrix)
vector: #Nx.Tensor<
s64
EXLA.Backend
14
>
vector_matrix: #Nx.Tensor<
s64[1]
EXLA.Backend
[5]
>
matrix_matrix: #Nx.Tensor<
s64[1][1]
EXLA.Backend
[
[11]
]
>
#Nx.Tensor<
s64[1][1]
EXLA.Backend
[
[11]
]
>
Thinking Probabilistically
Reasoning About Uncertainty
simulation = fn key ->
{value, key} = Nx.Random.uniform(key)
if Nx.to_number(value) < 0.5, do: {0, key}, else: {1, key}
end
#Function<42.39164016/1 in :erl_eval.expr/6>
key = Nx.Random.key(42)
for n <- [10, 100, 1000, 10000] do
Enum.map_reduce(1..n, key, fn _, key -> simulation.(key) end)
|> elem(0)
|> Enum.sum()
|> IO.inspect()
end
6
49
501
5025
[6, 49, 501, 5025]
Tracking Change
Understanding Differentiation
defmodule BerryFarm do
import Nx.Defn
defn profits(trees) do
-((trees - 1) ** 4) + trees ** 3 + trees ** 2
end
defn profits_derivative(trees) do
grad(trees, &profits/1)
end
end
{:module, BerryFarm, <<70, 79, 82, 49, 0, 0, 12, ...>>, true}
trees = Nx.linspace(0, 4, n: 100)
profits = BerryFarm.profits(trees)
alias VegaLite, as: Vl
Vl.new(title: "Berry Profits", width: 1440, height: 1080)
|> Vl.data_from_values(%{
trees: Nx.to_flat_list(trees),
profits: Nx.to_flat_list(profits)
})
|> Vl.mark(:line, interpolate: :basis)
|> Vl.encode_field(:x, "trees", type: :quantitative)
|> Vl.encode_field(:y, "profits", type: :quantitative)
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"profits":-1.0,"trees":0.0},{"profits":-0.8462191820144653,"trees":0.04040404036641121},{"profits":-0.706821620464325,"trees":0.08080808073282242},{"profits":-0.5799248218536377,"trees":0.12121212482452393},{"profits":-0.46370965242385864,"trees":0.16161616146564484},{"profits":-0.3564212918281555,"trees":0.20202019810676575},{"profits":-0.2563686668872833,"trees":0.24242424964904785},{"profits":-0.16192500293254852,"trees":0.28282827138900757},{"profits":-0.0715271383523941,"trees":0.3232323229312897},{"profits":0.01632404327392578,"trees":0.3636363744735718},{"profits":0.10306348651647568,"trees":0.4040403962135315},{"profits":0.19006246328353882,"trees":0.4444444477558136},{"profits":0.2786282002925873,"trees":0.4848484992980957},{"profits":0.37000375986099243,"trees":0.5252525210380554},{"profits":0.46536850929260254,"trees":0.5656565427780151},{"profits":0.5658379197120667,"trees":0.6060606241226196},{"profits":0.6724629998207092,"trees":0.6464646458625793},{"profits":0.786231279373169,"trees":0.6868686676025391},{"profits":0.90806645154953,"trees":0.7272727489471436},{"profits":1.0388275384902954,"trees":0.7676767706871033},{"profits":1.1793103218078613,"trees":0.808080792427063},{"profits":1.3302464485168457,"trees":0.8484848737716675},{"profits":1.4923030138015747,"trees":0.8888888955116272},{"profits":1.6660840511322021,"trees":0.9292929172515869},{"profits":1.8521294593811035,"trees":0.9696969985961914},{"profits":2.0509138107299805,"trees":1.0101009607315063},{"profits":2.262850522994995,"trees":1.0505050420761108},{"profits":2.4882864952087402,"trees":1.0909091234207153},{"profits":2.7275047302246094,"trees":1.1313130855560303},{"profits":2.980726718902588,"trees":1.1717171669006348},{"profits":3.24810791015625,"trees":1.2121212482452393},{"profits":3.5297389030456543,"trees":1.2525252103805542},{"profits":3.8256494998931885,"trees":1.2929292917251587},{"profits":4.135802745819092,"trees":1.3333333730697632},{"profits":4.460097789764404,"trees":1.3737373352050781},{"profits":4.798373222351074,"trees":1.4141414165496826},{"profits":5.150400161743164,"trees":1.454545497894287},{"profits":5.5158843994140625,"trees":1.494949460029602},{"profits":5.894474029541016,"trees":1.5353535413742065},{"profits":6.285747528076172,"trees":1.575757622718811},{"profits":6.6892194747924805,"trees":1.616161584854126},{"profits":7.104344844818115,"trees":1.6565656661987305},{"profits":7.530511379241943,"trees":1.696969747543335},{"profits":7.967041969299316,"trees":1.73737370967865},{"profits":8.413199424743652,"trees":1.7777777910232544},{"profits":8.868179321289062,"trees":1.8181818723678589},{"profits":9.331111907958984,"trees":1.8585858345031738},{"profits":9.801070213317871,"trees":1.8989899158477783},{"profits":10.277055740356445,"trees":1.9393939971923828},{"profits":10.75800895690918,"trees":1.9797979593276978},{"profits":11.242805480957031,"trees":2.0202019214630127},{"profits":11.730262756347656,"trees":2.060606002807617},{"profits":12.219127655029297,"trees":2.1010100841522217},{"profits":12.708084106445312,"trees":2.141414165496826},{"profits":13.19575309753418,"trees":2.1818182468414307},{"profits":13.68069076538086,"trees":2.222222328186035},{"profits":14.16138744354248,"trees":2.2626261711120605},{"profits":14.636279106140137,"trees":2.303030252456665},{"profits":15.10372543334961,"trees":2.3434343338012695},{"profits":15.562030792236328,"trees":2.383838415145874},{"profits":16.009429931640625,"trees":2.4242424964904785},{"profits":16.444095611572266,"trees":2.464646577835083},{"profits":16.864139556884766,"trees":2.5050504207611084},{"profits":17.26760482788086,"trees":2.545454502105713},{"profits":17.65247344970703,"trees":2.5858585834503174},{"profits":18.016660690307617,"trees":2.626262664794922},{"profits":18.35802459716797,"trees":2.6666667461395264},{"profits":18.67435073852539,"trees":2.7070705890655518},{"profits":18.963363647460938,"trees":2.7474746704101562},{"profits":19.222728729248047,"trees":2.7878787517547607},{"profits":19.450042724609375,"trees":2.8282828330993652},{"profits":19.642837524414062,"trees":2.8686869144439697},{"profits":19.798580169677734,"trees":2.909090995788574},{"profits":19.91468048095703,"trees":2.9494948387145996},{"profits":19.988475799560547,"trees":2.989898920059204},{"profits":20.017250061035156,"trees":3.0303030014038086},{"profits":19.998214721679688,"trees":3.070707082748413},{"profits":19.928516387939453,"trees":3.1111111640930176},{"profits":19.80524444580078,"trees":3.151515245437622},{"profits":19.62541961669922,"trees":3.1919190883636475},{"profits":19.385997772216797,"trees":3.232323169708252},{"profits":19.083873748779297,"trees":3.2727272510528564},{"profits":18.715879440307617,"trees":3.313131332397461},{"profits":18.27878189086914,"trees":3.3535354137420654},{"profits":17.769275665283203,"trees":3.39393949508667},{"profits":17.184011459350586,"trees":3.4343433380126953},{"profits":16.51955223083496,"trees":3.4747474193573},{"profits":15.772408485412598,"trees":3.5151515007019043},{"profits":14.939033508300781,"trees":3.555555582046509},{"profits":14.015803337097168,"trees":3.5959596633911133},{"profits":12.999043464660645,"trees":3.6363637447357178},{"profits":11.885002136230469,"trees":3.676767587661743},{"profits":10.66987419128418,"trees":3.7171716690063477},{"profits":9.349775314331055,"trees":3.757575750350952},{"profits":7.920779228210449,"trees":3.7979798316955566},{"profits":6.37888240814209,"trees":3.838383913040161},{"profits":4.720010757446289,"trees":3.8787879943847656},{"profits":2.9400548934936523,"trees":3.919191837310791},{"profits":1.0348033905029297,"trees":3.9595959186553955},{"profits":-1.0,"trees":4.0}]},"encoding":{"x":{"field":"trees","type":"quantitative"},"y":{"field":"profits","type":"quantitative"}},"height":1080,"mark":{"interpolate":"basis","type":"line"},"title":"Berry Profits","width":1440}
trees = Nx.linspace(0, 4, n: 100)
profits = BerryFarm.profits(trees)
profits_derivative = BerryFarm.profits_derivative(trees)
alias VegaLite, as: Vl
Vl.new(title: "Berry Profits and Profits Rate of Change", width: 1440, height: 1080)
|> Vl.data_from_values(%{
trees: Nx.to_flat_list(trees),
profits: Nx.to_flat_list(profits),
profits_derivative: Nx.to_flat_list(profits_derivative)
})
|> Vl.layers([
Vl.new()
|> Vl.mark(:line, interpolate: :basis)
|> Vl.encode_field(:x, "trees", type: :quantitative)
|> Vl.encode_field(:y, "profits", type: :quantitative),
Vl.new()
|> Vl.mark(:line, interpolate: :basis)
|> Vl.encode_field(:x, "trees", type: :quantitative)
|> Vl.encode_field(:y, "profits_derivative", type: :quantitative)
|> Vl.encode(:color, value: "#ff0000")
])
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"profits":-1.0,"profits_derivative":4.0,"trees":0.0},{"profits":-0.8462191820144653,"profits_derivative":3.620183229446411,"trees":0.04040404036641121},{"profits":-0.706821620464325,"profits_derivative":3.287757396697998,"trees":0.08080808073282242},{"profits":-0.5799248218536377,"profits_derivative":3.001140832901001,"trees":0.12121212482452393},{"profits":-0.46370965242385864,"profits_derivative":2.7587499618530273,"trees":0.16161616146564484},{"profits":-0.3564212918281555,"profits_derivative":2.5590012073516846,"trees":0.20202019810676575},{"profits":-0.2563686668872833,"profits_derivative":2.4003114700317383,"trees":0.24242424964904785},{"profits":-0.16192500293254852,"profits_derivative":2.2810988426208496,"trees":0.28282827138900757},{"profits":-0.0715271383523941,"profits_derivative":2.199779748916626,"trees":0.3232323229312897},{"profits":0.01632404327392578,"profits_derivative":2.154770851135254,"trees":0.3636363744735718},{"profits":0.10306348651647568,"profits_derivative":2.144489288330078,"trees":0.4040403962135315},{"profits":0.19006246328353882,"profits_derivative":2.1673526763916016,"trees":0.4444444477558136},{"profits":0.2786282002925873,"profits_derivative":2.2217769622802734,"trees":0.4848484992980957},{"profits":0.37000375986099243,"profits_derivative":2.3061797618865967,"trees":0.5252525210380554},{"profits":0.46536850929260254,"profits_derivative":2.418977975845337,"trees":0.5656565427780151},{"profits":0.5658379197120667,"profits_derivative":2.558588743209839,"trees":0.6060606241226196},{"profits":0.6724629998207092,"profits_derivative":2.723428726196289,"trees":0.6464646458625793},{"profits":0.786231279373169,"profits_derivative":2.911914825439453,"trees":0.6868686676025391},{"profits":0.90806645154953,"profits_derivative":3.122464418411255,"trees":0.7272727489471436},{"profits":1.0388275384902954,"profits_derivative":3.353494167327881,"trees":0.7676767706871033},{"profits":1.1793103218078613,"profits_derivative":3.603421211242676,"trees":0.808080792427063},{"profits":1.3302464485168457,"profits_derivative":3.8706626892089844,"trees":0.8484848737716675},{"profits":1.4923030138015747,"profits_derivative":4.153635025024414,"trees":0.8888888955116272},{"profits":1.6660840511322021,"profits_derivative":4.450756072998047,"trees":0.9292929172515869},{"profits":1.8521294593811035,"profits_derivative":4.760441780090332,"trees":0.9696969985961914},{"profits":2.0509138107299805,"profits_derivative":5.081110000610352,"trees":1.0101009607315063},{"profits":2.262850522994995,"profits_derivative":5.411177158355713,"trees":1.0505050420761108},{"profits":2.4882864952087402,"profits_derivative":5.749061107635498,"trees":1.0909091234207153},{"profits":2.7275047302246094,"profits_derivative":6.09317684173584,"trees":1.1313130855560303},{"profits":2.980726718902588,"profits_derivative":6.441944122314453,"trees":1.1717171669006348},{"profits":3.24810791015625,"profits_derivative":6.793778419494629,"trees":1.2121212482452393},{"profits":3.5297389030456543,"profits_derivative":7.147095680236816,"trees":1.2525252103805542},{"profits":3.8256494998931885,"profits_derivative":7.500314712524414,"trees":1.2929292917251587},{"profits":4.135802745819092,"profits_derivative":7.8518524169921875,"trees":1.3333333730697632},{"profits":4.460097789764404,"profits_derivative":8.200122833251953,"trees":1.3737373352050781},{"profits":4.798373222351074,"profits_derivative":8.543547630310059,"trees":1.4141414165496826},{"profits":5.150400161743164,"profits_derivative":8.88054084777832,"trees":1.454545497894287},{"profits":5.5158843994140625,"profits_derivative":9.209519386291504,"trees":1.494949460029602},{"profits":5.894474029541016,"profits_derivative":9.528902053833008,"trees":1.5353535413742065},{"profits":6.285747528076172,"profits_derivative":9.837104797363281,"trees":1.575757622718811},{"profits":6.6892194747924805,"profits_derivative":10.13254165649414,"trees":1.616161584854126},{"profits":7.104344844818115,"profits_derivative":10.41363525390625,"trees":1.6565656661987305},{"profits":7.530511379241943,"profits_derivative":10.678799629211426,"trees":1.696969747543335},{"profits":7.967041969299316,"profits_derivative":10.926450729370117,"trees":1.73737370967865},{"profits":8.413199424743652,"profits_derivative":11.155006408691406,"trees":1.7777777910232544},{"profits":8.868179321289062,"profits_derivative":11.362886428833008,"trees":1.8181818723678589},{"profits":9.331111907958984,"profits_derivative":11.548501968383789,"trees":1.8585858345031738},{"profits":9.801070213317871,"profits_derivative":11.710275650024414,"trees":1.8989899158477783},{"profits":10.277055740356445,"profits_derivative":11.846620559692383,"trees":1.9393939971923828},{"profits":10.75800895690918,"profits_derivative":11.95595645904541,"trees":1.9797979593276978},{"profits":11.242805480957031,"profits_derivative":12.036697387695312,"trees":2.0202019214630127},{"profits":11.730262756347656,"profits_derivative":12.087263107299805,"trees":2.060606002807617},{"profits":12.219127655029297,"profits_derivative":12.106069564819336,"trees":2.1010100841522217},{"profits":12.708084106445312,"profits_derivative":12.091534614562988,"trees":2.141414165496826},{"profits":13.19575309753418,"profits_derivative":12.042073249816895,"trees":2.1818182468414307},{"profits":13.68069076538086,"profits_derivative":11.956103324890137,"trees":2.222222328186035},{"profits":14.16138744354248,"profits_derivative":11.832043647766113,"trees":2.2626261711120605},{"profits":14.636279106140137,"profits_derivative":11.66830825805664,"trees":2.303030252456665},{"profits":15.10372543334961,"profits_derivative":11.463315963745117,"trees":2.3434343338012695},{"profits":15.562030792236328,"profits_derivative":11.21548080444336,"trees":2.383838415145874},{"profits":16.009429931640625,"profits_derivative":10.923226356506348,"trees":2.4242424964904785},{"profits":16.444095611572266,"profits_derivative":10.584964752197266,"trees":2.464646577835083},{"profits":16.864139556884766,"profits_derivative":10.199111938476562,"trees":2.5050504207611084},{"profits":17.26760482788086,"profits_derivative":9.764086723327637,"trees":2.545454502105713},{"profits":17.65247344970703,"profits_derivative":9.27830696105957,"trees":2.5858585834503174},{"profits":18.016660690307617,"profits_derivative":8.740188598632812,"trees":2.626262664794922},{"profits":18.35802459716797,"profits_derivative":8.148149490356445,"trees":2.6666667461395264},{"profits":18.67435073852539,"profits_derivative":7.50060510635376,"trees":2.7070705890655518},{"profits":18.963363647460938,"profits_derivative":6.79597282409668,"trees":2.7474746704101562},{"profits":19.222728729248047,"profits_derivative":6.032668590545654,"trees":2.7878787517547607},{"profits":19.450042724609375,"profits_derivative":5.209111213684082,"trees":2.8282828330993652},{"profits":19.642837524414062,"profits_derivative":4.3237175941467285,"trees":2.8686869144439697},{"profits":19.798580169677734,"profits_derivative":3.3749046325683594,"trees":2.909090995788574},{"profits":19.91468048095703,"profits_derivative":2.3610944747924805,"trees":2.9494948387145996},{"profits":19.988475799560547,"profits_derivative":1.2806897163391113,"trees":2.989898920059204},{"profits":20.017250061035156,"profits_derivative":0.1321239471435547,"trees":3.0303030014038086},{"profits":19.998214721679688,"profits_derivative":-1.0862011909484863,"trees":3.070707082748413},{"profits":19.928516387939453,"profits_derivative":-2.375861167907715,"trees":3.1111111640930176},{"profits":19.80524444580078,"profits_derivative":-3.738431453704834,"trees":3.151515245437622},{"profits":19.62541961669922,"profits_derivative":-5.175499439239502,"trees":3.1919190883636475},{"profits":19.385997772216797,"profits_derivative":-6.688662528991699,"trees":3.232323169708252},{"profits":19.083873748779297,"profits_derivative":-8.279485702514648,"trees":3.2727272510528564},{"profits":18.715879440307617,"profits_derivative":-9.949562072753906,"trees":3.313131332397461},{"profits":18.27878189086914,"profits_derivative":-11.700475692749023,"trees":3.3535354137420654},{"profits":17.769275665283203,"profits_derivative":-13.533799171447754,"trees":3.39393949508667},{"profits":17.184011459350586,"profits_derivative":-15.451114654541016,"trees":3.4343433380126953},{"profits":16.51955223083496,"profits_derivative":-17.454015731811523,"trees":3.4747474193573},{"profits":15.772408485412598,"profits_derivative":-19.544090270996094,"trees":3.5151515007019043},{"profits":14.939033508300781,"profits_derivative":-21.722911834716797,"trees":3.555555582046509},{"profits":14.015803337097168,"profits_derivative":-23.992055892944336,"trees":3.5959596633911133},{"profits":12.999043464660645,"profits_derivative":-26.353118896484375,"trees":3.6363637447357178},{"profits":11.885002136230469,"profits_derivative":-28.80767059326172,"trees":3.676767587661743},{"profits":10.66987419128418,"profits_derivative":-31.35731315612793,"trees":3.7171716690063477},{"profits":9.349775314331055,"profits_derivative":-34.00361633300781,"trees":3.757575750350952},{"profits":7.920779228210449,"profits_derivative":-36.7481689453125,"trees":3.7979798316955566},{"profits":6.37888240814209,"profits_derivative":-39.59254455566406,"trees":3.838383913040161},{"profits":4.720010757446289,"profits_derivative":-42.53834533691406,"trees":3.8787879943847656},{"profits":2.9400548934936523,"profits_derivative":-45.58710861206055,"trees":3.919191837310791},{"profits":1.0348033905029297,"profits_derivative":-48.7404670715332,"trees":3.9595959186553955},{"profits":-1.0,"profits_derivative":-52.0,"trees":4.0}]},"height":1080,"layer":[{"encoding":{"x":{"field":"trees","type":"quantitative"},"y":{"field":"profits","type":"quantitative"}},"mark":{"interpolate":"basis","type":"line"}},{"encoding":{"color":{"value":"#ff0000"},"x":{"field":"trees","type":"quantitative"},"y":{"field":"profits_derivative","type":"quantitative"}},"mark":{"interpolate":"basis","type":"line"}}],"title":"Berry Profits and Profits Rate of Change","width":1440}
Automatic differentiation with defn
defmodule GradFun do
import Nx.Defn
defn my_function(x) do
x
|> Nx.cos()
|> Nx.exp()
|> Nx.sum()
|> print_expr()
end
defn grad_my_function(x) do
grad(x, &my_function/1) |> print_expr()
end
end
{:module, GradFun, <<70, 79, 82, 49, 0, 0, 12, ...>>, true}
GradFun.grad_my_function(Nx.tensor([1.0, 2.0, 3.0]))
#Nx.Tensor<
f32
Nx.Defn.Expr
parameter a:0 f32[3]
b = cos a f32[3]
c = exp b f32[3]
d = sum c, axes: nil, keep_axes: false f32
>
#Nx.Tensor<
f32[3]
Nx.Defn.Expr
parameter a:0 f32[3]
b = cos a f32[3]
c = exp b f32[3]
d = sin a f32[3]
e = negate d f32[3]
f = multiply c, e f32[3]
>
#Nx.Tensor<
f32[3]
EXLA.Backend
[-1.444406509399414, -0.5997574925422668, -0.05243729427456856]
>