Ch 3: Harness the Power of Math
Mix.install([
{:nx, "~> 0.7.2"},
{:exla, "~> 0.7.2"},
{:kino, "~> 0.12.3"},
{:stb_image, "~> 0.6.8"},
{:vega_lite, "~> 0.1.9"},
{:kino_vega_lite, "~> 0.1.11"}
])
Setup
Nx.default_backend(EXLA.Backend)
{Nx.BinaryBackend, []}
The Building Blocks of Linear Algebra
The fundamental object in Linear Algebra is the vector. Nx
doesn’t explicitly differentiate between scalars, vectors and matrices. Everythin in Nx
is a tensor
a = Nx.tensor([1, 2, 3])
b = Nx.tensor([4.0, 5.0, 6.0])
c = Nx.tensor([1, 0, 1], type: {:u, 8})
IO.inspect(a, label: :a)
IO.inspect(b, label: :b)
IO.inspect(c, label: :c)
a: #Nx.Tensor<
s64[3]
EXLA.Backend
[1, 2, 3]
>
b: #Nx.Tensor<
f32[3]
EXLA.Backend
[4.0, 5.0, 6.0]
>
c: #Nx.Tensor<
u8[3]
EXLA.Backend
[1, 0, 1]
>
#Nx.Tensor<
u8[3]
EXLA.Backend
[1, 0, 1]
>
Important Operations in Linear Algrebra
Some of the most important and fundamental operations in linear algebra are:
- Vector addition
- Scalar multiplication
- Transpose
- Linear Transformations
sales_day_1 = Nx.tensor([32, 10, 14])
sales_day_2 = Nx.tensor([10, 24, 21])
#Nx.Tensor<
s64[3]
EXLA.Backend
[10, 24, 21]
>
total_sales = Nx.add(sales_day_1, sales_day_2)
#Nx.Tensor<
s64[3]
EXLA.Backend
[42, 34, 35]
>
keep_rate = 0.9
unreturned_sales = Nx.multiply(keep_rate, total_sales)
#Nx.Tensor<
f32[3]
EXLA.Backend
[37.79999923706055, 30.599998474121094, 31.5]
>
price_per_product = Nx.tensor([9.95, 10.95, 5.99])
revenue_per_product = Nx.multiply(unreturned_sales, price_per_product)
#Nx.Tensor<
f32[3]
EXLA.Backend
[376.1099853515625, 335.0699768066406, 188.68499755859375]
>
sales_matrix =
Nx.tensor([
[32, 10, 14],
[10, 24, 21]
])
Nx.transpose(sales_matrix)
#Nx.Tensor<
s64[3][2]
EXLA.Backend
[
[32, 10],
[10, 24],
[14, 21]
]
>
Kino.FS.file_path("Cat.jpg")
|> StbImage.read_file!()
|> StbImage.resize(256, 256)
|> StbImage.to_nx()
|> Nx.as_type({:u, 8})
|> Kino.Image.new()
invert_color_channels =
Nx.tensor([
[-1, 0, 0],
[0, -1, 0],
[0, 0, 1]
])
Kino.FS.file_path("Cat.jpg")
|> StbImage.read_file!()
|> StbImage.resize(256, 256)
|> StbImage.to_nx()
|> Nx.dot(invert_color_channels)
|> Nx.as_type({:u, 8})
|> Kino.Image.new()
Thinking Probablistically
Probability helps you think about and make predictions from data
The difference between modern approaches and previous generations of ML, largely based on logic, is the acceptance of probabilistic methods over logic-based approaches
There are three primary tools used in ML to reason about data and make predictions: probability theory, decision theory and information theory
Reasoning about Uncertainty
Probability theory is a framework for understanding uncertainty. Uncertainty is a fact of life therefore it is best to embrace it rather than ignore it. ML shines in uncertain situations because uncertainty is a built-in assumption
All about Bayes
Bayes Theorem describes the conditional probability of an event based on available information or the occurence of another event. Bayes Theorem is a rule that describes how probabilities should be updated in the face of new evidence
As a ML model is trained it is making inferences or predictions about it’s surroundings (data). After a prediction is made, a decision must be made
Making Decisions
Decision Theory provides a framework for acting optimally in the presence of uncertainty. When an ML model is making decisions in the presence of uncertainty, the model must consider the objective in order to make decisions after inference
Learning from Observations
Information theory provides a framework for reasoning about information systems. Information theory is a mathematical study of coding information in sequences and symbols and the study of how much information can be stored and transmitted in these mediums.
In Information theory, the main concern is the degree of surprise of a particular observation
Tracking Changes
While exploring the field of ML, the terms derivative, gradient and automatic differentiation come up often. All of these terms come from Vector calculus
Vector calculus is concerned with differentiation and integration of vector fields or functions
Understanding Differentiation
The most important concept from calculus to understand in ML is the derivative. A derivative is a measure of the instantaneous rate of change of a function
For example, to maximize profits for the year, you could compute the derivative of your profit function to determine the amount of a product which maximizes profit
defmodule BerryFarm do
import Nx.Defn
defn profits(trees) do
trees
|> Nx.subtract(1)
|> Nx.pow(4)
|> Nx.negate()
|> Nx.add(Nx.pow(trees, 3))
|> Nx.add(Nx.pow(trees, 2))
end
defn profits_derivative(trees) do
grad(trees, &profits/1)
end
end
{:module, BerryFarm, <<70, 79, 82, 49, 0, 0, 12, ...>>, true}
grad/2
is an Nx
function which is capable of taking the gradient of a function
The gradient is the direction of greatest change of a scalar function. The gradient is like a heat-seeking missle attracted to change
An analogy, you are on a boat in a lake and you want to find the location of the deepest point in the lake as quickly as possible using a depth finder, how do you solve this problem? One method is to compute the depth-change in every direction and always travel in the direction of steepest descent, that way you always travel in the direction of steepest descent. That way you can guarentee you’re always following a path towards a deeper depth in the lake. If the bottom of the lake resembles a valley, you’re basically guarenteed to find the steepest point
From the analogy, the optimization process of objective functions in ML models is the same process using gradient descent. Gradient descent is useful because often times you don’t have access to the complete gradient of the objective function. Instead the model needs to look at a few samples at a time and move accordingly
Gradient-based optimization is a powerful form of optimization and is widely used in ML, especially DL. With automatic differentiation, all you need is an objective function and you can optimize it with respect to some model parameters
trees = Nx.linspace(0, 4, n: 100)
profits = BerryFarm.profits(trees)
profits_derivative = BerryFarm.profits_derivative(trees)
VegaLite.new(title: "Berry Profits", width: 600, height: 400)
|> VegaLite.data_from_values(%{
trees: Nx.to_flat_list(trees),
profits: Nx.to_flat_list(profits),
profits_derivative: Nx.to_flat_list(profits_derivative)
})
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.mark(:line, interpolate: :basis)
|> VegaLite.encode_field(:x, "trees", type: :quantitative)
|> VegaLite.encode_field(:y, "profits", type: :quantitative),
VegaLite.new()
|> VegaLite.mark(:line, interpolate: :basis)
|> VegaLite.encode_field(:x, "trees", type: :quantitative)
|> VegaLite.encode_field(:y, "profits_derivative", type: :quantitative)
|> VegaLite.encode(:color, value: "#ff0000")
])
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"profits":-1.0,"profits_derivative":4.0,"trees":0.0},{"profits":-0.8462191820144653,"profits_derivative":3.620183229446411,"trees":0.04040404036641121},{"profits":-0.706821620464325,"profits_derivative":3.287757396697998,"trees":0.08080808073282242},{"profits":-0.5799248218536377,"profits_derivative":3.001140832901001,"trees":0.12121212482452393},{"profits":-0.46370965242385864,"profits_derivative":2.7587499618530273,"trees":0.16161616146564484},{"profits":-0.3564212918281555,"profits_derivative":2.5590012073516846,"trees":0.20202019810676575},{"profits":-0.2563686668872833,"profits_derivative":2.4003114700317383,"trees":0.24242424964904785},{"profits":-0.16192500293254852,"profits_derivative":2.2810988426208496,"trees":0.28282827138900757},{"profits":-0.0715271383523941,"profits_derivative":2.199779748916626,"trees":0.3232323229312897},{"profits":0.01632404327392578,"profits_derivative":2.154770851135254,"trees":0.3636363744735718},{"profits":0.10306348651647568,"profits_derivative":2.144489288330078,"trees":0.4040403962135315},{"profits":0.19006246328353882,"profits_derivative":2.1673526763916016,"trees":0.4444444477558136},{"profits":0.2786282002925873,"profits_derivative":2.2217769622802734,"trees":0.4848484992980957},{"profits":0.37000375986099243,"profits_derivative":2.3061797618865967,"trees":0.5252525210380554},{"profits":0.46536850929260254,"profits_derivative":2.418977975845337,"trees":0.5656565427780151},{"profits":0.5658379197120667,"profits_derivative":2.558588743209839,"trees":0.6060606241226196},{"profits":0.6724629998207092,"profits_derivative":2.723428726196289,"trees":0.6464646458625793},{"profits":0.786231279373169,"profits_derivative":2.911914825439453,"trees":0.6868686676025391},{"profits":0.90806645154953,"profits_derivative":3.122464418411255,"trees":0.7272727489471436},{"profits":1.0388275384902954,"profits_derivative":3.353494167327881,"trees":0.7676767706871033},{"profits":1.1793103218078613,"profits_derivative":3.603421211242676,"trees":0.808080792427063},{"profits":1.3302464485168457,"profits_derivative":3.8706626892089844,"trees":0.8484848737716675},{"profits":1.4923030138015747,"profits_derivative":4.153635025024414,"trees":0.8888888955116272},{"profits":1.6660840511322021,"profits_derivative":4.450756072998047,"trees":0.9292929172515869},{"profits":1.8521294593811035,"profits_derivative":4.760441780090332,"trees":0.9696969985961914},{"profits":2.0509138107299805,"profits_derivative":5.081110000610352,"trees":1.0101009607315063},{"profits":2.262850522994995,"profits_derivative":5.411177158355713,"trees":1.0505050420761108},{"profits":2.4882864952087402,"profits_derivative":5.749061107635498,"trees":1.0909091234207153},{"profits":2.7275047302246094,"profits_derivative":6.09317684173584,"trees":1.1313130855560303},{"profits":2.980726718902588,"profits_derivative":6.441944122314453,"trees":1.1717171669006348},{"profits":3.24810791015625,"profits_derivative":6.793778419494629,"trees":1.2121212482452393},{"profits":3.5297389030456543,"profits_derivative":7.147095680236816,"trees":1.2525252103805542},{"profits":3.8256494998931885,"profits_derivative":7.500314712524414,"trees":1.2929292917251587},{"profits":4.135802745819092,"profits_derivative":7.8518524169921875,"trees":1.3333333730697632},{"profits":4.460097789764404,"profits_derivative":8.200122833251953,"trees":1.3737373352050781},{"profits":4.798373222351074,"profits_derivative":8.543547630310059,"trees":1.4141414165496826},{"profits":5.150400161743164,"profits_derivative":8.88054084777832,"trees":1.454545497894287},{"profits":5.5158843994140625,"profits_derivative":9.209519386291504,"trees":1.494949460029602},{"profits":5.894474029541016,"profits_derivative":9.528902053833008,"trees":1.5353535413742065},{"profits":6.285747528076172,"profits_derivative":9.837104797363281,"trees":1.575757622718811},{"profits":6.6892194747924805,"profits_derivative":10.13254165649414,"trees":1.616161584854126},{"profits":7.104344844818115,"profits_derivative":10.41363525390625,"trees":1.6565656661987305},{"profits":7.530511379241943,"profits_derivative":10.678799629211426,"trees":1.696969747543335},{"profits":7.967041969299316,"profits_derivative":10.926450729370117,"trees":1.73737370967865},{"profits":8.413199424743652,"profits_derivative":11.155006408691406,"trees":1.7777777910232544},{"profits":8.868179321289062,"profits_derivative":11.362886428833008,"trees":1.8181818723678589},{"profits":9.331111907958984,"profits_derivative":11.548501968383789,"trees":1.8585858345031738},{"profits":9.801070213317871,"profits_derivative":11.710275650024414,"trees":1.8989899158477783},{"profits":10.277055740356445,"profits_derivative":11.846620559692383,"trees":1.9393939971923828},{"profits":10.75800895690918,"profits_derivative":11.95595645904541,"trees":1.9797979593276978},{"profits":11.242805480957031,"profits_derivative":12.036697387695312,"trees":2.0202019214630127},{"profits":11.730262756347656,"profits_derivative":12.087263107299805,"trees":2.060606002807617},{"profits":12.219127655029297,"profits_derivative":12.106069564819336,"trees":2.1010100841522217},{"profits":12.708084106445312,"profits_derivative":12.091534614562988,"trees":2.141414165496826},{"profits":13.19575309753418,"profits_derivative":12.042073249816895,"trees":2.1818182468414307},{"profits":13.68069076538086,"profits_derivative":11.956103324890137,"trees":2.222222328186035},{"profits":14.16138744354248,"profits_derivative":11.832043647766113,"trees":2.2626261711120605},{"profits":14.636279106140137,"profits_derivative":11.66830825805664,"trees":2.303030252456665},{"profits":15.10372543334961,"profits_derivative":11.463315963745117,"trees":2.3434343338012695},{"profits":15.562030792236328,"profits_derivative":11.21548080444336,"trees":2.383838415145874},{"profits":16.009429931640625,"profits_derivative":10.923226356506348,"trees":2.4242424964904785},{"profits":16.444095611572266,"profits_derivative":10.584964752197266,"trees":2.464646577835083},{"profits":16.864139556884766,"profits_derivative":10.199111938476562,"trees":2.5050504207611084},{"profits":17.26760482788086,"profits_derivative":9.764086723327637,"trees":2.545454502105713},{"profits":17.65247344970703,"profits_derivative":9.27830696105957,"trees":2.5858585834503174},{"profits":18.016660690307617,"profits_derivative":8.740188598632812,"trees":2.626262664794922},{"profits":18.35802459716797,"profits_derivative":8.148149490356445,"trees":2.6666667461395264},{"profits":18.67435073852539,"profits_derivative":7.50060510635376,"trees":2.7070705890655518},{"profits":18.963363647460938,"profits_derivative":6.79597282409668,"trees":2.7474746704101562},{"profits":19.222728729248047,"profits_derivative":6.032668590545654,"trees":2.7878787517547607},{"profits":19.450042724609375,"profits_derivative":5.209111213684082,"trees":2.8282828330993652},{"profits":19.642837524414062,"profits_derivative":4.3237175941467285,"trees":2.8686869144439697},{"profits":19.798580169677734,"profits_derivative":3.3749046325683594,"trees":2.909090995788574},{"profits":19.91468048095703,"profits_derivative":2.3610944747924805,"trees":2.9494948387145996},{"profits":19.988475799560547,"profits_derivative":1.2806897163391113,"trees":2.989898920059204},{"profits":20.017250061035156,"profits_derivative":0.1321239471435547,"trees":3.0303030014038086},{"profits":19.998214721679688,"profits_derivative":-1.0862011909484863,"trees":3.070707082748413},{"profits":19.928516387939453,"profits_derivative":-2.375861167907715,"trees":3.1111111640930176},{"profits":19.80524444580078,"profits_derivative":-3.738431453704834,"trees":3.151515245437622},{"profits":19.62541961669922,"profits_derivative":-5.175499439239502,"trees":3.1919190883636475},{"profits":19.385997772216797,"profits_derivative":-6.688662528991699,"trees":3.232323169708252},{"profits":19.083873748779297,"profits_derivative":-8.279485702514648,"trees":3.2727272510528564},{"profits":18.715879440307617,"profits_derivative":-9.949562072753906,"trees":3.313131332397461},{"profits":18.27878189086914,"profits_derivative":-11.700475692749023,"trees":3.3535354137420654},{"profits":17.769275665283203,"profits_derivative":-13.533799171447754,"trees":3.39393949508667},{"profits":17.184011459350586,"profits_derivative":-15.451114654541016,"trees":3.4343433380126953},{"profits":16.51955223083496,"profits_derivative":-17.454015731811523,"trees":3.4747474193573},{"profits":15.772408485412598,"profits_derivative":-19.544090270996094,"trees":3.5151515007019043},{"profits":14.939033508300781,"profits_derivative":-21.722911834716797,"trees":3.555555582046509},{"profits":14.015803337097168,"profits_derivative":-23.992055892944336,"trees":3.5959596633911133},{"profits":12.999043464660645,"profits_derivative":-26.353118896484375,"trees":3.6363637447357178},{"profits":11.885002136230469,"profits_derivative":-28.80767059326172,"trees":3.676767587661743},{"profits":10.66987419128418,"profits_derivative":-31.35731315612793,"trees":3.7171716690063477},{"profits":9.349775314331055,"profits_derivative":-34.00361633300781,"trees":3.757575750350952},{"profits":7.920779228210449,"profits_derivative":-36.7481689453125,"trees":3.7979798316955566},{"profits":6.37888240814209,"profits_derivative":-39.59254455566406,"trees":3.838383913040161},{"profits":4.720010757446289,"profits_derivative":-42.53834533691406,"trees":3.8787879943847656},{"profits":2.9400548934936523,"profits_derivative":-45.58710861206055,"trees":3.919191837310791},{"profits":1.0348033905029297,"profits_derivative":-48.7404670715332,"trees":3.9595959186553955},{"profits":-1.0,"profits_derivative":-52.0,"trees":4.0}]},"height":400,"layer":[{"encoding":{"x":{"field":"trees","type":"quantitative"},"y":{"field":"profits","type":"quantitative"}},"mark":{"interpolate":"basis","type":"line"}},{"encoding":{"color":{"value":"#ff0000"},"x":{"field":"trees","type":"quantitative"},"y":{"field":"profits_derivative","type":"quantitative"}},"mark":{"interpolate":"basis","type":"line"}}],"title":"Berry Profits","width":600}
From the above graph, notice that the derivative of the profit function is 0
at the same time the profit function maxes out