Chapter 5 - Traditional Machine Learning
Mix.install([
{:scholar, "~> 0.1"},
{:nx, "~> 0.5"},
{:scidata, "~> 0.1"},
{:vega_lite, "~> 0.1.6"},
{:kino_vega_lite, "~> 0.1.6"}
])
Linear Regression
m = :rand.uniform() * 10
b = :random.uniform() * 10
key = Nx.Random.key(42)
size = 100
{x, new_key} = Nx.Random.normal(key, 0.0, 1.0, shape: {size, 1})
{noise_x, new_key} = Nx.Random.normal(new_key, 0.0, 1.0, shape: {size, 1})
{noise_b, _} = Nx.Random.normal(new_key, 0.0, 1.0, shape: {size, 1})
y =
m
|> Nx.multiply(Nx.add(x, noise_x))
|> Nx.add(b)
#Nx.Tensor<
f32[100][1]
[
[0.03887796401977539],
[4.279749870300293],
[8.112302780151367],
[-2.2218804359436035],
[6.144775867462158],
[5.7303361892700195],
[6.924253463745117],
[2.205869197845459],
[5.778014183044434],
[9.057815551757812],
[0.5078833103179932],
[19.300395965576172],
[5.8746232986450195],
[9.086906433105469],
[5.712613582611084],
[8.592241287231445],
[3.0142822265625],
[2.3498716354370117],
[13.239413261413574],
[-8.851838111877441],
[-9.08227825164795],
[2.7460973262786865],
[5.435785293579102],
[5.19484806060791],
[9.656211853027344],
[8.89450454711914],
[10.833284378051758],
[9.950844764709473],
[18.857440948486328],
[-6.527419090270996],
[-9.613273620605469],
[7.780972480773926],
[14.053135871887207],
[8.886348724365234],
[12.689155578613281],
[21.51001739501953],
[8.104337692260742],
[0.6447916030883789],
[-6.969191551208496],
[16.84417724609375],
[6.300955295562744],
[26.257116317749023],
[5.885571479797363],
[8.78177261352539],
[5.983844757080078],
[17.31948471069336],
[8.100205421447754],
[5.43218469619751],
[-3.525576591491699],
[1.2058968544006348],
...
]
>
alias VegaLite, as: Vl
Vl.new(title: "Scatterplot Distribution", width: 800, height: 600)
|> Vl.data_from_values(%{
x: Nx.to_flat_list(x),
y: Nx.to_flat_list(y)
})
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"x":-0.5951409339904785,"y":0.03887796401977539},{"x":0.0132689718157053,"y":4.279749870300293},{"x":0.7378900647163391,"y":8.112302780151367},{"x":0.21672981977462769,"y":-2.2218804359436035},{"x":1.268314003944397,"y":6.144775867462158},{"x":0.08188837021589279,"y":5.7303361892700195},{"x":-0.01103353314101696,"y":6.924253463745117},{"x":-1.4509247541427612,"y":2.205869197845459},{"x":-0.5412119030952454,"y":5.778014183044434},{"x":0.8773791790008545,"y":9.057815551757812},{"x":-0.48833128809928894,"y":0.5078833103179932},{"x":1.237937092781067,"y":19.300395965576172},{"x":0.9093202352523804,"y":5.8746232986450195},{"x":0.1803722381591797,"y":9.086906433105469},{"x":-0.557214081287384,"y":5.712613582611084},{"x":0.7308189868927002,"y":8.592241287231445},{"x":0.34681078791618347,"y":3.0142822265625},{"x":0.529615581035614,"y":2.3498716354370117},{"x":0.03764510527253151,"y":13.239413261413574},{"x":-1.409928560256958,"y":-8.851838111877441},{"x":-1.4359955787658691,"y":-9.08227825164795},{"x":-0.9735497236251831,"y":2.7460973262786865},{"x":-1.3374433517456055,"y":5.435785293579102},{"x":-0.670516312122345,"y":5.19484806060791},{"x":-0.10266610980033875,"y":9.656211853027344},{"x":0.3223346173763275,"y":8.89450454711914},{"x":-0.4511873126029968,"y":10.833284378051758},{"x":-0.007002972066402435,"y":9.950844764709473},{"x":2.520038604736328,"y":18.857440948486328},{"x":-0.4965696632862091,"y":-6.527419090270996},{"x":-2.2178070545196533,"y":-9.613273620605469},{"x":-0.01344812661409378,"y":7.780972480773926},{"x":0.33387258648872375,"y":14.053135871887207},{"x":0.7982445359230042,"y":8.886348724365234},{"x":0.2093075066804886,"y":12.689155578613281},{"x":1.9826748371124268,"y":21.51001739501953},{"x":0.4582616984844208,"y":8.104337692260742},{"x":0.5479519963264465,"y":0.6447916030883789},{"x":-1.2041635513305664,"y":-6.969191551208496},{"x":2.004716634750366,"y":16.84417724609375},{"x":0.16586682200431824,"y":6.300955295562744},{"x":1.70429527759552,"y":26.257116317749023},{"x":-0.16112852096557617,"y":5.885571479797363},{"x":1.2452048063278198,"y":8.78177261352539},{"x":0.1535537838935852,"y":5.983844757080078},{"x":1.5046491622924805,"y":17.31948471069336},{"x":0.013979914598166943,"y":8.100205421447754},{"x":1.3802995681762695,"y":5.43218469619751},{"x":-0.3476333022117615,"y":-3.525576591491699},{"x":0.9467588663101196,"y":1.2058968544006348},{"x":-1.1229523420333862,"y":-1.2019143104553223},{"x":1.3731544017791748,"y":10.388106346130371},{"x":1.5248984098434448,"y":4.284648895263672},{"x":-0.83209627866745,"y":-0.37822628021240234},{"x":-0.4361415207386017,"y":-4.965084075927734},{"x":-0.9708553552627563,"y":4.431140422821045},{"x":-0.02753218077123165,"y":7.802522659301758},{"x":-0.6075385808944702,"y":-1.3150444030761719},{"x":0.69698166847229,"y":9.997228622436523},{"x":-1.5198605060577393,"y":0.13942718505859375},{"x":1.4279040098190308,"y":18.642736434936523},{"x":-0.7271621227264404,"y":-0.5941271781921387},{"x":0.6438418030738831,"y":5.2171454429626465},{"x":0.028946582227945328,"y":-2.6058268547058105},{"x":-0.7837653756141663,"y":0.8168387413024902},{"x":-0.7311482429504395,"y":-2.3730826377868652},{"x":-1.8962322473526,"y":-6.138404846191406},{"x":-0.09240783005952835,"y":-3.5298047065734863},{"x":0.8312447667121887,"y":13.766128540039062},{"x":0.08850807696580887,"y":4.256069660186768},{"x":1.0102241039276123,"y":7.030539512634277},{"x":-0.9956196546554565,"y":4.493330478668213},{"x":0.050282858312129974,"y":7.704846382141113},{"x":1.5256069898605347,"y":19.215499877929688},{"x":-1.3656820058822632,"y":-10.412747383117676},{"x":0.3665124475955963,"y":7.216827869415283},{"x":-1.4678500890731812,"y":0.7908694744110107},{"x":-1.353937029838562,"y":-3.7343931198120117},{"x":-0.7861266136169434,"y":-1.5413298606872559},{"x":-1.6958733797073364,"y":-11.22181510925293},{"x":0.10122057795524597,"y":-10.72789478302002},{"x":1.315024971961975,"y":7.041434288024902},{"x":0.6036592125892639,"y":6.795568466186523},{"x":-2.478543281555176,"y":-0.0038051605224609375},{"x":-0.7435774207115173,"y":7.070751190185547},{"x":-2.687535285949707,"y":-7.40959358215332},{"x":0.6872367858886719,"y":1.9924001693725586},{"x":1.3363817930221558,"y":14.002706527709961},{"x":-0.579086184501648,"y":-1.187077522277832},{"x":0.3402725160121918,"y":9.303038597106934},{"x":-0.5970014929771423,"y":-1.6615219116210938},{"x":1.373253345489502,"y":22.889440536499023},{"x":1.02512788772583,"y":19.163171768188477},{"x":-1.0098217725753784,"y":-0.3724198341369629},{"x":-0.14097319543361664,"y":4.102351188659668},{"x":-0.027422772720456123,"y":10.54636001586914},{"x":0.9244212508201599,"y":7.857834815979004},{"x":0.7690575122833252,"y":4.098662376403809},{"x":1.485957384109497,"y":17.97329330444336},{"x":-0.8444464802742004,"y":-9.950044631958008}]},"encoding":{"x":{"field":"x","type":"quantitative"},"y":{"field":"y","type":"quantitative"}},"height":600,"mark":"point","title":"Scatterplot Distribution","width":800}
model = Scholar.Linear.LinearRegression.fit(x, y)
%Scholar.Linear.LinearRegression{
coefficients: #Nx.Tensor<
f32[1][1]
[
[5.706689834594727]
]
>,
intercept: #Nx.Tensor<
f32[1]
[4.8282599449157715]
>
}
Scholar.Linear.LinearRegression.predict(model, Nx.iota({3, 1}))
#Nx.Tensor<
f32[3][1]
[
[4.8282599449157715],
[10.534950256347656],
[16.241640090942383]
]
>
pred_xs = Nx.linspace(-3.0, 3.0, n: 100) |> Nx.new_axis(-1)
pred_ys = Scholar.Linear.LinearRegression.predict(model, pred_xs)
#Nx.Tensor<
f32[100][1]
[
[-12.29180908203125],
[-11.94594955444336],
[-11.600090026855469],
[-11.254228591918945],
[-10.908369064331055],
[-10.562509536743164],
[-10.21664810180664],
[-9.87078857421875],
[-9.52492904663086],
[-9.179069519042969],
[-8.833209991455078],
[-8.487348556518555],
[-8.141489028930664],
[-7.795629978179932],
[-7.449768543243408],
[-7.103909015655518],
[-6.758049488067627],
[-6.412189960479736],
[-6.066329479217529],
[-5.720468997955322],
[-5.374609470367432],
[-5.028749942779541],
[-4.682889461517334],
[-4.337028980255127],
[-3.9911694526672363],
[-3.6453099250793457],
[-3.2994494438171387],
[-2.95358943939209],
[-2.607729434967041],
[-2.2618699073791504],
[-1.9160094261169434],
[-1.5701494216918945],
[-1.2242894172668457],
[-0.8784298896789551],
[-0.5325703620910645],
[-0.18670940399169922],
[0.1591506004333496],
[0.5050101280212402],
[0.8508710861206055],
[1.1967308521270752],
[1.5425903797149658],
[1.8884501457214355],
[2.234309673309326],
[2.5801708698272705],
[2.926030397415161],
[3.271890163421631],
[3.617751121520996],
[3.9636106491088867],
[4.3094706535339355],
[4.655330181121826],
...
]
>
Vl.new(title: "Scatterplot Distribution and Fit Curve", width: 1440, height: 1080)
|> Vl.data_from_values(%{
x: Nx.to_flat_list(x),
y: Nx.to_flat_list(y),
pred_x: Nx.to_flat_list(pred_xs),
pred_y: Nx.to_flat_list(pred_ys)
})
|> Vl.layers([
Vl.new()
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative),
Vl.new()
|> Vl.mark(:line)
|> Vl.encode_field(:x, "pred_x", type: :quantitative)
|> Vl.encode_field(:y, "pred_y", type: :quantitative)
])
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"pred_x":-3.0,"pred_y":-12.29180908203125,"x":-0.5951409339904785,"y":0.03887796401977539},{"pred_x":-2.939393997192383,"pred_y":-11.94594955444336,"x":0.0132689718157053,"y":4.279749870300293},{"pred_x":-2.8787879943847656,"pred_y":-11.600090026855469,"x":0.7378900647163391,"y":8.112302780151367},{"pred_x":-2.8181817531585693,"pred_y":-11.254228591918945,"x":0.21672981977462769,"y":-2.2218804359436035},{"pred_x":-2.757575750350952,"pred_y":-10.908369064331055,"x":1.268314003944397,"y":6.144775867462158},{"pred_x":-2.696969747543335,"pred_y":-10.562509536743164,"x":0.08188837021589279,"y":5.7303361892700195},{"pred_x":-2.6363635063171387,"pred_y":-10.21664810180664,"x":-0.01103353314101696,"y":6.924253463745117},{"pred_x":-2.5757575035095215,"pred_y":-9.87078857421875,"x":-1.4509247541427612,"y":2.205869197845459},{"pred_x":-2.5151515007019043,"pred_y":-9.52492904663086,"x":-0.5412119030952454,"y":5.778014183044434},{"pred_x":-2.454545497894287,"pred_y":-9.179069519042969,"x":0.8773791790008545,"y":9.057815551757812},{"pred_x":-2.39393949508667,"pred_y":-8.833209991455078,"x":-0.48833128809928894,"y":0.5078833103179932},{"pred_x":-2.3333332538604736,"pred_y":-8.487348556518555,"x":1.237937092781067,"y":19.300395965576172},{"pred_x":-2.2727272510528564,"pred_y":-8.141489028930664,"x":0.9093202352523804,"y":5.8746232986450195},{"pred_x":-2.2121212482452393,"pred_y":-7.795629978179932,"x":0.1803722381591797,"y":9.086906433105469},{"pred_x":-2.151515007019043,"pred_y":-7.449768543243408,"x":-0.557214081287384,"y":5.712613582611084},{"pred_x":-2.090909004211426,"pred_y":-7.103909015655518,"x":0.7308189868927002,"y":8.592241287231445},{"pred_x":-2.0303030014038086,"pred_y":-6.758049488067627,"x":0.34681078791618347,"y":3.0142822265625},{"pred_x":-1.9696969985961914,"pred_y":-6.412189960479736,"x":0.529615581035614,"y":2.3498716354370117},{"pred_x":-1.9090908765792847,"pred_y":-6.066329479217529,"x":0.03764510527253151,"y":13.239413261413574},{"pred_x":-1.848484754562378,"pred_y":-5.720468997955322,"x":-1.409928560256958,"y":-8.851838111877441},{"pred_x":-1.7878787517547607,"pred_y":-5.374609470367432,"x":-1.4359955787658691,"y":-9.08227825164795},{"pred_x":-1.7272727489471436,"pred_y":-5.028749942779541,"x":-0.9735497236251831,"y":2.7460973262786865},{"pred_x":-1.6666666269302368,"pred_y":-4.682889461517334,"x":-1.3374433517456055,"y":5.435785293579102},{"pred_x":-1.60606050491333,"pred_y":-4.337028980255127,"x":-0.670516312122345,"y":5.19484806060791},{"pred_x":-1.545454502105713,"pred_y":-3.9911694526672363,"x":-0.10266610980033875,"y":9.656211853027344},{"pred_x":-1.4848484992980957,"pred_y":-3.6453099250793457,"x":0.3223346173763275,"y":8.89450454711914},{"pred_x":-1.424242377281189,"pred_y":-3.2994494438171387,"x":-0.4511873126029968,"y":10.833284378051758},{"pred_x":-1.3636362552642822,"pred_y":-2.95358943939209,"x":-0.007002972066402435,"y":9.950844764709473},{"pred_x":-1.303030252456665,"pred_y":-2.607729434967041,"x":2.520038604736328,"y":18.857440948486328},{"pred_x":-1.2424242496490479,"pred_y":-2.2618699073791504,"x":-0.4965696632862091,"y":-6.527419090270996},{"pred_x":-1.1818181276321411,"pred_y":-1.9160094261169434,"x":-2.2178070545196533,"y":-9.613273620605469},{"pred_x":-1.1212120056152344,"pred_y":-1.5701494216918945,"x":-0.01344812661409378,"y":7.780972480773926},{"pred_x":-1.0606060028076172,"pred_y":-1.2242894172668457,"x":0.33387258648872375,"y":14.053135871887207},{"pred_x":-1.0,"pred_y":-0.8784298896789551,"x":0.7982445359230042,"y":8.886348724365234},{"pred_x":-0.9393939971923828,"pred_y":-0.5325703620910645,"x":0.2093075066804886,"y":12.689155578613281},{"pred_x":-0.8787877559661865,"pred_y":-0.18670940399169922,"x":1.9826748371124268,"y":21.51001739501953},{"pred_x":-0.8181817531585693,"pred_y":0.1591506004333496,"x":0.4582616984844208,"y":8.104337692260742},{"pred_x":-0.7575757503509521,"pred_y":0.5050101280212402,"x":0.5479519963264465,"y":0.6447916030883789},{"pred_x":-0.6969695091247559,"pred_y":0.8508710861206055,"x":-1.2041635513305664,"y":-6.969191551208496},{"pred_x":-0.6363635063171387,"pred_y":1.1967308521270752,"x":2.004716634750366,"y":16.84417724609375},{"pred_x":-0.5757575035095215,"pred_y":1.5425903797149658,"x":0.16586682200431824,"y":6.300955295562744},{"pred_x":-0.5151515007019043,"pred_y":1.8884501457214355,"x":1.70429527759552,"y":26.257116317749023},{"pred_x":-0.4545454978942871,"pred_y":2.234309673309326,"x":-0.16112852096557617,"y":5.885571479797363},{"pred_x":-0.3939392566680908,"pred_y":2.5801708698272705,"x":1.2452048063278198,"y":8.78177261352539},{"pred_x":-0.33333325386047363,"pred_y":2.926030397415161,"x":0.1535537838935852,"y":5.983844757080078},{"pred_x":-0.27272725105285645,"pred_y":3.271890163421631,"x":1.5046491622924805,"y":17.31948471069336},{"pred_x":-0.21212100982666016,"pred_y":3.617751121520996,"x":0.013979914598166943,"y":8.100205421447754},{"pred_x":-0.15151500701904297,"pred_y":3.9636106491088867,"x":1.3802995681762695,"y":5.43218469619751},{"pred_x":-0.09090900421142578,"pred_y":4.3094706535339355,"x":-0.3476333022117615,"y":-3.525576591491699},{"pred_x":-0.030303001403808594,"pred_y":4.655330181121826,"x":0.9467588663101196,"y":1.2058968544006348},{"pred_x":0.030303001403808594,"pred_y":5.001189708709717,"x":-1.1229523420333862,"y":-1.2019143104553223},{"pred_x":0.09090924263000488,"pred_y":5.347050666809082,"x":1.3731544017791748,"y":10.388106346130371},{"pred_x":0.15151524543762207,"pred_y":5.692910194396973,"x":1.5248984098434448,"y":4.284648895263672},{"pred_x":0.21212124824523926,"pred_y":6.0387701988220215,"x":-0.83209627866745,"y":-0.37822628021240234},{"pred_x":0.27272748947143555,"pred_y":6.384631156921387,"x":-0.4361415207386017,"y":-4.965084075927734},{"pred_x":0.33333349227905273,"pred_y":6.730490684509277,"x":-0.9708553552627563,"y":4.431140422821045},{"pred_x":0.3939394950866699,"pred_y":7.076350212097168,"x":-0.02753218077123165,"y":7.802522659301758},{"pred_x":0.4545454978942871,"pred_y":7.422210216522217,"x":-0.6075385808944702,"y":-1.3150444030761719},{"pred_x":0.5151515007019043,"pred_y":7.768069744110107,"x":0.69698166847229,"y":9.997228622436523},{"pred_x":0.5757577419281006,"pred_y":8.113930702209473,"x":-1.5198605060577393,"y":0.13942718505859375},{"pred_x":0.6363637447357178,"pred_y":8.459790229797363,"x":1.4279040098190308,"y":18.642736434936523},{"pred_x":0.696969747543335,"pred_y":8.80565071105957,"x":-0.7271621227264404,"y":-0.5941271781921387},{"pred_x":0.7575759887695312,"pred_y":9.151511192321777,"x":0.6438418030738831,"y":5.2171454429626465},{"pred_x":0.8181819915771484,"pred_y":9.497370719909668,"x":0.028946582227945328,"y":-2.6058268547058105},{"pred_x":0.8787879943847656,"pred_y":9.843230247497559,"x":-0.7837653756141663,"y":0.8168387413024902},{"pred_x":0.9393939971923828,"pred_y":10.189090728759766,"x":-0.7311482429504395,"y":-2.3730826377868652},{"pred_x":1.0,"pred_y":10.534950256347656,"x":-1.8962322473526,"y":-6.138404846191406},{"pred_x":1.0606060028076172,"pred_y":10.880809783935547,"x":-0.09240783005952835,"y":-3.5298047065734863},{"pred_x":1.1212120056152344,"pred_y":11.226669311523438,"x":0.8312447667121887,"y":13.766128540039062},{"pred_x":1.1818184852600098,"pred_y":11.572531700134277,"x":0.08850807696580887,"y":4.256069660186768},{"pred_x":1.242424488067627,"pred_y":11.918391227722168,"x":1.0102241039276123,"y":7.030539512634277},{"pred_x":1.3030304908752441,"pred_y":12.264250755310059,"x":-0.9956196546554565,"y":4.493330478668213},{"pred_x":1.3636364936828613,"pred_y":12.61011028289795,"x":0.050282858312129974,"y":7.704846382141113},{"pred_x":1.4242424964904785,"pred_y":12.955970764160156,"x":1.5256069898605347,"y":19.215499877929688},{"pred_x":1.4848484992980957,"pred_y":13.301830291748047,"x":-1.3656820058822632,"y":-10.412747383117676},{"pred_x":1.545454502105713,"pred_y":13.647689819335938,"x":0.3665124475955963,"y":7.216827869415283},{"pred_x":1.6060609817504883,"pred_y":13.993551254272461,"x":-1.4678500890731812,"y":0.7908694744110107},{"pred_x":1.6666669845581055,"pred_y":14.339410781860352,"x":-1.353937029838562,"y":-3.7343931198120117},{"pred_x":1.7272729873657227,"pred_y":14.685270309448242,"x":-0.7861266136169434,"y":-1.5413298606872559},{"pred_x":1.7878789901733398,"pred_y":15.031131744384766,"x":-1.6958733797073364,"y":-11.22181510925293},{"pred_x":1.848484992980957,"pred_y":15.376991271972656,"x":0.10122057795524597,"y":-10.72789478302002},{"pred_x":1.9090909957885742,"pred_y":15.722850799560547,"x":1.315024971961975,"y":7.041434288024902},{"pred_x":1.9696969985961914,"pred_y":16.068710327148438,"x":0.6036592125892639,"y":6.795568466186523},{"pred_x":2.0303030014038086,"pred_y":16.414569854736328,"x":-2.478543281555176,"y":-0.0038051605224609375},{"pred_x":2.090909004211426,"pred_y":16.76042938232422,"x":-0.7435774207115173,"y":7.070751190185547},{"pred_x":2.151515483856201,"pred_y":17.106290817260742,"x":-2.687535285949707,"y":-7.40959358215332},{"pred_x":2.2121214866638184,"pred_y":17.452150344848633,"x":0.6872367858886719,"y":1.9924001693725586},{"pred_x":2.2727274894714355,"pred_y":17.798009872436523,"x":1.3363817930221558,"y":14.002706527709961},{"pred_x":2.3333334922790527,"pred_y":18.143871307373047,"x":-0.579086184501648,"y":-1.187077522277832},{"pred_x":2.39393949508667,"pred_y":18.489730834960938,"x":0.3402725160121918,"y":9.303038597106934},{"pred_x":2.454545497894287,"pred_y":18.835590362548828,"x":-0.5970014929771423,"y":-1.6615219116210938},{"pred_x":2.5151515007019043,"pred_y":19.18144989013672,"x":1.373253345489502,"y":22.889440536499023},{"pred_x":2.5757579803466797,"pred_y":19.527311325073242,"x":1.02512788772583,"y":19.163171768188477},{"pred_x":2.636363983154297,"pred_y":19.873170852661133,"x":-1.0098217725753784,"y":-0.3724198341369629},{"pred_x":2.696969985961914,"pred_y":20.219030380249023,"x":-0.14097319543361664,"y":4.102351188659668},{"pred_x":2.7575759887695312,"pred_y":20.564889907836914,"x":-0.027422772720456123,"y":10.54636001586914},{"pred_x":2.8181819915771484,"pred_y":20.910751342773438,"x":0.9244212508201599,"y":7.857834815979004},{"pred_x":2.8787879943847656,"pred_y":21.256610870361328,"x":0.7690575122833252,"y":4.098662376403809},{"pred_x":2.939393997192383,"pred_y":21.60247039794922,"x":1.485957384109497,"y":17.97329330444336},{"pred_x":3.0,"pred_y":21.94832992553711,"x":-0.8444464802742004,"y":-9.950044631958008}]},"height":1080,"layer":[{"encoding":{"x":{"field":"x","type":"quantitative"},"y":{"field":"y","type":"quantitative"}},"mark":"point"},{"encoding":{"x":{"field":"pred_x","type":"quantitative"},"y":{"field":"pred_y","type":"quantitative"}},"mark":"line"}],"title":"Scatterplot Distribution and Fit Curve","width":1440}
Transformations
x = Nx.linspace(0, 5, n: 100)
y = Nx.exp(x)
log_y = Nx.log(y)
#Nx.Tensor<
f32[100]
[0.0, 0.05050503462553024, 0.10101012140512466, 0.1515151411294937, 0.20202024281024933, 0.2525252401828766, 0.3030303120613098, 0.35353538393974304, 0.4040404260158539, 0.45454543828964233, 0.5050504803657532, 0.555555522441864, 0.6060606837272644, 0.6565656661987305, 0.7070707082748413, 0.7575757503509521, 0.808080792427063, 0.858585774898529, 0.9090908765792847, 0.9595958590507507, 1.0101009607315063, 1.0606060028076172, 1.111111044883728, 1.1616160869598389, 1.2121212482452393, 1.26262629032135, 1.313131332397461, 1.3636363744735718, 1.4141414165496826, 1.4646464586257935, 1.5151515007019043, 1.5656565427780151, 1.616161584854126, 1.6666666269302368, 1.7171716690063477, 1.7676767110824585, 1.8181817531585693, 1.8686867952346802, 1.919191837310791, 1.9696968793869019, 2.0202019214630127, 2.070707082748413, 2.1212120056152344, 2.1717171669006348, 2.222222089767456, 2.2727272510528564, 2.3232321739196777, 2.373737335205078, 2.4242424964904785, 2.4747474193573, ...]
>
Vl.new(title: "Transformations")
|> Vl.concat([
Vl.new(width: 720, height: 540)
|> Vl.mark(:point)
|> Vl.data_from_values(%{
x: Nx.to_flat_list(x),
y: Nx.to_flat_list(y)
})
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative),
Vl.new(width: 720, height: 540)
|> Vl.mark(:point)
|> Vl.data_from_values(%{
x: Nx.to_flat_list(x),
log_y: Nx.to_flat_list(log_y)
})
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "log_y", type: :quantitative)
])
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","concat":[{"data":{"values":[{"x":0.0,"y":1.0},{"x":0.05050504952669144,"y":1.051802158355713},{"x":0.10101009905338287,"y":1.1062878370285034},{"x":0.1515151560306549,"y":1.1635959148406982},{"x":0.20202019810676575,"y":1.2238727807998657},{"x":0.2525252401828766,"y":1.2872719764709473},{"x":0.3030303120613098,"y":1.3539555072784424},{"x":0.35353535413742065,"y":1.4240933656692505},{"x":0.4040403962135315,"y":1.4978644847869873},{"x":0.45454543828964233,"y":1.5754570960998535},{"x":0.5050504803657532,"y":1.657069206237793},{"x":0.555555522441864,"y":1.7429089546203613},{"x":0.6060606241226196,"y":1.8331955671310425},{"x":0.6565656661987305,"y":1.9281589984893799},{"x":0.7070707082748413,"y":2.0280418395996094},{"x":0.7575757503509521,"y":2.133098840713501},{"x":0.808080792427063,"y":2.243597984313965},{"x":0.8585858345031738,"y":2.359821081161499},{"x":0.9090908765792847,"y":2.482064962387085},{"x":0.9595959186553955,"y":2.6106412410736084},{"x":1.0101009607315063,"y":2.745878219604492},{"x":1.0606060028076172,"y":2.888120651245117},{"x":1.111111044883728,"y":3.037731647491455},{"x":1.1616160869598389,"y":3.1950926780700684},{"x":1.2121212482452393,"y":3.3606057167053223},{"x":1.26262629032135,"y":3.5346925258636475},{"x":1.313131332397461,"y":3.71779727935791},{"x":1.3636363744735718,"y":3.9103870391845703},{"x":1.4141414165496826,"y":4.1129536628723145},{"x":1.4646464586257935,"y":4.326013565063477},{"x":1.5151515007019043,"y":4.550110340118408},{"x":1.5656565427780151,"y":4.785816192626953},{"x":1.616161584854126,"y":5.033731460571289},{"x":1.6666666269302368,"y":5.294489860534668},{"x":1.7171716690063477,"y":5.568756103515625},{"x":1.7676767110824585,"y":5.857229709625244},{"x":1.8181817531585693,"y":6.160646915435791},{"x":1.8686867952346802,"y":6.479781627655029},{"x":1.919191837310791,"y":6.81544828414917},{"x":1.9696968793869019,"y":7.168503284454346},{"x":2.0202019214630127,"y":7.539847373962402},{"x":2.070707082748413,"y":7.930428504943848},{"x":2.1212120056152344,"y":8.341240882873535},{"x":2.1717171669006348,"y":8.773336410522461},{"x":2.222222089767456,"y":9.227812767028809},{"x":2.2727272510528564,"y":9.705835342407227},{"x":2.3232321739196777,"y":10.208617210388184},{"x":2.373737335205078,"y":10.737446784973145},{"x":2.4242424964904785,"y":11.293671607971191},{"x":2.4747474193573,"y":11.878705978393555},{"x":2.5252525806427,"y":12.494050979614258},{"x":2.5757575035095215,"y":13.141267776489258},{"x":2.626262664794922,"y":13.822015762329102},{"x":2.676767587661743,"y":14.538023948669434},{"x":2.7272727489471436,"y":15.29112720489502},{"x":2.777777671813965,"y":16.08323860168457},{"x":2.8282828330993652,"y":16.9163875579834},{"x":2.8787877559661865,"y":17.79269027709961},{"x":2.929292917251587,"y":18.714393615722656},{"x":2.979797840118408,"y":19.683837890625},{"x":3.0303030014038086,"y":20.70350456237793},{"x":3.08080792427063,"y":21.775989532470703},{"x":3.1313130855560303,"y":22.904035568237305},{"x":3.1818180084228516,"y":24.090511322021484},{"x":3.232323169708252,"y":25.33845329284668},{"x":3.2828283309936523,"y":26.651044845581055},{"x":3.3333332538604736,"y":28.0316219329834},{"x":3.383838415145874,"y":29.48372459411621},{"x":3.4343433380126953,"y":31.01104164123535},{"x":3.4848484992980957,"y":32.61748504638672},{"x":3.535353422164917,"y":34.30713653564453},{"x":3.5858585834503174,"y":36.084327697753906},{"x":3.6363635063171387,"y":37.95356750488281},{"x":3.686868667602539,"y":39.91965103149414},{"x":3.7373735904693604,"y":41.98756790161133},{"x":3.7878787517547607,"y":44.162620544433594},{"x":3.838383674621582,"y":46.45033645629883},{"x":3.8888888359069824,"y":48.8565673828125},{"x":3.9393937587738037,"y":51.3874397277832},{"x":3.989898920059204,"y":54.0494270324707},{"x":4.040403842926025,"y":56.84929656982422},{"x":4.090909004211426,"y":59.794219970703125},{"x":4.141414165496826,"y":62.89169692993164},{"x":4.191919326782227,"y":66.1496353149414},{"x":4.242424011230469,"y":69.57630157470703},{"x":4.292929172515869,"y":73.18051147460938},{"x":4.3434343338012695,"y":76.97142791748047},{"x":4.39393949508667,"y":80.95872497558594},{"x":4.444444179534912,"y":85.15253448486328},{"x":4.4949493408203125,"y":89.56362915039062},{"x":4.545454502105713,"y":94.20323181152344},{"x":4.595959663391113,"y":99.08317565917969},{"x":4.6464643478393555,"y":104.21586608886719},{"x":4.696969509124756,"y":109.61448669433594},{"x":4.747474670410156,"y":115.29276275634766},{"x":4.797979831695557,"y":121.26519012451172},{"x":4.848484992980957,"y":127.54701232910156},{"x":4.898989677429199,"y":134.1541748046875},{"x":4.9494948387146,"y":141.10366821289062},{"x":5.0,"y":148.4131622314453}]},"encoding":{"x":{"field":"x","type":"quantitative"},"y":{"field":"y","type":"quantitative"}},"height":540,"mark":"point","width":720},{"data":{"values":[{"log_y":0.0,"x":0.0},{"log_y":0.05050503462553024,"x":0.05050504952669144},{"log_y":0.10101012140512466,"x":0.10101009905338287},{"log_y":0.1515151411294937,"x":0.1515151560306549},{"log_y":0.20202024281024933,"x":0.20202019810676575},{"log_y":0.2525252401828766,"x":0.2525252401828766},{"log_y":0.3030303120613098,"x":0.3030303120613098},{"log_y":0.35353538393974304,"x":0.35353535413742065},{"log_y":0.4040404260158539,"x":0.4040403962135315},{"log_y":0.45454543828964233,"x":0.45454543828964233},{"log_y":0.5050504803657532,"x":0.5050504803657532},{"log_y":0.555555522441864,"x":0.555555522441864},{"log_y":0.6060606837272644,"x":0.6060606241226196},{"log_y":0.6565656661987305,"x":0.6565656661987305},{"log_y":0.7070707082748413,"x":0.7070707082748413},{"log_y":0.7575757503509521,"x":0.7575757503509521},{"log_y":0.808080792427063,"x":0.808080792427063},{"log_y":0.858585774898529,"x":0.8585858345031738},{"log_y":0.9090908765792847,"x":0.9090908765792847},{"log_y":0.9595958590507507,"x":0.9595959186553955},{"log_y":1.0101009607315063,"x":1.0101009607315063},{"log_y":1.0606060028076172,"x":1.0606060028076172},{"log_y":1.111111044883728,"x":1.111111044883728},{"log_y":1.1616160869598389,"x":1.1616160869598389},{"log_y":1.2121212482452393,"x":1.2121212482452393},{"log_y":1.26262629032135,"x":1.26262629032135},{"log_y":1.313131332397461,"x":1.313131332397461},{"log_y":1.3636363744735718,"x":1.3636363744735718},{"log_y":1.4141414165496826,"x":1.4141414165496826},{"log_y":1.4646464586257935,"x":1.4646464586257935},{"log_y":1.5151515007019043,"x":1.5151515007019043},{"log_y":1.5656565427780151,"x":1.5656565427780151},{"log_y":1.616161584854126,"x":1.616161584854126},{"log_y":1.6666666269302368,"x":1.6666666269302368},{"log_y":1.7171716690063477,"x":1.7171716690063477},{"log_y":1.7676767110824585,"x":1.7676767110824585},{"log_y":1.8181817531585693,"x":1.8181817531585693},{"log_y":1.8686867952346802,"x":1.8686867952346802},{"log_y":1.919191837310791,"x":1.919191837310791},{"log_y":1.9696968793869019,"x":1.9696968793869019},{"log_y":2.0202019214630127,"x":2.0202019214630127},{"log_y":2.070707082748413,"x":2.070707082748413},{"log_y":2.1212120056152344,"x":2.1212120056152344},{"log_y":2.1717171669006348,"x":2.1717171669006348},{"log_y":2.222222089767456,"x":2.222222089767456},{"log_y":2.2727272510528564,"x":2.2727272510528564},{"log_y":2.3232321739196777,"x":2.3232321739196777},{"log_y":2.373737335205078,"x":2.373737335205078},{"log_y":2.4242424964904785,"x":2.4242424964904785},{"log_y":2.4747474193573,"x":2.4747474193573},{"log_y":2.5252525806427,"x":2.5252525806427},{"log_y":2.5757575035095215,"x":2.5757575035095215},{"log_y":2.626262664794922,"x":2.626262664794922},{"log_y":2.676767587661743,"x":2.676767587661743},{"log_y":2.7272727489471436,"x":2.7272727489471436},{"log_y":2.777777671813965,"x":2.777777671813965},{"log_y":2.8282828330993652,"x":2.8282828330993652},{"log_y":2.8787877559661865,"x":2.8787877559661865},{"log_y":2.929292917251587,"x":2.929292917251587},{"log_y":2.979797840118408,"x":2.979797840118408},{"log_y":3.0303030014038086,"x":3.0303030014038086},{"log_y":3.08080792427063,"x":3.08080792427063},{"log_y":3.1313130855560303,"x":3.1313130855560303},{"log_y":3.1818180084228516,"x":3.1818180084228516},{"log_y":3.232323169708252,"x":3.232323169708252},{"log_y":3.2828283309936523,"x":3.2828283309936523},{"log_y":3.3333332538604736,"x":3.3333332538604736},{"log_y":3.383838415145874,"x":3.383838415145874},{"log_y":3.4343433380126953,"x":3.4343433380126953},{"log_y":3.4848484992980957,"x":3.4848484992980957},{"log_y":3.535353422164917,"x":3.535353422164917},{"log_y":3.5858585834503174,"x":3.5858585834503174},{"log_y":3.6363635063171387,"x":3.6363635063171387},{"log_y":3.686868667602539,"x":3.686868667602539},{"log_y":3.7373735904693604,"x":3.7373735904693604},{"log_y":3.7878787517547607,"x":3.7878787517547607},{"log_y":3.838383674621582,"x":3.838383674621582},{"log_y":3.8888888359069824,"x":3.8888888359069824},{"log_y":3.9393937587738037,"x":3.9393937587738037},{"log_y":3.989898920059204,"x":3.989898920059204},{"log_y":4.040403842926025,"x":4.040403842926025},{"log_y":4.090909004211426,"x":4.090909004211426},{"log_y":4.141414165496826,"x":4.141414165496826},{"log_y":4.191919326782227,"x":4.191919326782227},{"log_y":4.242424011230469,"x":4.242424011230469},{"log_y":4.292929172515869,"x":4.292929172515869},{"log_y":4.3434343338012695,"x":4.3434343338012695},{"log_y":4.39393949508667,"x":4.39393949508667},{"log_y":4.444444179534912,"x":4.444444179534912},{"log_y":4.4949493408203125,"x":4.4949493408203125},{"log_y":4.545454502105713,"x":4.545454502105713},{"log_y":4.595959663391113,"x":4.595959663391113},{"log_y":4.6464643478393555,"x":4.6464643478393555},{"log_y":4.696969509124756,"x":4.696969509124756},{"log_y":4.747474670410156,"x":4.747474670410156},{"log_y":4.797979831695557,"x":4.797979831695557},{"log_y":4.848484992980957,"x":4.848484992980957},{"log_y":4.898989677429199,"x":4.898989677429199},{"log_y":4.9494948387146,"x":4.9494948387146},{"log_y":5.0,"x":5.0}]},"encoding":{"x":{"field":"x","type":"quantitative"},"y":{"field":"log_y","type":"quantitative"}},"height":540,"mark":"point","width":720}],"title":"Transformations"}
Logistic Regression
{inputs, targets} = Scidata.Wine.download()
{[
[14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0],
[13.2, 1.78, 2.14, 11.2, 100.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.4, 1050.0],
[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.68, 1.03, 3.17, 1185.0],
[14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0],
[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0],
[14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0],
[14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0],
[14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0],
[14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0],
[13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.22, 1.01, 3.55, 1045.0],
[14.1, 2.16, 2.3, 18.0, 105.0, 2.95, 3.32, 0.22, 2.38, 5.75, 1.25, 3.17, 1510.0],
[14.12, 1.48, 2.32, 16.8, 95.0, 2.2, 2.43, 0.26, 1.57, 5.0, 1.17, 2.82, 1280.0],
[13.75, 1.73, 2.41, 16.0, 89.0, 2.6, 2.76, 0.29, 1.81, 5.6, 1.15, 2.9, 1320.0],
[14.75, 1.73, 2.39, 11.4, 91.0, 3.1, 3.69, 0.43, 2.81, 5.4, 1.25, 2.73, 1150.0],
[14.38, 1.87, 2.38, 12.0, 102.0, 3.3, 3.64, 0.29, 2.96, 7.5, 1.2, 3.0, 1547.0],
[13.63, 1.81, 2.7, 17.2, 112.0, 2.85, 2.91, 0.3, 1.46, 7.3, 1.28, 2.88, 1310.0],
[14.3, 1.92, 2.72, 20.0, 120.0, 2.8, 3.14, 0.33, 1.97, 6.2, 1.07, 2.65, 1280.0],
[13.83, 1.57, 2.62, 20.0, 115.0, 2.95, 3.4, 0.4, 1.72, 6.6, 1.13, 2.57, 1130.0],
[14.19, 1.59, 2.48, 16.5, 108.0, 3.3, 3.93, 0.32, 1.86, 8.7, 1.23, 2.82, 1680.0],
[13.64, 3.1, 2.56, 15.2, 116.0, 2.7, 3.03, 0.17, 1.66, 5.1, 0.96, 3.36, 845.0],
[14.06, 1.63, 2.28, 16.0, 126.0, 3.0, 3.17, 0.24, 2.1, 5.65, 1.09, 3.71, 780.0],
[12.93, 3.8, 2.65, 18.6, 102.0, 2.41, 2.41, 0.25, 1.98, 4.5, 1.03, 3.52, 770.0],
[13.71, 1.86, 2.36, 16.6, 101.0, 2.61, 2.88, 0.27, 1.69, 3.8, 1.11, 4.0, 1035.0],
[12.85, 1.6, 2.52, 17.8, 95.0, 2.48, 2.37, 0.26, 1.46, 3.93, 1.09, 3.63, 1015.0],
[13.5, 1.81, 2.61, 20.0, 96.0, 2.53, 2.61, 0.28, 1.66, 3.52, 1.12, 3.82, 845.0],
[13.05, 2.05, 3.22, 25.0, 124.0, 2.63, 2.68, 0.47, 1.92, 3.58, 1.13, 3.2, 830.0],
[13.39, 1.77, 2.62, 16.1, 93.0, 2.85, 2.94, 0.34, 1.45, 4.8, 0.92, 3.22, 1195.0],
[13.3, 1.72, 2.14, 17.0, 94.0, 2.4, 2.19, 0.27, 1.35, 3.95, 1.02, 2.77, 1285.0],
[13.87, 1.9, 2.8, 19.4, 107.0, 2.95, 2.97, 0.37, 1.76, 4.5, 1.25, 3.4, 915.0],
[14.02, 1.68, 2.21, 16.0, 96.0, 2.65, 2.33, 0.26, 1.98, 4.7, 1.04, 3.59, 1035.0],
[13.73, 1.5, 2.7, 22.5, 101.0, 3.0, 3.25, 0.29, 2.38, 5.7, 1.19, 2.71, 1285.0],
[13.58, 1.66, 2.36, 19.1, 106.0, 2.86, 3.19, 0.22, 1.95, 6.9, 1.09, 2.88, 1515.0],
[13.68, 1.83, 2.36, 17.2, 104.0, 2.42, 2.69, 0.42, 1.97, 3.84, 1.23, 2.87, 990.0],
[13.76, 1.53, 2.7, 19.5, 132.0, 2.95, 2.74, 0.5, 1.35, 5.4, 1.25, 3.0, 1235.0],
[13.51, 1.8, 2.65, 19.0, 110.0, 2.35, 2.53, 0.29, 1.54, 4.2, 1.1, 2.87, 1095.0],
[13.48, 1.81, 2.41, 20.5, 100.0, 2.7, 2.98, 0.26, 1.86, 5.1, 1.04, 3.47, 920.0],
[13.28, 1.64, 2.84, 15.5, 110.0, 2.6, 2.68, 0.34, 1.36, 4.6, 1.09, 2.78, ...],
[13.05, 1.65, 2.55, 18.0, 98.0, 2.45, 2.43, 0.29, 1.44, 4.25, 1.12, ...],
[13.07, 1.5, 2.1, 15.5, 98.0, 2.4, 2.64, 0.28, 1.37, 3.7, ...],
[14.22, 3.99, 2.51, 13.2, 128.0, 3.0, 3.04, 0.2, 2.08, ...],
[13.56, 1.71, 2.31, 16.2, 117.0, 3.15, 3.29, 0.34, ...],
[13.41, 3.84, 2.12, 18.8, 90.0, 2.45, 2.68, ...],
[13.88, 1.89, 2.59, 15.0, 101.0, 3.25, ...],
[13.24, 3.98, 2.29, 17.5, 103.0, ...],
[13.05, 1.77, 2.1, 17.0, ...],
[14.21, 4.04, 2.44, ...],
[14.38, 3.59, ...],
[13.9, ...],
[...],
...
],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]}
{train, test} =
inputs
|> Enum.zip(targets)
|> Enum.shuffle()
|> Enum.split(floor(length(inputs) * 0.8))
{train_inputs, train_targets} = Enum.unzip(train)
{train_inputs, train_targets} = {Nx.tensor(train_inputs), Nx.tensor(train_targets)}
{test_inputs, test_targets} = Enum.unzip(test)
{test_inputs, test_targets} = {Nx.tensor(test_inputs), Nx.tensor(test_targets)}
{#Nx.Tensor<
f32[36][13]
[
[13.5600004196167, 1.7100000381469727, 2.309999942779541, 16.200000762939453, 117.0, 3.1500000953674316, 3.2899999618530273, 0.3400000035762787, 2.3399999141693115, 6.130000114440918, 0.949999988079071, 3.380000114440918, 795.0],
[12.930000305175781, 2.809999942779541, 2.700000047683716, 21.0, 96.0, 1.5399999618530273, 0.5, 0.5299999713897705, 0.75, 4.599999904632568, 0.7699999809265137, 2.309999942779541, 600.0],
[13.630000114440918, 1.809999942779541, 2.700000047683716, 17.200000762939453, 112.0, 2.8499999046325684, 2.9100000858306885, 0.30000001192092896, 1.4600000381469727, 7.300000190734863, 1.2799999713897705, 2.880000114440918, 1310.0],
[12.0, 0.9200000166893005, 2.0, 19.0, 86.0, 2.4200000762939453, 2.259999990463257, 0.30000001192092896, 1.4299999475479126, 2.5, ...],
...
]
>,
#Nx.Tensor<
s64[36]
[0, 2, 0, 1, 2, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 2, 1, 0, 0, 2, 1, 0, 2, 2, 1, 2, 2, 1, 0, 1, 2, 0, 2, 0]
>}
train_inputs = Scholar.Preprocessing.min_max_scale(train_inputs)
test_inputs = Scholar.Preprocessing.min_max_scale(test_inputs)
#Nx.Tensor<
f32[36][13]
[
[0.008656414225697517, 9.955846471711993e-4, 0.001383474562317133, 0.010363130830228329, 0.07552866637706757, 0.0019265207229182124, 0.002017028396949172, 1.0990219016093761e-4, 0.0014028690056875348, 0.003853041445836425, 5.04257099237293e-4, 0.002075212076306343, 0.5138444900512695],
[0.008249129168689251, 0.0017067162552848458, 0.00163560314103961, 0.013466251082718372, 0.061952512711286545, 8.856823551468551e-4, 2.1333953191060573e-4, 2.3273401893675327e-4, 3.749603929463774e-4, 0.0028639216907322407, 3.878900606650859e-4, 0.001383474562317133, 0.38778018951416016],
[0.008701667189598083, 0.00106023286934942, 0.00163560314103961, 0.01100961398333311, 0.07229624688625336, 0.0017325755907222629, 0.0017713647102937102, 8.404285472352058e-5, 8.33963742479682e-4, 0.004609427414834499, 7.175966748036444e-4, 0.0017519702669233084, 0.8467833995819092],
[0.007647899445146322, 4.8486259765923023e-4, 0.0011830647708848119, 0.012173283845186234, 0.055487677454948425, 0.0014545877929776907, 0.0013511504512280226, 8.404285472352058e-5, 8.145691826939583e-4, 0.0015063064638525248, 7.82245013397187e-4, ...],
...
]
>
model = Scholar.Linear.LogisticRegression.fit(train_inputs, train_targets, num_classes: 3)
%Scholar.Linear.LogisticRegression{
coefficients: #Nx.Tensor<
f32[13][3]
[
[1.0102200508117676, 0.8912280797958374, 1.0985535383224487],
[0.9858412742614746, 0.813435435295105, 1.2007232904434204],
[1.0022029876708984, 0.9788302779197693, 1.018965721130371],
[0.6923989057540894, 0.9967496395111084, 1.3108508586883545],
[0.8633057475090027, 1.329253077507019, 0.8074418902397156],
[1.0406365394592285, 1.054695725440979, 0.9046669006347656],
[1.089746117591858, 1.1372164487838745, 0.7730366587638855],
[0.9934093356132507, 0.9891678094863892, 1.0174230337142944],
[1.0170220136642456, 1.0611008405685425, 0.9218771457672119],
[0.9716349244117737, 0.4915732741355896, 1.5367921590805054],
[1.005954384803772, 1.0449081659317017, 0.9491406679153442],
[1.0509141683578491, 1.1178494691848755, 0.8312413096427917],
[14.072208404541016, -8.818758964538574, -2.2534501552581787]
]
>,
bias: #Nx.Tensor<
f32[3]
[-6.293769836425781, 4.521202087402344, 1.7725601196289062]
>
}
test_preds = Scholar.Linear.LogisticRegression.predict(model, test_inputs)
#Nx.Tensor<
s64[36]
[0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0]
>
# Scholar.Metrics.accuracy(test_targets, test_preds)
Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
#Nx.Tensor<
f32
0.7222222089767456
>
# cmat = Scholar.Metrics.confusion_matrix(test_targets, test_preds, num_classes: 3)
cmat = Scholar.Metrics.Classification.confusion_matrix(test_targets, test_preds, num_classes: 3)
#Nx.Tensor<
u64[3][3]
[
[15, 0, 0],
[0, 11, 0],
[2, 8, 0]
]
>
Vl.new(title: "Confusion Matrix", width: 640, height: 400)
|> Vl.data_from_values(%{
predicted: Nx.to_flat_list(test_preds),
actual: Nx.to_flat_list(test_targets)
})
|> Vl.mark(:rect)
|> Vl.encode_field(:x, "predicted")
|> Vl.encode_field(:y, "actual")
|> Vl.encode(:color, aggregate: :count)
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":2,"predicted":0},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":2,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":2,"predicted":0},{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":0,"predicted":0}]},"encoding":{"color":{"aggregate":"count"},"x":{"field":"predicted"},"y":{"field":"actual"}},"height":400,"mark":"rect","title":"Confusion Matrix","width":640}
K-Nearest Neighbors
# model = Scholar.Neighbors.KNearestNeighbors.fit(train_inputs, train_targets, num_classes: 3)
# Reshape train_targets to have shape {142, 1}
train_targets = Nx.reshape(train_targets, {142, 1})
model = Scholar.Neighbors.KNNRegressor.fit(train_inputs, train_targets, num_neighbors: 3)
%Scholar.Neighbors.KNNRegressor{
algorithm: %Scholar.Neighbors.BruteKNN{
num_neighbors: 3,
metric: &Scholar.Metrics.Distance.pairwise_minkowski/2,
data: #Nx.Tensor<
f32[142][13]
[
[0.008173251524567604, 7.262466824613512e-4, 0.0012739080702885985, 0.009447159245610237, 0.058260463178157806, 0.0016965598333626986, 0.0017977581592276692, 5.357557529350743e-5, 0.0010238888207823038, 0.004220564384013414, 5.238500307314098e-4, 0.002035871846601367, 0.6219945549964905],
[0.007167221046984196, 7.85775133408606e-4, 0.0014286820078268647, 0.011233012191951275, 0.06183216720819473, 0.0010477000614628196, 9.643603698350489e-4, 1.904909295262769e-4, 5.357557092793286e-4, 0.001678701490163803, 7.85775133408606e-4, 0.001250096713192761, 0.2112484872341156],
[0.007572014816105366, 8.75067722517997e-4, 0.0014227291103452444, 0.01051867101341486, 0.05647461116313934, 0.0013989177532494068, 0.0013334363466128707, 7.738693966530263e-5, 7.917279726825655e-4, 0.002262079855427146, 5.714728031307459e-4, 0.002083494560793042, 0.604136049747467],
[0.00769107136875391, 0.0011429456062614918, 0.001839428092353046, 0.014804717153310776, 0.07373785227537155, 0.0014882104005664587, 0.0015179747715592384, ...],
...
]
>,
batch_size: nil
},
weights: :uniform,
labels: #Nx.Tensor<
s64[142][1]
[
[0],
[1],
[0],
[0],
[0],
[1],
[2],
[0],
[2],
[1],
[0],
[1],
[0],
[1],
[1],
[2],
[0],
[1],
[1],
[2],
[1],
[1],
[0],
[1],
[0],
[2],
[2],
[2],
[0],
[2],
[2],
[0],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[2],
[1],
[1],
[0],
[1],
[0],
...
]
>
}
# test_preds = Scholar.Neighbors.KNearestNeighbors.predict(model, test_inputs)
test_preds = Scholar.Neighbors.KNNRegressor.predict(model, test_inputs)
#Nx.Tensor<
f32[36][1]
[
[1.0],
[2.0],
[0.0],
[1.0],
[1.6666666269302368],
[1.0],
[1.0],
[0.0],
[0.0],
[1.0],
[0.0],
[1.3333333730697632],
[0.0],
[1.0],
[0.0],
[0.0],
[1.6666666269302368],
[1.6666666269302368],
[1.3333333730697632],
[0.0],
[0.3333333432674408],
[0.0],
[1.0],
[0.0],
[1.6666666269302368],
[1.3333333730697632],
[2.0],
[1.6666666269302368],
[1.6666666269302368],
[1.0],
[0.0],
[1.3333333730697632],
[0.0],
[0.0],
[2.0],
[0.0]
]
>
# Scholar.Metrics.accuracy(test_targets, test_preds)
# Reshape test_preds to have shape {36}
test_preds = Nx.reshape(test_preds, {36})
Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
#Nx.Tensor<
f32
0.5833333134651184
>
# Scholar.Metrics.confusion_matrix(test_targets, test_preds, num_classes: 3)
# Convert test_targets and test_preds to integer tensors
test_targets = Nx.as_type(test_targets, {:s, 64})
test_preds = Nx.as_type(test_preds, {:s, 64})
Scholar.Metrics.Classification.confusion_matrix(test_targets, test_preds, num_classes: 3)
#Nx.Tensor<
u64[3][3]
[
[13, 2, 0],
[0, 10, 1],
[2, 6, 2]
]
>
Vl.new(title: "Confusion Matrix", width: 640, height: 400)
|> Vl.data_from_values(%{
predicted: Nx.to_flat_list(test_preds),
actual: Nx.to_flat_list(test_targets)
})
|> Vl.mark(:rect)
|> Vl.encode_field(:x, "predicted")
|> Vl.encode_field(:y, "actual")
|> Vl.encode(:color, aggregate: :count)
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"actual":0,"predicted":1},{"actual":2,"predicted":2},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":1},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":2,"predicted":0},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":2},{"actual":2,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":2,"predicted":0},{"actual":0,"predicted":0},{"actual":2,"predicted":2},{"actual":0,"predicted":0}]},"encoding":{"color":{"aggregate":"count"},"x":{"field":"predicted"},"y":{"field":"actual"}},"height":400,"mark":"rect","title":"Confusion Matrix","width":640}
K-Means Clustering
model = Scholar.Cluster.KMeans.fit(train_inputs, num_clusters: 3)
%Scholar.Cluster.KMeans{
num_iterations: #Nx.Tensor<
s64
2
>,
clusters: #Nx.Tensor<
f32[3][13]
[
[0.007588574197143316, 0.0014265172649174929, 0.0013569231377914548, 0.011884577572345734, 0.06121523678302765, 0.0011656746501103044, 8.236567955464125e-4, 1.5715500921942294e-4, 8.027677540667355e-4, 0.003267027670517564, 4.480433417484164e-4, 0.001284623285755515, 0.4283323884010315],
[0.008128604851663113, 0.001064897165633738, 0.0013650195905938745, 0.010186304338276386, 0.06313848495483398, 0.0016264485893771052, 0.00172582792583853, 9.75604634732008e-5, 0.0010622515110298991, 0.0032957245130091906, 5.703152855858207e-4, 0.0017532772617414594, 0.7059130668640137],
[0.007372769061475992, 0.0014313665451481938, 0.0012837128015235066, 0.012285848148167133, 0.05398842319846153, 0.0011668737279251218, 9.935408597812057e-4, 1.566414430271834e-4, 8.234764682129025e-4, 0.0022190092131495476, 4.8498151591047645e-4, 0.00145237660035491, 0.27157062292099]
]
>,
inertia: #Nx.Tensor<
f32
0.6447487473487854
>,
labels: #Nx.Tensor<
s64[142]
[1, 2, 1, 0, 1, 2, 0, 0, 0, 2, 1, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 0, 1, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 2, 0, 1, 2, ...]
>
}
test_preds = Scholar.Cluster.KMeans.predict(model, test_inputs)
#Nx.Tensor<
s64[36]
[0, 0, 1, 2, 0, 2, 2, 1, 1, 2, 1, 0, 1, 2, 1, 1, 2, 0, 0, 1, 0, 0, 2, 1, 0, 2, 2, 2, 2, 2, 1, 2, 0, 1, 2, 0]
>
wine_features = %{
"feature_1" => train_inputs[[.., 1]] |> Nx.to_flat_list(),
"feature_2" => train_inputs[[.., 2]] |> Nx.to_flat_list(),
"class" => train_targets |> Nx.to_flat_list()
}
coords = [
cluster_feature_1: model.clusters[[.., 1]] |> Nx.to_flat_list(),
cluster_feature_2: model.clusters[[.., 2]] |> Nx.to_flat_list()
]
Vl.new(
width: 1440,
height: 1080,
title: [
text: "Scatterplot of data samples pojected on plane wine feature 1 x wine feature 2",
offset: 25
]
)
|> Vl.layers([
Vl.new()
|> Vl.data_from_values(wine_features)
|> Vl.mark(:circle)
|> Vl.encode_field(:x, "feature_1", type: :quantitative)
|> Vl.encode_field(:y, "feature_2", type: :quantitative)
|> Vl.encode_field(:color, "class"),
Vl.new()
|> Vl.data_from_values(coords)
|> Vl.mark(:circle, color: :green, size: 100)
|> Vl.encode_field(:x, "cluster_feature_1", type: :quantitative)
|> Vl.encode_field(:y, "cluster_feature_2", type: :quantitative)
])
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","height":1080,"layer":[{"data":{"values":[{"class":0,"feature_1":7.262466824613512e-4,"feature_2":0.0012739080702885985},{"class":1,"feature_1":7.85775133408606e-4,"feature_2":0.0014286820078268647},{"class":0,"feature_1":8.75067722517997e-4,"feature_2":0.0014227291103452444},{"class":0,"feature_1":0.0011429456062614918,"feature_2":0.001839428092353046},{"class":0,"feature_1":8.15539329778403e-4,"feature_2":0.0015298804501071572},{"class":1,"feature_1":7.321995217353106e-4,"feature_2":0.0011250870302319527},{"class":2,"feature_1":9.226904367096722e-4,"feature_2":0.0015298804501071572},{"class":0,"feature_1":8.988790796138346e-4,"feature_2":0.001613220083527267},{"class":2,"feature_1":0.0026668731588870287,"feature_2":0.0013989177532494068},{"class":1,"feature_1":0.001208427012898028,"feature_2":0.0012143796775490046},{"class":0,"feature_1":0.001208427012898028,"feature_2":0.0012917666463181376},{"class":1,"feature_1":0.0024882876314222813,"feature_2":0.001345342374406755},{"class":0,"feature_1":8.333977893926203e-4,"feature_2":0.0015298804501071572},{"class":1,"feature_1":8.214921108447015e-4,"feature_2":0.0015120217576622963},{"class":1,"feature_1":0.002559721702709794,"feature_2":0.0015477387933060527},{"class":2,"feature_1":0.0017263240879401565,"feature_2":0.0013036723248660564},{"class":0,"feature_1":9.94124566204846e-4,"feature_2":0.0015001160791143775},{"class":1,"feature_1":0.0013810594100505114,"feature_2":0.001387012074701488},{"class":1,"feature_1":0.0011429456062614918,"feature_2":0.0018453808734193444},{"class":2,"feature_1":0.001809663837775588,"feature_2":0.0015417860122397542},{"class":1,"feature_1":8.333977893926203e-4,"feature_2":0.0012858137488365173},{"class":1,"feature_1":0.001041747280396521,"feature_2":0.0010655586374923587},{"class":0,"feature_1":9.524546912871301e-4,"feature_2":0.001345342374406755},{"class":1,"feature_1":0.0010000773472711444,"feature_2":0.0012322383699938655},{"class":0,"feature_1":0.0023275609128177166,"feature_2":0.0013751063961535692},{"class":2,"feature_1":7.262466824613512e-4,"feature_2":0.0013036723248660564},{"class":2,"feature_1":0.0014465403510257602,"feature_2":0.001321530668064952},{"class":2,"feature_1":0.002035871846601367,"feature_2":0.001327483681961894},{"class":0,"feature_1":0.0014643990434706211,"feature_2":0.0016310784267261624},{"class":2,"feature_1":0.002125164493918419,"feature_2":0.0014703517081215978},{"class":2,"feature_1":0.0010536529589444399,"feature_2":0.0015596444718539715},{"class":0,"feature_1":9.405490127392113e-4,"feature_2":0.0013691537315025926},{"class":1,"feature_1":8.810205617919564e-4,"feature_2":0.0012381910346448421},{"class":1,"feature_1":9.524546912871301e-4,"feature_2":0.001101275673136115},{"class":1,"feature_1":5.595671245828271e-4,"feature_2":0.0011727097444236279},{"class":0,"feature_1":0.002291843993589282,"feature_2":0.0012858137488365173},{"class":0,"feature_1":9.04831918887794e-4,"feature_2":0.0014405876863747835},{"class":0,"feature_1":0.0012024739990010858,"feature_2":0.00147630472201854},{"class":0,"feature_1":8.988790796138346e-4,"feature_2":0.0012143796775490046},{"class":0,"feature_1":9.762659901753068e-4,"feature_2":0.0011727097444236279},{"class":0,"feature_1":0.001011983142234385,"feature_2":0.001327483681961894},{"class":2,"feature_1":0.0020477776415646076,"feature_2":0.0012024739990010858},{"class":1,"feature_1":6.905295886099339e-4,"feature_2":0.0010774643160402775},{"class":1,"feature_1":5.119444103911519e-4,"feature_2":0.0010834172135218978},{"class":0,"feature_1":9.04831918887794e-4,"feature_2":0.0014703517081215978},{"class":1,"feature_1":9.524546912871301e-4,"feature_2":0.0011846154229715466},{"class":0,"feature_1":9.226904367096722e-4,"feature_2":0.0011846154229715466},{"class":1,"feature_1":0.001190568320453167,"feature_2":0.0015775030478835106},{"class":1,"feature_1":0.0011072285706177354,"feature_2":0.0012798609677702188},{"class":2,"feature_1":0.001416776329278946,"feature_2":0.0013989177532494068},{"class":2,"feature_1":0.001345342374406755,"feature_2":0.0012798609677702188},{"class":1,"feature_1":6.310012540780008e-4,"feature_2":9.643603698350489e-4},{"class":1,"feature_1":6.667182897217572e-4,"feature_2":0.0010655586374923587},{"class":1,"feature_1":9.465018520131707e-4,"feature_2":0.001041747280396521},{"class":1,"feature_1":4.8218018491752446e-4,"feature_2":0.001327483681961894},{"class":1,"feature_1":0.0024823350831866264,"feature_2":0.0013393893605098128},{"class":0,"feature_1":0.0010655586374923587,"feature_2":0.0015417860122397542},{"class":0,"feature_1":9.822188876569271e-4,"feature_2":0.001196521334350109},{"class":0,"feature_1":0.0022085043601691723,"feature_2":0.0011846154229715466},{"class":0,"feature_1":0.0011250870302319527,"feature_2":0.0013512950390577316},{"class":2,"feature_1":0.002363278064876795,"feature_2":0.001553691690787673},{"class":0,"feature_1":0.0010953228920698166,"feature_2":0.0015179747715592384},{"class":1,"feature_1":6.190955173224211e-4,"feature_2":0.0010655586374923587},{"class":0,"feature_1":9.345961734652519e-4,"feature_2":0.0012917666463181376},{"class":2,"feature_1":0.001458446029573679,"feature_2":0.001523927436210215},{"class":1,"feature_1":8.333977893926203e-4,"feature_2":0.0012679552892223},{"class":0,"feature_1":0.002297797007486224,"feature_2":0.001416776329278946},{"class":2,"feature_1":0.0015655973693355918,"feature_2":0.0012917666463181376},{"class":2,"feature_1":0.0020596832036972046,"feature_2":0.0012262853560969234},{"class":1,"feature_1":0.0022204099223017693,"feature_2":0.0013036723248660564},{"class":0,"feature_1":9.107847581617534e-4,"feature_2":0.001327483681961894},{"class":1,"feature_1":7.976808119565248e-4,"feature_2":0.0011072285706177354},{"class":1,"feature_1":8.810205617919564e-4,"feature_2":0.0012977193109691143},{"class":2,"feature_1":0.0027323542162775993,"feature_2":0.0014346347888931632},{"class":1,"feature_1":5.059915711171925e-4,"feature_2":0.0012560496106743813},{"class":1,"feature_1":0.0016072670696303248,"feature_2":0.0012441439321264625},{"class":2,"feature_1":0.002518052002415061,"feature_2":0.0012679552892223},{"class":2,"feature_1":0.002375183627009392,"feature_2":0.0013393893605098128},{"class":2,"feature_1":0.0019287205068394542,"feature_2":0.0012917666463181376},{"class":1,"feature_1":6.429069326259196e-4,"feature_2":0.0014465403510257602},{"class":1,"feature_1":7.619637181051075e-4,"feature_2":0.001101275673136115},{"class":0,"feature_1":9.465018520131707e-4,"feature_2":0.001196521334350109},{"class":2,"feature_1":0.001387012074701488,"feature_2":0.0012322383699938655},{"class":0,"feature_1":9.524546912871301e-4,"feature_2":0.0012739080702885985},{"class":1,"feature_1":5.238500307314098e-4,"feature_2":9.345961734652519e-4},{"class":0,"feature_1":0.0020596832036972046,"feature_2":0.0012798609677702188},{"class":0,"feature_1":9.643603698350489e-4,"feature_2":0.0013632007176056504},{"class":1,"feature_1":8.810205617919564e-4,"feature_2":0.0012262853560969234},{"class":0,"feature_1":8.691148832440376e-4,"feature_2":0.0013989177532494068},{"class":1,"feature_1":8.274449501186609e-4,"feature_2":0.0012322383699938655},{"class":1,"feature_1":0.001160804065875709,"feature_2":9.345961734652519e-4},{"class":2,"feature_1":0.003012137720361352,"feature_2":0.0013036723248660564},{"class":1,"feature_1":0.0022263627033680677,"feature_2":0.0013512950390577316},{"class":1,"feature_1":0.0011846154229715466,"feature_2":0.001553691690787673},{"class":1,"feature_1":5.714728031307459e-4,"feature_2":0.0012917666463181376},{"class":2,"feature_1":0.001976343570277095,"feature_2":0.001321530668064952},{"class":1,"feature_1":0.001011983142234385,"feature_2":0.0013036723248660564},{"class":1,"feature_1":7.262466824613512e-4,"feature_2":0.0015298804501071572},{"class":2,"feature_1":0.0026609201449900866,"feature_2":0.0016251257620751858},{"class":0,"feature_1":0.0010536529589444399,"feature_2":0.0015894087264314294},{"class":1,"feature_1":0.0013691537315025926,"feature_2":0.0012143796775490046},{"class":2,"feature_1":0.0032026288099586964,"feature_2":0.0014941634144634008},{"class":2,"feature_1":6.905295886099339e-4,"feature_2":0.0011727097444236279},{"class":1,"feature_1":9.167375974357128e-4,"feature_2":0.0014703517081215978},{"class":0,"feature_1":0.0010357944993302226,"feature_2":0.0013810594100505114},{"class":2,"feature_1":0.0018870508065447211,"feature_2":0.0012798609677702188},{"class":1,"feature_1":0.0021489758510142565,"feature_2":0.0010060302447527647},{"class":2,"feature_1":0.0016846541548147798,"feature_2":0.00147630472201854},{"class":2,"feature_1":6.607654504477978e-4,"feature_2":0.0012620023917406797},{"class":1,"feature_1":0.0019644377753138542,"feature_2":0.001101275673136115},{"class":2,"feature_1":0.003285968443378806,"feature_2":0.0013810594100505114},{"class":0,"feature_1":0.001327483681961894,"feature_2":0.0015120217576622963},{"class":1,"feature_1":0.0015179747715592384,"feature_2":0.0016608427977189422},{"class":2,"feature_1":0.002250174293294549,"feature_2":0.0013989177532494068},{"class":1,"feature_1":4.52415959443897e-4,"feature_2":0.001458446029573679},{"class":1,"feature_1":7.202938431873918e-4,"feature_2":0.0010536529589444399},{"class":1,"feature_1":0.0013512950390577316,"feature_2":0.0013632007176056504},{"class":1,"feature_1":0.0019644377753138542,"feature_2":0.0011131813516840339},{"class":2,"feature_1":0.0028692695777863264,"feature_2":0.001321530668064952},{"class":1,"feature_1":8.214921108447015e-4,"feature_2":0.0013632007176056504},{"class":0,"feature_1":9.703131509013474e-4,"feature_2":0.0013810594100505114},{"class":0,"feature_1":0.0021846930030733347,"feature_2":0.0015001160791143775},{"class":1,"feature_1":9.167375974357128e-4,"feature_2":0.0014822573866695166},{"class":0,"feature_1":7.738693966530263e-4,"feature_2":0.0014108234317973256},{"class":2,"feature_1":0.0018632394494488835,"feature_2":0.0014346347888931632},{"class":0,"feature_1":0.0010536529589444399,"feature_2":0.0015179747715592384},{"class":1,"feature_1":4.583687987178564e-4,"feature_2":9.405490127392113e-4},{"class":1,"feature_1":5.774256424047053e-4,"feature_2":0.0012798609677702188},{"class":1,"feature_1":5.952841602265835e-4,"feature_2":0.001416776329278946},{"class":2,"feature_1":0.00223231571726501,"feature_2":0.0012322383699938655},{"class":2,"feature_1":0.0018513337709009647,"feature_2":0.0013393893605098128},{"class":0,"feature_1":8.572092046961188e-4,"feature_2":0.0014822573866695166},{"class":1,"feature_1":9.643603698350489e-4,"feature_2":0.0012798609677702188},{"class":1,"feature_1":0.0011488983873277903,"feature_2":0.001387012074701488},{"class":2,"feature_1":0.0029228450730443,"feature_2":0.001250096713192761},{"class":1,"feature_1":0.0033752613235265017,"feature_2":0.001190568320453167},{"class":0,"feature_1":9.524546912871301e-4,"feature_2":0.0011369927087798715},{"class":0,"feature_1":8.15539329778403e-4,"feature_2":0.0011727097444236279},{"class":2,"feature_1":0.0017263240879401565,"feature_2":0.0014941634144634008},{"class":2,"feature_1":9.167375974357128e-4,"feature_2":0.0014941634144634008},{"class":1,"feature_1":3.631233412306756e-4,"feature_2":0.0014108234317973256},{"class":0,"feature_1":0.0010000773472711444,"feature_2":0.0013572480529546738}]},"encoding":{"color":{"field":"class"},"x":{"field":"feature_1","type":"quantitative"},"y":{"field":"feature_2","type":"quantitative"}},"mark":"circle"},{"data":{"values":[{"cluster_feature_1":0.0014265172649174929,"cluster_feature_2":0.0013569231377914548},{"cluster_feature_1":0.001064897165633738,"cluster_feature_2":0.0013650195905938745},{"cluster_feature_1":0.0014313665451481938,"cluster_feature_2":0.0012837128015235066}]},"encoding":{"x":{"field":"cluster_feature_1","type":"quantitative"},"y":{"field":"cluster_feature_2","type":"quantitative"}},"mark":{"color":"green","size":100,"type":"circle"}}],"title":{"offset":25,"text":"Scatterplot of data samples pojected on plane wine feature 1 x wine feature 2"},"width":1440}
wine_features
%{
"class" => [0, 1, 0, 0, 0, 1, 2, 0, 2, 1, 0, 1, 0, 1, 1, 2, 0, 1, 1, 2, 1, 1, 0, 1, 0, 2, 2, 2, 0,
2, 2, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 1, 0, 1, 1, ...],
"feature_1" => [7.262466824613512e-4, 7.85775133408606e-4, 8.75067722517997e-4,
0.0011429456062614918, 8.15539329778403e-4, 7.321995217353106e-4, 9.226904367096722e-4,
8.988790796138346e-4, 0.0026668731588870287, 0.001208427012898028, 0.001208427012898028,
0.0024882876314222813, 8.333977893926203e-4, 8.214921108447015e-4, 0.002559721702709794,
0.0017263240879401565, 9.94124566204846e-4, 0.0013810594100505114, 0.0011429456062614918,
0.001809663837775588, 8.333977893926203e-4, 0.001041747280396521, 9.524546912871301e-4,
0.0010000773472711444, 0.0023275609128177166, 7.262466824613512e-4, 0.0014465403510257602,
0.002035871846601367, 0.0014643990434706211, 0.002125164493918419, 0.0010536529589444399,
9.405490127392113e-4, 8.810205617919564e-4, 9.524546912871301e-4, 5.595671245828271e-4,
0.002291843993589282, 9.04831918887794e-4, 0.0012024739990010858, 8.988790796138346e-4,
9.762659901753068e-4, 0.001011983142234385, 0.0020477776415646076, 6.905295886099339e-4,
5.119444103911519e-4, 9.04831918887794e-4, 9.524546912871301e-4, 9.226904367096722e-4,
0.001190568320453167, ...],
"feature_2" => [0.0012739080702885985, 0.0014286820078268647, 0.0014227291103452444,
0.001839428092353046, 0.0015298804501071572, 0.0011250870302319527, 0.0015298804501071572,
0.001613220083527267, 0.0013989177532494068, 0.0012143796775490046, 0.0012917666463181376,
0.001345342374406755, 0.0015298804501071572, 0.0015120217576622963, 0.0015477387933060527,
0.0013036723248660564, 0.0015001160791143775, 0.001387012074701488, 0.0018453808734193444,
0.0015417860122397542, 0.0012858137488365173, 0.0010655586374923587, 0.001345342374406755,
0.0012322383699938655, 0.0013751063961535692, 0.0013036723248660564, 0.001321530668064952,
0.001327483681961894, 0.0016310784267261624, 0.0014703517081215978, 0.0015596444718539715,
0.0013691537315025926, 0.0012381910346448421, 0.001101275673136115, 0.0011727097444236279,
0.0012858137488365173, 0.0014405876863747835, 0.00147630472201854, 0.0012143796775490046,
0.0011727097444236279, 0.001327483681961894, 0.0012024739990010858, 0.0010774643160402775,
0.0010834172135218978, 0.0014703517081215978, 0.0011846154229715466, 0.0011846154229715466, ...]
}
# Scholar.Metrics.accuracy(test_targets, test_preds)
Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
#Nx.Tensor<
f32
0.2222222238779068
>