Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Ch 5: Traditional Machine Learning

ch5_traditional_ml.livemd

Ch 5: Traditional Machine Learning

Mix.install([
  {:scholar, "~> 0.2.1"},
  {:nx, "~> 0.7.2"},
  {:exla, "~> 0.7.2"},
  {:vega_lite, "~> 0.1.9"},
  {:kino_vega_lite, "~> 0.1.11"},
  {:scidata, "~> 0.1.11"}
])

Learning Linearly

Recent advances in ML have largely been attributed to Deep Learning, a subset of ML that uses neural networks. The world of ML is vast and contains other useful algorithms and architectures.

Shallow ML is non-deep learning based algorithms. While reaching for DL might be the most exciting thing to do right now, a simple linear regression model can also yield just as useful results.

A simple class of traditional ML algorithms are linear models. Linear models assume linearity in the underlying relationship between inputs and outputs, put another way, the input data can be modeled with a line. Reality is almost never linear, but linear models are still powerful at modeling the real-world.

Linear Regression with Scholar

Nx.default_backend(EXLA.Backend)
Nx.Defn.default_options(compiler: EXLA)
[]
m = :rand.uniform() * 10
b = :rand.uniform() * 10
key = Nx.Random.key(42)
size = 100
{x, new_key} = Nx.Random.normal(key, 0.0, 1.0, shape: {size, 1})
{noise_x, new_key} = Nx.Random.normal(new_key, 0.0, 1.0, shape: {size, 1})
{#Nx.Tensor<
   f32[100][1]
   EXLA.Backend
   [
     [-0.2715127170085907],
     [-0.044035933911800385],
     [-0.013251191936433315],
     [-1.528984785079956],
     [-0.9314795732498169],
     [0.1732589155435562],
     [0.5015048980712891],
     [1.0113904476165771],
     [0.8057566285133362],
     [0.0336226150393486],
     [-0.2858802080154419],
     [1.6919035911560059],
     [-0.6257336139678955],
     [0.7363636493682861],
     [0.8088681697845459],
     [0.0884171724319458],
     [-0.6270047426223755],
     [-0.9407665133476257],
     [1.697560429573059],
     [-1.209107756614685],
     [-1.2284613847732544],
     [0.6404958963394165],
     [1.5345338582992554],
     [0.8201174736022949],
     [1.1316134929656982],
     [0.5564783811569214],
     [1.7121386528015137],
     [1.0940232276916504],
     [0.3224945068359375],
     [-1.6643178462982178],
     [-0.5513103604316711],
     [0.6727809309959412],
     [1.5617197751998901],
     [0.07896070182323456],
     [1.4174407720565796],
     [1.3826879262924194],
     [0.2648073136806488],
     [-1.2951784133911133],
     [-1.0437983274459839],
     [0.44099709391593933],
     [0.20175093412399292],
     [2.5967326164245605],
     [0.4468730390071869],
     [-0.38861167430877686],
     [0.15156064927577972],
     [1.034748911857605],
     [0.7082746624946594],
     [-1.183918833732605],
     [-1.221583366394043],
     ...
   ]
 >,
 #Nx.Tensor<
   u32[2]
   EXLA.Backend
   [3164236999, 3984487275]
 >}
IO.inspect({m, b})
{5.448026224760931, 2.7821239930131636}
{5.448026224760931, 2.7821239930131636}
y =
  m
  |> Nx.multiply(Nx.add(x, noise_x))
  |> Nx.add(b)
#Nx.Tensor<
  f32[100][1]
  EXLA.Backend
  [
    [-1.9394278526306152],
    [2.614504814147949],
    [6.729975700378418],
    [-4.367074966430664],
    [4.617206573486328],
    [4.172173023223877],
    [5.454224586486816],
    [0.3875296115875244],
    [4.223370552062988],
    [7.745285511016846],
    [-1.435800552368164],
    [18.743972778320312],
    [4.32711124420166],
    [7.776525020599365],
    [4.153141975402832],
    [7.245344161987305],
    [1.2556202411651611],
    [0.5421628952026367],
    [12.23556900024414],
    [-11.486454010009766],
    [-11.733905792236328],
    [0.9676380157470703],
    [3.8558783531188965],
    [3.5971550941467285],
    [8.387855529785156],
    [7.569920539855957],
    [9.651819229125977],
    [8.704238891601562],
    [18.26831817626953],
    [-8.990447998046875],
    [-12.304100036621094],
    [6.3741865158081055],
    [13.109359741210938],
    [7.561161041259766],
    [11.644691467285156],
    [21.11670684814453],
    [6.721423149108887],
    [-1.2887849807739258],
    [-9.464832305908203],
    [16.10643768310547],
    [4.784914970397949],
    [26.214235305786133],
    [4.338867664337158],
    [7.44886589050293],
    [4.444395542144775],
    [16.616832733154297],
    [6.716985702514648],
    [3.852010726928711],
    [-5.76701021194458],
    [-0.6862599849700928],
    ...
  ]
>
alias VegaLite, as: Vl
VegaLite
Vl.new(title: "Scatterplot", width: 720, height: 480)
|> Vl.data_from_values(%{
  x: Nx.to_flat_list(x),
  y: Nx.to_flat_list(y)
})
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"x":-0.5951409339904785,"y":-1.9394278526306152},{"x":0.0132689718157053,"y":2.614504814147949},{"x":0.7378900647163391,"y":6.729975700378418},{"x":0.21672981977462769,"y":-4.367074966430664},{"x":1.268314003944397,"y":4.617206573486328},{"x":0.08188837766647339,"y":4.172173023223877},{"x":-0.01103353314101696,"y":5.454224586486816},{"x":-1.4509247541427612,"y":0.3875296115875244},{"x":-0.5412119030952454,"y":4.223370552062988},{"x":0.8773791790008545,"y":7.745285511016846},{"x":-0.48833128809928894,"y":-1.435800552368164},{"x":1.237937092781067,"y":18.743972778320312},{"x":0.9093202352523804,"y":4.32711124420166},{"x":0.1803722381591797,"y":7.776525020599365},{"x":-0.5572141408920288,"y":4.153141975402832},{"x":0.7308189868927002,"y":7.245344161987305},{"x":0.34681081771850586,"y":1.2556202411651611},{"x":0.529615581035614,"y":0.5421628952026367},{"x":0.03764510527253151,"y":12.23556900024414},{"x":-1.409928560256958,"y":-11.486454010009766},{"x":-1.43599534034729,"y":-11.733905792236328},{"x":-0.9735497236251831,"y":0.9676380157470703},{"x":-1.3374433517456055,"y":3.8558783531188965},{"x":-0.670516312122345,"y":3.5971550941467285},{"x":-0.10266610980033875,"y":8.387855529785156},{"x":0.3223346173763275,"y":7.569920539855957},{"x":-0.4511873126029968,"y":9.651819229125977},{"x":-0.007002972066402435,"y":8.704238891601562},{"x":2.520038366317749,"y":18.26831817626953},{"x":-0.4965696930885315,"y":-8.990447998046875},{"x":-2.2178070545196533,"y":-12.304100036621094},{"x":-0.01344812661409378,"y":6.3741865158081055},{"x":0.33387258648872375,"y":13.109359741210938},{"x":0.7982445359230042,"y":7.561161041259766},{"x":0.20930752158164978,"y":11.644691467285156},{"x":1.9826748371124268,"y":21.11670684814453},{"x":0.45826175808906555,"y":6.721423149108887},{"x":0.5479519963264465,"y":-1.2887849807739258},{"x":-1.2041635513305664,"y":-9.464832305908203},{"x":2.0047168731689453,"y":16.10643768310547},{"x":0.16586682200431824,"y":4.784914970397949},{"x":1.70429527759552,"y":26.214235305786133},{"x":-0.16112852096557617,"y":4.338867664337158},{"x":1.2452048063278198,"y":7.44886589050293},{"x":0.1535537838935852,"y":4.444395542144775},{"x":1.5046494007110596,"y":16.616832733154297},{"x":0.013979915529489517,"y":6.716985702514648},{"x":1.38029944896698,"y":3.852010726928711},{"x":-0.3476333022117615,"y":-5.76701021194458},{"x":0.9467588067054749,"y":-0.6862599849700928},{"x":-1.1229523420333862,"y":-3.271815299987793},{"x":1.3731545209884644,"y":9.173779487609863},{"x":1.5248982906341553,"y":2.6197640895843506},{"x":-0.83209627866745,"y":-2.387322425842285},{"x":-0.4361415207386017,"y":-7.312780857086182},{"x":-0.9708554148674011,"y":2.7770702838897705},{"x":-0.02753218077123165,"y":6.3973283767700195},{"x":-0.6075385808944702,"y":-3.3932971954345703},{"x":0.69698166847229,"y":8.754046440124512},{"x":-1.5198605060577393,"y":-1.831456184387207},{"x":1.4279040098190308,"y":18.037765502929688},{"x":-0.7271621227264404,"y":-2.619161605834961},{"x":0.6438417434692383,"y":3.621098279953003},{"x":0.02894657850265503,"y":-4.779365062713623},{"x":-0.7837653756141663,"y":-1.1040380001068115},{"x":-0.7311482429504395,"y":-4.5294389724731445},{"x":-1.8962322473526,"y":-8.572715759277344},{"x":-0.09240781515836716,"y":-5.771549701690674},{"x":0.8312447667121887,"y":12.801166534423828},{"x":0.08850807696580887,"y":2.589076519012451},{"x":1.0102239847183228,"y":5.568356513977051},{"x":-0.9956196546554565,"y":2.8438515663146973},{"x":0.050282858312129974,"y":6.292440414428711},{"x":1.5256069898605347,"y":18.652807235717773},{"x":-1.3656818866729736,"y":-13.162588119506836},{"x":0.3665124475955963,"y":5.768397331237793},{"x":-1.4678500890731812,"y":-1.1319243907928467},{"x":-1.353937029838562,"y":-5.991240978240967},{"x":-0.7861266136169434,"y":-3.6362862586975098},{"x":-1.6958731412887573,"y":-14.031381607055664},{"x":0.10122056305408478,"y":-13.500997543334961},{"x":1.315024971961975,"y":5.580056190490723},{"x":0.6036592125892639,"y":5.3160400390625},{"x":-2.478543519973755,"y":-1.9852619171142578},{"x":-0.7435773611068726,"y":5.611536979675293},{"x":-2.687534809112549,"y":-9.937742233276367},{"x":0.6872367858886719,"y":0.15830230712890625},{"x":1.3363817930221558,"y":13.055208206176758},{"x":-0.5790861248970032,"y":-3.255882740020752},{"x":0.3402725160121918,"y":8.008611679077148},{"x":-0.5970014929771423,"y":-3.765350818634033},{"x":1.373253345489502,"y":22.597959518432617},{"x":1.02512788772583,"y":18.59661865234375},{"x":-1.009821891784668,"y":-2.3810882568359375},{"x":-0.14097319543361664,"y":2.4240105152130127},{"x":-0.027422772720456123,"y":9.34371566772461},{"x":0.9244212508201599,"y":6.456722736358643},{"x":0.7690575122833252,"y":2.4200491905212402},{"x":1.485957384109497,"y":17.31890296936035},{"x":-0.8444464802742004,"y":-12.665731430053711}]},"encoding":{"x":{"field":"x","type":"quantitative"},"y":{"field":"y","type":"quantitative"}},"height":480,"mark":"point","title":"Scatterplot","width":720}

The data generally follows a line, the line is not perfect because of the guassian noise introduced

To fit a linear regression model on input variables x and output variables y, this can be done with Scholar. Under the hood, Scholar.Linear.LinearRegression.fit/2 finds a best-fine using ordinary least-squares regression

model = Scholar.Linear.LinearRegression.fit(x, y)
%Scholar.Linear.LinearRegression{
  coefficients: #Nx.Tensor<
    f32[1][1]
    EXLA.Backend
    [
      [6.127957344055176]
    ]
  >,
  intercept: #Nx.Tensor<
    f32[1]
    EXLA.Backend
    [3.203505754470825]
  >
}
pred_xs = Nx.linspace(-3.0, 3.0, n: 100) |> Nx.new_axis(-1)
pred_ys = Scholar.Linear.LinearRegression.predict(model, pred_xs)
#Nx.Tensor<
  f32[100][1]
  EXLA.Backend
  [
    [-15.180366516113281],
    [-14.808975219726562],
    [-14.437583923339844],
    [-14.066191673278809],
    [-13.69480037689209],
    [-13.323410034179688],
    [-12.952017784118652],
    [-12.580626487731934],
    [-12.209235191345215],
    [-11.837843894958496],
    [-11.466453552246094],
    [-11.095061302185059],
    [-10.72367000579834],
    [-10.352278709411621],
    [-9.980886459350586],
    [-9.609495162963867],
    [-9.238104820251465],
    [-8.866713523864746],
    [-8.495321273803711],
    [-8.123929977416992],
    [-7.752539157867432],
    [-7.381147861480713],
    [-7.009756088256836],
    [-6.638364315032959],
    [-6.266973495483398],
    [-5.895582675933838],
    [-5.524190902709961],
    [-5.152799129486084],
    [-4.781407833099365],
    [-4.410017013549805],
    [-4.038625240325928],
    [-3.66723370552063],
    [-3.2958426475524902],
    [-2.9244515895843506],
    [-2.553060531616211],
    [-2.1816680431365967],
    [-1.8102771043777466],
    [-1.4388861656188965],
    [-1.0674936771392822],
    [-0.6961026787757874],
    [-0.3247116804122925],
    [0.046679332852363586],
    [0.41807034611701965],
    [0.7894628047943115],
    [1.1608537435531616],
    [1.5322448015213013],
    [1.9036372900009155],
    [2.2750282287597656],
    [2.6464192867279053],
    [3.017810344696045],
    ...
  ]
>
Vl.new(
  title: "Scatterplot Distribution and Fit Cruve",
  width: 720,
  height: 480
)
|> Vl.data_from_values(%{
  x: Nx.to_flat_list(x),
  y: Nx.to_flat_list(x),
  pred_x: Nx.to_flat_list(pred_xs),
  pred_y: Nx.to_flat_list(pred_ys)
})
|> Vl.layers([
  Vl.new()
  |> Vl.mark(:point)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative),
  Vl.new()
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "pred_x", type: :quantitative)
  |> Vl.encode_field(:y, "pred_y", type: :quantitative)
])
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"pred_x":-3.0,"pred_y":-15.180366516113281,"x":-0.5951409339904785,"y":-0.5951409339904785},{"pred_x":-2.939393997192383,"pred_y":-14.808975219726562,"x":0.0132689718157053,"y":0.0132689718157053},{"pred_x":-2.8787879943847656,"pred_y":-14.437583923339844,"x":0.7378900647163391,"y":0.7378900647163391},{"pred_x":-2.8181817531585693,"pred_y":-14.066191673278809,"x":0.21672981977462769,"y":0.21672981977462769},{"pred_x":-2.757575750350952,"pred_y":-13.69480037689209,"x":1.268314003944397,"y":1.268314003944397},{"pred_x":-2.696969747543335,"pred_y":-13.323410034179688,"x":0.08188837766647339,"y":0.08188837766647339},{"pred_x":-2.6363635063171387,"pred_y":-12.952017784118652,"x":-0.01103353314101696,"y":-0.01103353314101696},{"pred_x":-2.5757575035095215,"pred_y":-12.580626487731934,"x":-1.4509247541427612,"y":-1.4509247541427612},{"pred_x":-2.5151515007019043,"pred_y":-12.209235191345215,"x":-0.5412119030952454,"y":-0.5412119030952454},{"pred_x":-2.454545497894287,"pred_y":-11.837843894958496,"x":0.8773791790008545,"y":0.8773791790008545},{"pred_x":-2.39393949508667,"pred_y":-11.466453552246094,"x":-0.48833128809928894,"y":-0.48833128809928894},{"pred_x":-2.3333332538604736,"pred_y":-11.095061302185059,"x":1.237937092781067,"y":1.237937092781067},{"pred_x":-2.2727272510528564,"pred_y":-10.72367000579834,"x":0.9093202352523804,"y":0.9093202352523804},{"pred_x":-2.2121212482452393,"pred_y":-10.352278709411621,"x":0.1803722381591797,"y":0.1803722381591797},{"pred_x":-2.151515007019043,"pred_y":-9.980886459350586,"x":-0.5572141408920288,"y":-0.5572141408920288},{"pred_x":-2.090909004211426,"pred_y":-9.609495162963867,"x":0.7308189868927002,"y":0.7308189868927002},{"pred_x":-2.0303030014038086,"pred_y":-9.238104820251465,"x":0.34681081771850586,"y":0.34681081771850586},{"pred_x":-1.9696969985961914,"pred_y":-8.866713523864746,"x":0.529615581035614,"y":0.529615581035614},{"pred_x":-1.9090908765792847,"pred_y":-8.495321273803711,"x":0.03764510527253151,"y":0.03764510527253151},{"pred_x":-1.848484754562378,"pred_y":-8.123929977416992,"x":-1.409928560256958,"y":-1.409928560256958},{"pred_x":-1.7878787517547607,"pred_y":-7.752539157867432,"x":-1.43599534034729,"y":-1.43599534034729},{"pred_x":-1.7272727489471436,"pred_y":-7.381147861480713,"x":-0.9735497236251831,"y":-0.9735497236251831},{"pred_x":-1.6666666269302368,"pred_y":-7.009756088256836,"x":-1.3374433517456055,"y":-1.3374433517456055},{"pred_x":-1.60606050491333,"pred_y":-6.638364315032959,"x":-0.670516312122345,"y":-0.670516312122345},{"pred_x":-1.545454502105713,"pred_y":-6.266973495483398,"x":-0.10266610980033875,"y":-0.10266610980033875},{"pred_x":-1.4848484992980957,"pred_y":-5.895582675933838,"x":0.3223346173763275,"y":0.3223346173763275},{"pred_x":-1.424242377281189,"pred_y":-5.524190902709961,"x":-0.4511873126029968,"y":-0.4511873126029968},{"pred_x":-1.3636362552642822,"pred_y":-5.152799129486084,"x":-0.007002972066402435,"y":-0.007002972066402435},{"pred_x":-1.303030252456665,"pred_y":-4.781407833099365,"x":2.520038366317749,"y":2.520038366317749},{"pred_x":-1.2424242496490479,"pred_y":-4.410017013549805,"x":-0.4965696930885315,"y":-0.4965696930885315},{"pred_x":-1.1818181276321411,"pred_y":-4.038625240325928,"x":-2.2178070545196533,"y":-2.2178070545196533},{"pred_x":-1.1212120056152344,"pred_y":-3.66723370552063,"x":-0.01344812661409378,"y":-0.01344812661409378},{"pred_x":-1.0606060028076172,"pred_y":-3.2958426475524902,"x":0.33387258648872375,"y":0.33387258648872375},{"pred_x":-1.0,"pred_y":-2.9244515895843506,"x":0.7982445359230042,"y":0.7982445359230042},{"pred_x":-0.9393939971923828,"pred_y":-2.553060531616211,"x":0.20930752158164978,"y":0.20930752158164978},{"pred_x":-0.8787877559661865,"pred_y":-2.1816680431365967,"x":1.9826748371124268,"y":1.9826748371124268},{"pred_x":-0.8181817531585693,"pred_y":-1.8102771043777466,"x":0.45826175808906555,"y":0.45826175808906555},{"pred_x":-0.7575757503509521,"pred_y":-1.4388861656188965,"x":0.5479519963264465,"y":0.5479519963264465},{"pred_x":-0.6969695091247559,"pred_y":-1.0674936771392822,"x":-1.2041635513305664,"y":-1.2041635513305664},{"pred_x":-0.6363635063171387,"pred_y":-0.6961026787757874,"x":2.0047168731689453,"y":2.0047168731689453},{"pred_x":-0.5757575035095215,"pred_y":-0.3247116804122925,"x":0.16586682200431824,"y":0.16586682200431824},{"pred_x":-0.5151515007019043,"pred_y":0.046679332852363586,"x":1.70429527759552,"y":1.70429527759552},{"pred_x":-0.4545454978942871,"pred_y":0.41807034611701965,"x":-0.16112852096557617,"y":-0.16112852096557617},{"pred_x":-0.3939392566680908,"pred_y":0.7894628047943115,"x":1.2452048063278198,"y":1.2452048063278198},{"pred_x":-0.33333325386047363,"pred_y":1.1608537435531616,"x":0.1535537838935852,"y":0.1535537838935852},{"pred_x":-0.27272725105285645,"pred_y":1.5322448015213013,"x":1.5046494007110596,"y":1.5046494007110596},{"pred_x":-0.21212100982666016,"pred_y":1.9036372900009155,"x":0.013979915529489517,"y":0.013979915529489517},{"pred_x":-0.15151500701904297,"pred_y":2.2750282287597656,"x":1.38029944896698,"y":1.38029944896698},{"pred_x":-0.09090900421142578,"pred_y":2.6464192867279053,"x":-0.3476333022117615,"y":-0.3476333022117615},{"pred_x":-0.030303001403808594,"pred_y":3.017810344696045,"x":0.9467588067054749,"y":0.9467588067054749},{"pred_x":0.030303001403808594,"pred_y":3.3892011642456055,"x":-1.1229523420333862,"y":-1.1229523420333862},{"pred_x":0.09090924263000488,"pred_y":3.7605936527252197,"x":1.3731545209884644,"y":1.3731545209884644},{"pred_x":0.15151524543762207,"pred_y":4.131984710693359,"x":1.5248982906341553,"y":1.5248982906341553},{"pred_x":0.21212124824523926,"pred_y":4.50337553024292,"x":-0.83209627866745,"y":-0.83209627866745},{"pred_x":0.27272748947143555,"pred_y":4.874768257141113,"x":-0.4361415207386017,"y":-0.4361415207386017},{"pred_x":0.33333349227905273,"pred_y":5.246159076690674,"x":-0.9708554148674011,"y":-0.9708554148674011},{"pred_x":0.3939394950866699,"pred_y":5.617550373077393,"x":-0.02753218077123165,"y":-0.02753218077123165},{"pred_x":0.4545454978942871,"pred_y":5.988941192626953,"x":-0.6075385808944702,"y":-0.6075385808944702},{"pred_x":0.5151515007019043,"pred_y":6.360332012176514,"x":0.69698166847229,"y":0.69698166847229},{"pred_x":0.5757577419281006,"pred_y":6.731724739074707,"x":-1.5198605060577393,"y":-1.5198605060577393},{"pred_x":0.6363637447357178,"pred_y":7.103115558624268,"x":1.4279040098190308,"y":1.4279040098190308},{"pred_x":0.696969747543335,"pred_y":7.474506855010986,"x":-0.7271621227264404,"y":-0.7271621227264404},{"pred_x":0.7575759887695312,"pred_y":7.8458991050720215,"x":0.6438417434692383,"y":0.6438417434692383},{"pred_x":0.8181819915771484,"pred_y":8.217289924621582,"x":0.02894657850265503,"y":0.02894657850265503},{"pred_x":0.8787879943847656,"pred_y":8.5886812210083,"x":-0.7837653756141663,"y":-0.7837653756141663},{"pred_x":0.9393939971923828,"pred_y":8.96007251739502,"x":-0.7311482429504395,"y":-0.7311482429504395},{"pred_x":1.0,"pred_y":9.331462860107422,"x":-1.8962322473526,"y":-1.8962322473526},{"pred_x":1.0606060028076172,"pred_y":9.70285415649414,"x":-0.09240781515836716,"y":-0.09240781515836716},{"pred_x":1.1212120056152344,"pred_y":10.07424545288086,"x":0.8312447667121887,"y":0.8312447667121887},{"pred_x":1.1818184852600098,"pred_y":10.445638656616211,"x":0.08850807696580887,"y":0.08850807696580887},{"pred_x":1.242424488067627,"pred_y":10.81702995300293,"x":1.0102239847183228,"y":1.0102239847183228},{"pred_x":1.3030304908752441,"pred_y":11.188421249389648,"x":-0.9956196546554565,"y":-0.9956196546554565},{"pred_x":1.3636364936828613,"pred_y":11.55981159210205,"x":0.050282858312129974,"y":0.050282858312129974},{"pred_x":1.4242424964904785,"pred_y":11.93120288848877,"x":1.5256069898605347,"y":1.5256069898605347},{"pred_x":1.4848484992980957,"pred_y":12.302594184875488,"x":-1.3656818866729736,"y":-1.3656818866729736},{"pred_x":1.545454502105713,"pred_y":12.673985481262207,"x":0.3665124475955963,"y":0.3665124475955963},{"pred_x":1.6060609817504883,"pred_y":13.045378684997559,"x":-1.4678500890731812,"y":-1.4678500890731812},{"pred_x":1.6666669845581055,"pred_y":13.416769981384277,"x":-1.353937029838562,"y":-1.353937029838562},{"pred_x":1.7272729873657227,"pred_y":13.788161277770996,"x":-0.7861266136169434,"y":-0.7861266136169434},{"pred_x":1.7878789901733398,"pred_y":14.159551620483398,"x":-1.6958731412887573,"y":-1.6958731412887573},{"pred_x":1.848484992980957,"pred_y":14.530942916870117,"x":0.10122056305408478,"y":0.10122056305408478},{"pred_x":1.9090909957885742,"pred_y":14.902334213256836,"x":1.315024971961975,"y":1.315024971961975},{"pred_x":1.9696969985961914,"pred_y":15.273724555969238,"x":0.6036592125892639,"y":0.6036592125892639},{"pred_x":2.0303030014038086,"pred_y":15.645115852355957,"x":-2.478543519973755,"y":-2.478543519973755},{"pred_x":2.090909004211426,"pred_y":16.01650619506836,"x":-0.7435773611068726,"y":-0.7435773611068726},{"pred_x":2.151515483856201,"pred_y":16.387901306152344,"x":-2.687534809112549,"y":-2.687534809112549},{"pred_x":2.2121214866638184,"pred_y":16.759292602539062,"x":0.6872367858886719,"y":0.6872367858886719},{"pred_x":2.2727274894714355,"pred_y":17.13068199157715,"x":1.3363817930221558,"y":1.3363817930221558},{"pred_x":2.3333334922790527,"pred_y":17.502073287963867,"x":-0.5790861248970032,"y":-0.5790861248970032},{"pred_x":2.39393949508667,"pred_y":17.873464584350586,"x":0.3402725160121918,"y":0.3402725160121918},{"pred_x":2.454545497894287,"pred_y":18.244855880737305,"x":-0.5970014929771423,"y":-0.5970014929771423},{"pred_x":2.5151515007019043,"pred_y":18.616247177124023,"x":1.373253345489502,"y":1.373253345489502},{"pred_x":2.5757579803466797,"pred_y":18.987640380859375,"x":1.02512788772583,"y":1.02512788772583},{"pred_x":2.636363983154297,"pred_y":19.359031677246094,"x":-1.009821891784668,"y":-1.009821891784668},{"pred_x":2.696969985961914,"pred_y":19.730422973632812,"x":-0.14097319543361664,"y":-0.14097319543361664},{"pred_x":2.7575759887695312,"pred_y":20.10181427001953,"x":-0.027422772720456123,"y":-0.027422772720456123},{"pred_x":2.8181819915771484,"pred_y":20.47320556640625,"x":0.9244212508201599,"y":0.9244212508201599},{"pred_x":2.8787879943847656,"pred_y":20.844594955444336,"x":0.7690575122833252,"y":0.7690575122833252},{"pred_x":2.939393997192383,"pred_y":21.215986251831055,"x":1.485957384109497,"y":1.485957384109497},{"pred_x":3.0,"pred_y":21.587377548217773,"x":-0.8444464802742004,"y":-0.8444464802742004}]},"height":480,"layer":[{"encoding":{"x":{"field":"x","type":"quantitative"},"y":{"field":"y","type":"quantitative"}},"mark":"point"},{"encoding":{"x":{"field":"pred_x","type":"quantitative"},"y":{"field":"pred_y","type":"quantitative"}},"mark":"line"}],"title":"Scatterplot Distribution and Fit Cruve","width":720}

Logistic Regression with Scholar

Logistic Regression is often used for classification. Logistic regression is almost identical linear regression. The difference is, after applying the linear transformation on the input variables, a logistic function is also applied. This logistic function squeezs the output between 0 and 1. This often represents a probability for a binary classification problem

{inputs, targets} = Scidata.Wine.download()
{[
   [14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0],
   [13.2, 1.78, 2.14, 11.2, 100.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.4, 1050.0],
   [13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.68, 1.03, 3.17, 1185.0],
   [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0],
   [13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0],
   [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0],
   [14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0],
   [14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0],
   [14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0],
   [13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.22, 1.01, 3.55, 1045.0],
   [14.1, 2.16, 2.3, 18.0, 105.0, 2.95, 3.32, 0.22, 2.38, 5.75, 1.25, 3.17, 1510.0],
   [14.12, 1.48, 2.32, 16.8, 95.0, 2.2, 2.43, 0.26, 1.57, 5.0, 1.17, 2.82, 1280.0],
   [13.75, 1.73, 2.41, 16.0, 89.0, 2.6, 2.76, 0.29, 1.81, 5.6, 1.15, 2.9, 1320.0],
   [14.75, 1.73, 2.39, 11.4, 91.0, 3.1, 3.69, 0.43, 2.81, 5.4, 1.25, 2.73, 1150.0],
   [14.38, 1.87, 2.38, 12.0, 102.0, 3.3, 3.64, 0.29, 2.96, 7.5, 1.2, 3.0, 1547.0],
   [13.63, 1.81, 2.7, 17.2, 112.0, 2.85, 2.91, 0.3, 1.46, 7.3, 1.28, 2.88, 1310.0],
   [14.3, 1.92, 2.72, 20.0, 120.0, 2.8, 3.14, 0.33, 1.97, 6.2, 1.07, 2.65, 1280.0],
   [13.83, 1.57, 2.62, 20.0, 115.0, 2.95, 3.4, 0.4, 1.72, 6.6, 1.13, 2.57, 1130.0],
   [14.19, 1.59, 2.48, 16.5, 108.0, 3.3, 3.93, 0.32, 1.86, 8.7, 1.23, 2.82, 1680.0],
   [13.64, 3.1, 2.56, 15.2, 116.0, 2.7, 3.03, 0.17, 1.66, 5.1, 0.96, 3.36, 845.0],
   [14.06, 1.63, 2.28, 16.0, 126.0, 3.0, 3.17, 0.24, 2.1, 5.65, 1.09, 3.71, 780.0],
   [12.93, 3.8, 2.65, 18.6, 102.0, 2.41, 2.41, 0.25, 1.98, 4.5, 1.03, 3.52, 770.0],
   [13.71, 1.86, 2.36, 16.6, 101.0, 2.61, 2.88, 0.27, 1.69, 3.8, 1.11, 4.0, 1035.0],
   [12.85, 1.6, 2.52, 17.8, 95.0, 2.48, 2.37, 0.26, 1.46, 3.93, 1.09, 3.63, 1015.0],
   [13.5, 1.81, 2.61, 20.0, 96.0, 2.53, 2.61, 0.28, 1.66, 3.52, 1.12, 3.82, 845.0],
   [13.05, 2.05, 3.22, 25.0, 124.0, 2.63, 2.68, 0.47, 1.92, 3.58, 1.13, 3.2, 830.0],
   [13.39, 1.77, 2.62, 16.1, 93.0, 2.85, 2.94, 0.34, 1.45, 4.8, 0.92, 3.22, 1195.0],
   [13.3, 1.72, 2.14, 17.0, 94.0, 2.4, 2.19, 0.27, 1.35, 3.95, 1.02, 2.77, 1285.0],
   [13.87, 1.9, 2.8, 19.4, 107.0, 2.95, 2.97, 0.37, 1.76, 4.5, 1.25, 3.4, 915.0],
   [14.02, 1.68, 2.21, 16.0, 96.0, 2.65, 2.33, 0.26, 1.98, 4.7, 1.04, 3.59, 1035.0],
   [13.73, 1.5, 2.7, 22.5, 101.0, 3.0, 3.25, 0.29, 2.38, 5.7, 1.19, 2.71, 1285.0],
   [13.58, 1.66, 2.36, 19.1, 106.0, 2.86, 3.19, 0.22, 1.95, 6.9, 1.09, 2.88, 1515.0],
   [13.68, 1.83, 2.36, 17.2, 104.0, 2.42, 2.69, 0.42, 1.97, 3.84, 1.23, 2.87, 990.0],
   [13.76, 1.53, 2.7, 19.5, 132.0, 2.95, 2.74, 0.5, 1.35, 5.4, 1.25, 3.0, 1235.0],
   [13.51, 1.8, 2.65, 19.0, 110.0, 2.35, 2.53, 0.29, 1.54, 4.2, 1.1, 2.87, 1095.0],
   [13.48, 1.81, 2.41, 20.5, 100.0, 2.7, 2.98, 0.26, 1.86, 5.1, 1.04, 3.47, 920.0],
   [13.28, 1.64, 2.84, 15.5, 110.0, 2.6, 2.68, 0.34, 1.36, 4.6, 1.09, 2.78, ...],
   [13.05, 1.65, 2.55, 18.0, 98.0, 2.45, 2.43, 0.29, 1.44, 4.25, 1.12, ...],
   [13.07, 1.5, 2.1, 15.5, 98.0, 2.4, 2.64, 0.28, 1.37, 3.7, ...],
   [14.22, 3.99, 2.51, 13.2, 128.0, 3.0, 3.04, 0.2, 2.08, ...],
   [13.56, 1.71, 2.31, 16.2, 117.0, 3.15, 3.29, 0.34, ...],
   [13.41, 3.84, 2.12, 18.8, 90.0, 2.45, 2.68, ...],
   [13.88, 1.89, 2.59, 15.0, 101.0, 3.25, ...],
   [13.24, 3.98, 2.29, 17.5, 103.0, ...],
   [13.05, 1.77, 2.1, 17.0, ...],
   [14.21, 4.04, 2.44, ...],
   [14.38, 3.59, ...],
   [13.9, ...],
   [...],
   ...
 ],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]}
{train, test} =
  inputs
  |> Enum.zip(targets)
  |> Enum.shuffle()
  |> Enum.split(floor(length(inputs) * 0.8))
{[
   {[13.56, 1.71, 2.31, 16.2, 117.0, 3.15, 3.29, 0.34, 2.34, 6.13, 0.95, 3.38, 795.0], 0},
   {[13.11, 1.01, 1.7, 15.0, 78.0, 2.98, 3.18, 0.26, 2.28, 5.3, 1.12, 3.18, 502.0], 1},
   {[13.58, 2.58, 2.69, 24.5, 105.0, 1.55, 0.84, 0.39, 1.54, 8.66, 0.74, 1.8, 750.0], 2},
   {[13.03, 0.9, 1.71, 16.0, 86.0, 1.95, 2.03, 0.24, 1.46, 4.6, 1.19, 2.48, 392.0], 1},
   {[13.86, 1.51, 2.67, 25.0, 86.0, 2.95, 2.86, 0.21, 1.87, 3.38, 1.36, 3.16, 410.0], 1},
   {[11.84, 2.89, 2.23, 18.0, 112.0, 1.72, 1.32, 0.43, 0.95, 2.65, 0.96, 2.52, 500.0], 1},
   {[12.37, 1.21, 2.56, 18.1, 98.0, 2.42, 2.65, 0.37, 2.08, 4.6, 1.19, 2.3, 678.0], 1},
   {[13.34, 0.94, 2.36, 17.0, 110.0, 2.53, 1.3, 0.55, 0.42, 3.17, 1.02, 1.93, 750.0], 1},
   {[12.86, 1.35, 2.32, 18.0, 122.0, 1.51, 1.25, 0.21, 0.94, 4.1, 0.76, 1.29, 630.0], 2},
   {[11.03, 1.51, 2.2, 21.5, 85.0, 2.46, 2.17, 0.52, 2.01, 1.9, 1.71, 2.87, 407.0], 1},
   {[12.07, 2.16, 2.17, 21.0, 85.0, 2.6, 2.65, 0.37, 1.35, 2.76, 0.86, 3.28, 378.0], 1},
   {[12.21, 1.19, 1.75, 16.8, 151.0, 1.85, 1.28, 0.14, 2.5, 2.85, 1.28, 3.07, 718.0], 1},
   {[11.64, 2.06, 2.46, 21.6, 84.0, 1.95, 1.69, 0.48, 1.35, 2.8, 1.0, 2.75, 680.0], 1},
   {[12.42, 4.43, 2.73, 26.5, 102.0, 2.2, 2.13, 0.43, 1.71, 2.08, 0.92, 3.12, 365.0], 1},
   {[12.77, 3.43, 1.98, 16.0, 80.0, 1.63, 1.25, 0.43, 0.83, 3.4, 0.7, 2.12, 372.0], 1},
   {[12.37, 0.94, 1.36, 10.6, 88.0, 1.98, 0.57, 0.28, 0.42, 1.95, 1.05, 1.82, 520.0], 1},
   {[14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0], 0},
   {[12.37, 1.63, 2.3, 24.5, 88.0, 2.22, 2.45, 0.4, 1.9, 2.12, 0.89, 2.78, 342.0], 1},
   {[14.21, 4.04, 2.44, 18.9, 111.0, 2.85, 2.65, 0.3, 1.25, 5.24, 0.87, 3.33, 1080.0], 0},
   {[12.43, 1.53, 2.29, 21.5, 86.0, 2.74, 3.15, 0.39, 1.77, 3.94, 0.69, 2.84, 352.0], 1},
   {[13.71, 1.86, 2.36, 16.6, 101.0, 2.61, 2.88, 0.27, 1.69, 3.8, 1.11, 4.0, 1035.0], 0},
   {[13.4, 3.91, 2.48, 23.0, 102.0, 1.8, 0.75, 0.43, 1.41, 7.3, 0.7, 1.56, 750.0], 2},
   {[11.82, 1.47, 1.99, 20.8, 86.0, 1.98, 1.6, 0.3, 1.53, 1.95, 0.95, 3.33, 495.0], 1},
   {[12.37, 1.13, 2.16, 19.0, 87.0, 3.5, 3.1, 0.19, 1.87, 4.45, 1.22, 2.87, 420.0], 1},
   {[12.72, 1.75, 2.28, 22.5, 84.0, 1.38, 1.76, 0.48, 1.63, 3.3, 0.88, 2.42, 488.0], 1},
   {[12.52, 2.43, 2.17, 21.0, 88.0, 2.55, 2.27, 0.26, 1.22, 2.0, 0.9, 2.78, 325.0], 1},
   {[13.49, 1.66, 2.24, 24.0, 87.0, 1.88, 1.84, 0.27, 1.03, 3.74, 0.98, 2.78, 472.0], 1},
   {[12.37, 1.17, 1.92, 19.6, 78.0, 2.11, 2.0, 0.27, 1.04, 4.68, 1.12, 3.48, 510.0], 1},
   {[13.78, 2.76, 2.3, 22.0, 90.0, 1.35, 0.68, 0.41, 1.03, 9.58, 0.7, 1.68, 615.0], 2},
   {[13.24, 3.98, 2.29, 17.5, 103.0, 2.64, 2.63, 0.32, 1.66, 4.36, 0.82, 3.0, 680.0], 0},
   {[12.51, 1.24, 2.25, 17.5, 85.0, 2.0, 0.58, 0.6, 1.25, 5.45, 0.75, 1.51, 650.0], 2},
   {[12.7, 3.87, 2.4, 23.0, 101.0, 2.83, 2.55, 0.43, 1.95, 2.57, 1.19, 3.13, 463.0], 1},
   {[12.22, 1.29, 1.94, 19.0, 92.0, 2.36, 2.04, 0.39, 2.08, 2.7, 0.86, 3.02, 312.0], 1},
   {[12.29, 1.41, 1.98, 16.0, 85.0, 2.55, 2.5, 0.29, 1.77, 2.9, 1.23, 2.74, 428.0], 1},
   {[13.05, 1.65, 2.55, 18.0, 98.0, 2.45, 2.43, 0.29, 1.44, 4.25, 1.12, 2.51, 1105.0], 0},
   {[13.74, 1.67, 2.25, 16.4, 118.0, 2.6, 2.9, 0.21, 1.62, 5.85, 0.92, 3.2, ...], 0},
   {[12.42, 1.61, 2.19, 22.5, 108.0, 2.0, 2.09, 0.34, 1.61, 2.06, 1.06, ...], 1},
   {[11.84, 0.89, 2.58, 18.0, 94.0, 2.2, 2.21, 0.22, 2.35, 3.05, ...], 1},
   {[14.22, 1.7, 2.3, 16.3, 118.0, 3.2, 3.0, 0.26, 2.03, ...], 0},
   {[12.96, 3.45, 2.35, 18.5, 106.0, 1.39, 0.7, 0.4, ...], 2},
   {[13.68, 1.83, 2.36, 17.2, 104.0, 2.42, 2.69, ...], 0},
   {[12.85, 1.6, 2.52, 17.8, 95.0, 2.48, ...], 0},
   {[14.12, 1.48, 2.32, 16.8, 95.0, ...], 0},
   {[12.04, 4.3, 2.38, 22.0, ...], 1},
   {[12.82, 3.37, 2.3, ...], 2},
   {[11.66, 1.88, ...], 1},
   {[13.77, ...], 0},
   {[...], ...},
   {...},
   ...
 ],
 [
   {[13.05, 1.73, 2.04, 12.4, 92.0, 2.72, 3.27, 0.17, 2.91, 7.2, 1.12, 2.91, 1150.0], 0},
   {[13.71, 5.65, 2.45, 20.5, 95.0, 1.68, 0.61, 0.52, 1.06, 7.7, 0.64, 1.74, 740.0], 2},
   {[13.05, 5.8, 2.13, 21.5, 86.0, 2.62, 2.65, 0.3, 2.01, 2.6, 0.73, 3.1, 380.0], 1},
   {[13.5, 3.12, 2.62, 24.0, 123.0, 1.4, 1.57, 0.22, 1.25, 8.6, 0.59, 1.3, 500.0], 2},
   {[12.0, 0.92, 2.0, 19.0, 86.0, 2.42, 2.26, 0.3, 1.43, 2.5, 1.38, 3.12, 278.0], 1},
   {[12.6, 1.34, 1.9, 18.5, 88.0, 1.45, 1.36, 0.29, 1.35, 2.45, 1.04, 2.77, 562.0], 1},
   {[14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0], 0},
   {[13.87, 1.9, 2.8, 19.4, 107.0, 2.95, 2.97, 0.37, 1.76, 4.5, 1.25, 3.4, 915.0], 0},
   {[13.83, 1.57, 2.62, 20.0, 115.0, 2.95, 3.4, 0.4, 1.72, 6.6, 1.13, 2.57, 1130.0], 0},
   {[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], 0},
   {[13.29, 1.97, 2.68, 16.8, 102.0, 3.0, 3.23, 0.31, 1.66, 6.0, 1.07, 2.84, 1270.0], 0},
   {[13.73, 1.5, 2.7, 22.5, 101.0, 3.0, 3.25, 0.29, 2.38, 5.7, 1.19, 2.71, 1285.0], 0},
   {[12.29, 1.61, 2.21, 20.4, 103.0, 1.1, 1.02, 0.37, 1.46, 3.05, 0.906, 1.82, 870.0], 1},
   {[11.79, 2.13, 2.78, 28.5, 92.0, 2.13, 2.24, 0.58, 1.76, 3.0, 0.97, 2.44, 466.0], 1},
   {[13.5, 1.81, 2.61, 20.0, 96.0, 2.53, 2.61, 0.28, 1.66, 3.52, 1.12, 3.82, 845.0], 0},
   {[13.62, 4.95, 2.35, 20.0, 92.0, 2.0, 0.8, 0.47, 1.02, 4.4, 0.91, 2.05, 550.0], 2},
   {[11.45, 2.4, 2.42, 20.0, 96.0, 2.9, 2.79, 0.32, 1.83, 3.25, 0.8, 3.39, 625.0], 1},
   {[14.19, 1.59, 2.48, 16.5, 108.0, 3.3, 3.93, 0.32, 1.86, 8.7, 1.23, 2.82, 1680.0], 0},
   {[12.17, 1.45, 2.53, 19.0, 104.0, 1.89, 1.75, 0.45, 1.03, 2.95, 1.45, 2.23, 355.0], 1},
   {[12.16, 1.61, 2.31, 22.8, 90.0, 1.78, 1.69, 0.43, 1.56, 2.45, 1.33, 2.26, 495.0], 1},
   {[11.81, 2.12, 2.74, 21.5, 134.0, 1.6, 0.99, 0.14, 1.56, 2.5, 0.95, 2.26, 625.0], 1},
   {[13.28, 1.64, 2.84, 15.5, 110.0, 2.6, 2.68, 0.34, 1.36, 4.6, 1.09, 2.78, 880.0], 0},
   {[12.25, 1.73, 2.12, 19.0, 80.0, 1.65, 2.03, 0.37, 1.63, 3.4, 1.0, 3.17, 510.0], 1},
   {[14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0], 0},
   {[13.88, 1.89, 2.59, 15.0, 101.0, 3.25, 3.56, 0.17, 1.7, 5.43, 0.88, 3.56, 1095.0], 0},
   {[14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], 0},
   {[11.41, 0.74, 2.5, 21.0, 88.0, 2.48, 2.01, 0.42, 1.44, 3.08, 1.1, 2.31, 434.0], 1},
   {[12.25, 4.72, 2.54, 21.0, 89.0, 1.38, 0.47, 0.53, 0.8, 3.85, 0.75, 1.27, 720.0], 2},
   {[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.68, 1.03, 3.17, 1185.0], 0},
   {[13.63, 1.81, 2.7, 17.2, 112.0, 2.85, 2.91, 0.3, 1.46, 7.3, 1.28, 2.88, 1310.0], 0},
   {[12.7, 3.55, 2.36, 21.5, 106.0, 1.7, 1.2, 0.17, 0.84, 5.0, 0.78, 1.29, 600.0], 2},
   {[13.3, 1.72, 2.14, 17.0, 94.0, 2.4, 2.19, 0.27, 1.35, 3.95, 1.02, 2.77, 1285.0], 0},
   {[12.87, 4.61, 2.48, 21.5, 86.0, 1.7, 0.65, 0.47, 0.86, 7.65, 0.54, 1.86, 625.0], 2},
   {[14.3, 1.92, 2.72, 20.0, 120.0, 2.8, 3.14, 0.33, 1.97, 6.2, 1.07, 2.65, 1280.0], 0},
   {[13.58, 1.66, 2.36, 19.1, 106.0, 2.86, 3.19, 0.22, 1.95, 6.9, 1.09, 2.88, ...], 0},
   {[13.39, 1.77, 2.62, 16.1, 93.0, 2.85, 2.94, 0.34, 1.45, 4.8, 0.92, ...], 0}
 ]}
{train_inputs, train_targets} = Enum.unzip(train)
train_inputs = Nx.tensor(train_inputs)
train_targets = Nx.tensor(train_targets)
#Nx.Tensor<
  s64[142]
  EXLA.Backend
  [0, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 2, 1, 1, 1, 1, 1, 1, 2, 0, 2, 1, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 1, 0, 1, 2, 1, ...]
>
{test_inputs, test_targets} = Enum.unzip(test)
test_inputs = Nx.tensor(test_inputs)
test_targets = Nx.tensor(test_targets)
#Nx.Tensor<
  s64[36]
  EXLA.Backend
  [0, 2, 1, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 2, 0, 0, 2, 0, 2, 0, 0, 0]
>
train_inputs = Scholar.Preprocessing.min_max_scale(train_inputs)
#Nx.Tensor<
  f32[142][13]
  EXLA.Backend
  [
    [0.008682048879563808, 0.0010214174399152398, 0.0014092973433434963, 0.010388721711933613, 0.0755525678396225, 0.0019523295341059566, 0.0020428348798304796, 1.357580185867846e-4, 0.0014286915538832545, 0.003878800431266427, 5.301026976667345e-4, 0.002101016929373145, 0.5138570070266724],
    [0.008391138166189194, 5.688907112926245e-4, 0.0010149527806788683, 0.009612960740923882, 0.05034036561846733, 0.0018424300942569971, 0.0019717237446457148, 8.404067193623632e-5, 0.0013899035984650254, 0.003342233132570982, 6.40002079308033e-4, 0.0019717237446457148, 0.3244422674179077],
    [0.00869497749954462, 0.001583843375556171, 0.0016549548599869013, 0.015754394233226776, 0.06779497116804123, 9.179827175103128e-4, 4.589913587551564e-4, 1.6808134387247264e-4, 9.115180582739413e-4, 0.005514360964298248, 3.943447081837803e-4, 0.0010795993730425835, 0.48476600646972656],
    [0.008339420892298222, 4.977794014848769e-4, 0.0010214174399152398, 0.01025942713022232, 0.05551210045814514, 0.0011765694944187999, 0.001228286768309772, 7.111133891157806e-5, 8.598007843829691e-4, 0.002889706287533045, 6.852548103779554e-4, ...],
    ...
  ]
>
test_inputs = Scholar.Preprocessing.min_max_scale(test_inputs)
#Nx.Tensor<
  f32[36][13]
  EXLA.Backend
  [
    [0.007685163989663124, 9.465074981562793e-4, 0.001131046679802239, 0.007298226933926344, 0.054683130234479904, 0.0015358421951532364, 0.001863250508904457, 1.7858632418210618e-5, 0.0016489468980580568, 0.00420273095369339, 5.833819741383195e-4, 0.0016489468980580568, 0.6844975352287292],
    [0.008078053593635559, 0.003280035452917218, 0.001375114545226097, 0.012120057828724384, 0.05646899342536926, 9.167430689558387e-4, 2.7978525031358004e-4, 2.2620933305006474e-4, 5.476646474562585e-4, 0.0045003751292824745, 2.976438554469496e-4, 9.524603374302387e-4, 0.44042956829071045],
    [0.007685163989663124, 0.0033693285658955574, 0.0011846226407215, 0.012715346179902554, 0.05111140385270119, 0.001476313336752355, 0.0014941721456125379, 9.524603956378996e-5, 0.001113187987357378, 0.0014644076582044363, 3.5121975815854967e-4, 0.0017620514845475554, 0.22612598538398743],
    [0.007953044027090073, 0.001773957279510796, 0.001476313336752355, 0.014203565195202827, 0.07313704490661621, 7.500625215470791e-4, 8.512614876963198e-4, 4.7623016143916175e-5, 6.60769350361079e-4, 0.00503613380715251, 2.678794553503394e-4, ...],
    ...
  ]
>
model =
  Scholar.Linear.LogisticRegression.fit(
    train_inputs,
    train_targets,
    num_classes: 3
  )
%Scholar.Linear.LogisticRegression{
  coefficients: #Nx.Tensor<
    f32[13][3]
    EXLA.Backend
    [
      [1.0080410242080688, 0.8926408886909485, 1.099318265914917],
      [0.9865338802337646, 0.8100578188896179, 1.2034083604812622],
      [0.9965129494667053, 0.9742255210876465, 1.0292614698410034],
      [0.6327518820762634, 0.9616338014602661, 1.4056137800216675],
      [0.8146526217460632, 0.8948448300361633, 1.290502667427063],
      [1.042885422706604, 1.080212116241455, 0.8769035935401917],
      [1.099364161491394, 1.1582599878311157, 0.742372453212738],
      [0.9907945394515991, 0.9903120994567871, 1.0188928842544556],
      [1.0206292867660522, 1.0675230026245117, 0.9118484258651733],
      [0.9449230432510376, 0.39299315214157104, 1.662082314491272],
      [1.0057761669158936, 1.0561885833740234, 0.9380345344543457],
      [1.0640884637832642, 1.1418187618255615, 0.7940915822982788],
      [13.501482963562012, -8.62801742553711, -1.8734524250030518]
    ]
  >,
  bias: #Nx.Tensor<
    f32[3]
    EXLA.Backend
    [-6.4759135246276855, 4.697933673858643, 1.7779799699783325]
  >
}
test_preds = Scholar.Linear.LogisticRegression.predict(model, test_inputs)
#Nx.Tensor<
  s64[36]
  EXLA.Backend
  [0, 2, 1, 1, 1, 1, 0, 0, 0, 2, 0, 0, 2, 1, 2, 1, 1, 0, 1, 1, 1, 2, 1, 0, 0, 0, 1, 2, 0, 0, 1, 0, 1, 0, 0, 0]
>

Scholar also implements routines for evaluating models. To compute the accuracy of a model:

Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
#Nx.Tensor<
  f32
  EXLA.Backend
  0.7777777910232544
>

A common metric for classification problems is a confusion matrix, a table that lays out the performance of a model with respect to each class in a classification problem. The table is two-dimensional where columns represent predicted class and rows represent actual class

Scholar.Metrics.Classification.confusion_matrix(
  test_targets,
  test_preds,
  num_classes: 3
)
#Nx.Tensor<
  u64[3][3]
  EXLA.Backend
  [
    [16, 0, 3],
    [0, 10, 1],
    [0, 4, 2]
  ]
>
Vl.new(
  title: "Confusion Matrix",
  width: 500,
  height: 500
)
|> Vl.data_from_values(%{
  predicted: Nx.to_flat_list(test_preds),
  actual: Nx.to_flat_list(test_targets)
})
|> Vl.mark(:rect)
|> Vl.encode_field(:x, "predicted")
|> Vl.encode_field(:y, "actual")
|> Vl.encode(:color, aggregate: :count)
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"actual":0,"predicted":0},{"actual":2,"predicted":2},{"actual":1,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":0,"predicted":2},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":2},{"actual":1,"predicted":1},{"actual":0,"predicted":2},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":1,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":2},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":2,"predicted":2},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":0,"predicted":0}]},"encoding":{"color":{"aggregate":"count"},"x":{"field":"predicted"},"y":{"field":"actual"}},"height":500,"mark":"rect","title":"Confusion Matrix","width":500}

Dealing with Non-linear Data

A way to deal with non-linear data is to perform transformations on your data until a linear relationship arises. In a sense, this is feature engineering, the process of manipulating features by hand to provide better information for your model

Learning from Your Surroundings

Another common model in traditional ML is a K-Nearest Neighbors (KNN) model. KNN works by classifying new points as belonging to the most common class of it’s K closest neighbors. KNN can be used for regression and classification. KNN is known to be simple and powerful

model =
  Scholar.Neighbors.KNearestNeighbors.fit(
    train_inputs,
    train_targets,
    num_classes: 3
  )
%Scholar.Neighbors.KNearestNeighbors{
  data: #Nx.Tensor<
    f32[142][13]
    EXLA.Backend
    [
      [0.008682048879563808, 0.0010214174399152398, 0.0014092973433434963, 0.010388721711933613, 0.0755525678396225, 0.0019523295341059566, 0.0020428348798304796, 1.357580185867846e-4, 0.0014286915538832545, 0.003878800431266427, 5.301026976667345e-4, 0.002101016929373145, 0.5138570070266724],
      [0.008391138166189194, 5.688907112926245e-4, 0.0010149527806788683, 0.009612960740923882, 0.05034036561846733, 0.0018424300942569971, 0.0019717237446457148, 8.404067193623632e-5, 0.0013899035984650254, 0.003342233132570982, 6.40002079308033e-4, 0.0019717237446457148, 0.3244422674179077],
      [0.00869497749954462, 0.001583843375556171, 0.0016549548599869013, 0.015754394233226776, 0.06779497116804123, 9.179827175103128e-4, 4.589913587551564e-4, 1.6808134387247264e-4, 9.115180582739413e-4, 0.005514360964298248, 3.943447081837803e-4, 0.0010795993730425835, 0.48476600646972656],
      [0.008339420892298222, 4.977794014848769e-4, 0.0010214174399152398, 0.01025942713022232, 0.05551210045814514, 0.0011765694944187999, 0.001228286768309772, 7.111133891157806e-5, 8.598007843829691e-4, 0.002889706287533045, ...],
      ...
    ]
  >,
  labels: #Nx.Tensor<
    s64[142]
    EXLA.Backend
    [0, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 2, 1, 1, 1, 1, 1, 1, 2, 0, 2, 1, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 1, 0, 1, ...]
  >,
  default_num_neighbors: 5,
  weights: :uniform,
  num_classes: 3,
  task: :classification,
  metric: {:minkowski, 2}
}
test_preds =
  Scholar.Neighbors.KNearestNeighbors.predict(
    model,
    test_inputs
  )
#Nx.Tensor<
  s64[36]
  EXLA.Backend
  [0, 1, 1, 1, 1, 1, 0, 2, 0, 1, 0, 0, 0, 1, 2, 1, 1, 0, 1, 1, 2, 2, 1, 0, 0, 0, 1, 2, 0, 0, 1, 0, 1, 0, 0, 0]
>
Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
#Nx.Tensor<
  f32
  EXLA.Backend
  0.694444477558136
>
Scholar.Metrics.Classification.confusion_matrix(
  test_targets,
  test_preds,
  num_classes: 3
)
#Nx.Tensor<
  u64[3][3]
  EXLA.Backend
  [
    [15, 1, 3],
    [1, 9, 1],
    [0, 5, 1]
  ]
>
Vl.new(
  title: "KNN Confusion Matrix",
  width: 500,
  height: 500
)
|> Vl.data_from_values(%{
  predicted: Nx.to_flat_list(test_preds),
  actual: Nx.to_flat_list(test_targets)
})
|> Vl.mark(:rect)
|> Vl.encode_field(:x, "predicted")
|> Vl.encode_field(:y, "actual")
|> Vl.encode(:color, aggregate: :count)
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","data":{"values":[{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":2},{"actual":0,"predicted":0},{"actual":0,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":0},{"actual":1,"predicted":1},{"actual":0,"predicted":2},{"actual":2,"predicted":1},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":1,"predicted":1},{"actual":1,"predicted":2},{"actual":0,"predicted":2},{"actual":1,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":1,"predicted":1},{"actual":2,"predicted":2},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":0,"predicted":0},{"actual":2,"predicted":1},{"actual":0,"predicted":0},{"actual":0,"predicted":0},{"actual":0,"predicted":0}]},"encoding":{"color":{"aggregate":"count"},"x":{"field":"predicted"},"y":{"field":"actual"}},"height":500,"mark":"rect","title":"KNN Confusion Matrix","width":500}

Using Clustering

Previous methods are examples of supervised methods. Scholar also has unsupervised learning and analysis.

A common type of unsupervised learning is clustering; the process of identifying clusters or groups of similar data points in a dataset. Approaches to clustering include K-Means clustering, hierarchical clustering, spectral clustering and more.

K-means clustering randomly assigns K centroids to random points in the dataset and iteratively updates each centroid until an optimal configuration is reached

model =
  Scholar.Cluster.KMeans.fit(
    train_inputs,
    num_clusters: 3
  )
%Scholar.Cluster.KMeans{
  num_iterations: #Nx.Tensor<
    s64
    EXLA.Backend
    4
  >,
  clusters: #Nx.Tensor<
    f32[3][13]
    EXLA.Backend
    [
      [0.008846291340887547, 0.0011630341177806258, 0.0014636411797255278, 0.01087559200823307, 0.06813840568065643, 0.0017448541475459933, 0.0018127331277355552, 1.0141448001377285e-4, 0.0011305087246000767, 0.003449707990512252, 6.163656362332404e-4, 0.0019596023485064507, 0.7514739632606506],
      [0.008014325052499771, 0.0015410014893859625, 0.0013915469171479344, 0.013321926817297935, 0.05955525487661362, 0.0012643354712054133, 0.0010442081838846207, 1.68957922141999e-4, 8.549796184524894e-4, 0.00261764251627028, 5.124618764966726e-4, 0.0015278528444468975, 0.29672086238861084],
      [0.00829061958938837, 0.0014706484507769346, 0.0014494797214865685, 0.012729943729937077, 0.06688231229782104, 0.001284313853830099, 9.240671643055975e-4, 1.7175734683405608e-4, 9.072083048522472e-4, 0.0037235214840620756, 4.856106243096292e-4, 0.0014383250381797552, 0.4686296880245209]
    ]
  >,
  inertia: #Nx.Tensor<
    f32
    EXLA.Backend
    0.6683441400527954
  >,
  labels: #Nx.Tensor<
    s64[142]
    EXLA.Backend
    [2, 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 0, 1, 0, 1, 0, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 1, ...]
  >
}

Visualizing the centroids relative to each point

wine_features = %{
  "feature_1" => train_inputs[[.., 1]] |> Nx.to_flat_list(),
  "feature_2" => train_inputs[[.., 2]] |> Nx.to_flat_list(),
  "class" => train_targets |> Nx.to_flat_list()
}

coords = [
  cluster_feature_1: model.clusters[[.., 1]] |> Nx.to_flat_list(),
  cluster_feature_2: model.clusters[[.., 2]] |> Nx.to_flat_list()
]

title = "Scatterplot of data samples pojected on plane wine feature 1 x wine feature 2"
"Scatterplot of data samples pojected on plane wine feature 1 x wine feature 2"
Vl.new(
  width: 1440,
  height: 1080,
  title: [text: title, offset: 25]
)
|> Vl.layers([
  Vl.new()
  |> Vl.data_from_values(wine_features)
  |> Vl.mark(:circle)
  |> Vl.encode_field(:x, "feature_1", type: :quantitative)
  |> Vl.encode_field(:y, "feature_2", type: :quantitative)
  |> Vl.encode_field(:color, "class"),
  Vl.new()
  |> Vl.data_from_values(coords)
  |> Vl.mark(:circle, color: :green, size: 100)
  |> Vl.encode_field(:x, "cluster_feature_1", type: :quantitative)
  |> Vl.encode_field(:y, "cluster_feature_2", type: :quantitative)
])
{"$schema":"https://vega.github.io/schema/vega-lite/v5.json","height":1080,"layer":[{"data":{"values":[{"class":0,"feature_1":0.0010214174399152398,"feature_2":0.0014092973433434963},{"class":1,"feature_1":5.688907112926245e-4,"feature_2":0.0010149527806788683},{"class":2,"feature_1":0.001583843375556171,"feature_2":0.0016549548599869013},{"class":1,"feature_1":4.977794014848769e-4,"feature_2":0.0010214174399152398},{"class":1,"feature_1":8.921240805648267e-4,"feature_2":0.0016420255415141582},{"class":1,"feature_1":0.001784248393960297,"feature_2":0.0013575800694525242},{"class":1,"feature_1":6.981841288506985e-4,"feature_2":0.001570914057083428},{"class":1,"feature_1":5.236380384303629e-4,"feature_2":0.0014416208723559976},{"class":2,"feature_1":7.886894163675606e-4,"feature_2":0.0014157622354105115},{"class":1,"feature_1":8.921240805648267e-4,"feature_2":0.0013381862081587315},{"class":1,"feature_1":0.0013123275712132454,"feature_2":0.001318792114034295},{"class":1,"feature_1":6.852548103779554e-4,"feature_2":0.0010472760768607259},{"class":1,"feature_1":0.0012476807460188866,"feature_2":0.0015062674647197127},{"class":1,"feature_1":0.0027798067312687635,"feature_2":0.0016808134969323874},{"class":1,"feature_1":0.0021333403419703245,"feature_2":0.0011959634721279144},{"class":1,"feature_1":5.236380384303629e-4,"feature_2":7.951540756039321e-4},{"class":0,"feature_1":0.001124852104112506,"feature_2":0.001499802921898663},{"class":1,"feature_1":9.697001078166068e-4,"feature_2":0.0014028329169377685},{"class":0,"feature_1":0.0025276849046349525,"feature_2":0.0014933381462469697},{"class":1,"feature_1":9.050533990375698e-4,"feature_2":0.0013963680248707533},{"class":0,"feature_1":0.0011183874448761344,"feature_2":0.0014416208723559976},{"class":2,"feature_1":0.0024436444509774446,"feature_2":0.0015191967831924558},{"class":1,"feature_1":8.662654436193407e-4,"feature_2":0.001202428131364286},{"class":1,"feature_1":6.464667385444045e-4,"feature_2":0.0013123275712132454},{"class":1,"feature_1":0.0010472760768607259,"feature_2":0.0013899035984650254},{"class":1,"feature_1":0.00148687360342592,"feature_2":0.001318792114034295},{"class":1,"feature_1":9.890941437333822e-4,"feature_2":0.0013640449615195394},{"class":1,"feature_1":6.723253754898906e-4,"feature_2":0.0011571754002943635},{"class":2,"feature_1":0.0017002075910568237,"feature_2":0.0014028329169377685},{"class":0,"feature_1":0.0024888969492167234,"feature_2":0.0013963680248707533},{"class":2,"feature_1":7.17578106559813e-4,"feature_2":0.0013705093879252672},{"class":1,"feature_1":0.0024177853483706713,"feature_2":0.0014674795093014836},{"class":1,"feature_1":7.499014027416706e-4,"feature_2":0.0011701048351824284},{"class":1,"feature_1":8.274774299934506e-4,"feature_2":0.0011959634721279144},{"class":0,"feature_1":9.826294844970107e-4,"feature_2":0.0015644495142623782},{"class":0,"feature_1":9.955588029697537e-4,"feature_2":0.0013705093879252672},{"class":1,"feature_1":9.567707893438637e-4,"feature_2":0.0013317214325070381},{"class":1,"feature_1":4.913147422485054e-4,"feature_2":0.001583843375556171},{"class":0,"feature_1":0.0010149527806788683,"feature_2":0.0014028329169377685},{"class":2,"feature_1":0.0021462696604430676,"feature_2":0.0014351559802889824},{"class":0,"feature_1":0.0010989934671670198,"feature_2":0.0014416208723559976},{"class":0,"feature_1":9.503061301074922e-4,"feature_2":0.0015450554201379418},{"class":0,"feature_1":8.727301028557122e-4,"feature_2":0.0014157622354105115},{"class":1,"feature_1":0.0026957662776112556,"feature_2":0.0014545501908287406},{"class":2,"feature_1":0.0020945521537214518,"feature_2":0.0014028329169377685},{"class":1,"feature_1":0.0011313167633488774,"feature_2":0.0011571754002943635},{"class":0,"feature_1":0.0011442460818216205,"feature_2":0.0016484903171658516},{"class":1,"feature_1":9.955588029697537e-4,"feature_2":0.001609702012501657},{"class":2,"feature_1":0.0010020234622061253,"feature_2":0.0016614196356385946},{"class":1,"feature_1":5.494967335835099e-4,"feature_2":0.0013640449615195394},{"class":1,"feature_1":7.757600978948176e-4,"feature_2":0.0014028329169377685},{"class":1,"feature_1":0.0024113210383802652,"feature_2":0.0014157622354105115},{"class":1,"feature_1":0.0023337448947131634,"feature_2":0.0010925288079306483},{"class":2,"feature_1":0.0023078862577676773,"feature_2":0.001596772694028914},{"class":2,"feature_1":0.0019652589689940214,"feature_2":0.0016743489541113377},{"class":2,"feature_1":7.499014027416706e-4,"feature_2":0.0012735393829643726},{"class":1,"feature_1":0.0010989934671670198,"feature_2":0.0014157622354105115},{"class":2,"feature_1":0.0024371796753257513,"feature_2":0.0014416208723559976},{"class":2,"feature_1":9.955588029697537e-4,"feature_2":0.0016226316802203655},{"class":0,"feature_1":9.697001078166068e-4,"feature_2":0.0013899035984650254},{"class":1,"feature_1":7.951540756039321e-4,"feature_2":0.0012218221090734005},{"class":1,"feature_1":0.0010343467583879828,"feature_2":0.0011959634721279144},{"class":0,"feature_1":0.002495361724868417,"feature_2":0.0015385908773168921},{"class":2,"feature_1":0.002236774889752269,"feature_2":0.0013317214325070381},{"class":1,"feature_1":5.559613928198814e-4,"feature_2":0.0011765694944187999},{"class":2,"feature_1":0.002049299655482173,"feature_2":0.0013899035984650254},{"class":2,"feature_1":0.0015062674647197127,"feature_2":0.0013381862081587315},{"class":2,"feature_1":0.0025794021785259247,"feature_2":0.0014545501908287406},{"class":0,"feature_1":0.002398391719907522,"feature_2":0.0012864687014371157},{"class":1,"feature_1":0.0027022308204323053,"feature_2":0.001461014966480434},{"class":0,"feature_1":9.826294844970107e-4,"feature_2":0.001596772694028914},{"class":0,"feature_1":0.002236774889752269,"feature_2":0.0013899035984650254},{"class":2,"feature_1":0.0017325307708233595,"feature_2":0.0016614196356385946},{"class":2,"feature_1":0.0018747536232694983,"feature_2":0.0014157622354105115},{"class":2,"feature_1":0.0027345542330294847,"feature_2":0.0013769742799922824},{"class":1,"feature_1":0.0016484903171658516,"feature_2":0.001803642138838768},{"class":1,"feature_1":9.955588029697537e-4,"feature_2":0.001596772694028914},{"class":0,"feature_1":0.001124852104112506,"feature_2":0.0014545501908287406},{"class":1,"feature_1":0.001202428131364286,"feature_2":0.0013899035984650254},{"class":2,"feature_1":0.0016420255415141582,"feature_2":0.0015191967831924558},{"class":0,"feature_1":9.050533990375698e-4,"feature_2":0.0016614196356385946},{"class":1,"feature_1":9.050533990375698e-4,"feature_2":0.0013769742799922824},{"class":0,"feature_1":7.886894163675606e-4,"feature_2":0.0013834387063980103},{"class":1,"feature_1":8.145481115207076e-4,"feature_2":0.0015321261016651988},{"class":2,"feature_1":0.0026828369591385126,"feature_2":0.0013769742799922824},{"class":2,"feature_1":0.0024242503568530083,"feature_2":0.0013381862081587315},{"class":1,"feature_1":6.206081015989184e-4,"feature_2":0.0014028329169377685},{"class":1,"feature_1":6.464667385444045e-4,"feature_2":0.0015385908773168921},{"class":1,"feature_1":0.0017454600892961025,"feature_2":0.0013511155266314745},{"class":2,"feature_1":0.0011442460818216205,"feature_2":0.0016937428154051304},{"class":0,"feature_1":0.0019200061215087771,"feature_2":0.001570914057083428},{"class":0,"feature_1":0.0010343467583879828,"feature_2":0.0015062674647197127},{"class":0,"feature_1":0.001305862795561552,"feature_2":0.0016032374696806073},{"class":1,"feature_1":7.886894163675606e-4,"feature_2":0.0016614196356385946},{"class":2,"feature_1":0.002391926711425185,"feature_2":0.0014545501908287406},{"class":0,"feature_1":0.0010343467583879828,"feature_2":0.0013834387063980103},{"class":1,"feature_1":0.001086064032278955,"feature_2":0.0013381862081587315},{"class":2,"feature_1":0.002023441018536687,"feature_2":0.0015579847386106849},{"class":2,"feature_1":0.0025664728600531816,"feature_2":0.0016872782725840807},{"class":1,"feature_1":6.2707276083529e-4,"feature_2":0.0013899035984650254},{"class":0,"feature_1":0.0010343467583879828,"feature_2":0.001461014966480434},{"class":2,"feature_1":0.0022238455712795258,"feature_2":0.001305862795561552},{"class":0,"feature_1":0.001241216086782515,"feature_2":0.001997582381591201},{"class":2,"feature_1":0.002010511700063944,"feature_2":0.0014545501908287406},{"class":2,"feature_1":0.001829500775784254,"feature_2":0.0016032374696806073},{"class":2,"feature_1":0.0034779910929501057,"feature_2":0.0016226316802203655},{"class":2,"feature_1":0.0018747536232694983,"feature_2":0.0016226316802203655},{"class":2,"feature_1":0.0031741515267640352,"feature_2":0.0013575800694525242},{"class":1,"feature_1":0.0021333403419703245,"feature_2":0.0012088927906006575},{"class":0,"feature_1":0.0010795993730425835,"feature_2":0.0016290962230414152},{"class":0,"feature_1":0.0010343467583879828,"feature_2":0.001473944284953177},{"class":0,"feature_1":0.0010666700545698404,"feature_2":0.0012993982527405024},{"class":2,"feature_1":0.001570914057083428,"feature_2":0.0014351559802889824},{"class":0,"feature_1":0.001060205395333469,"feature_2":0.0012735393829643726},{"class":0,"feature_1":8.404067484661937e-4,"feature_2":0.0015321261016651988},{"class":1,"feature_1":0.001499802921898663,"feature_2":0.0015062674647197127},{"class":0,"feature_1":0.0012218221090734005,"feature_2":0.0014674795093014836},{"class":0,"feature_1":8.856594213284552e-4,"feature_2":0.0012735393829643726},{"class":1,"feature_1":0.0019652589689940214,"feature_2":0.0013446507509797812},{"class":0,"feature_1":0.0010020234622061253,"feature_2":0.0012864687014371157},{"class":0,"feature_1":0.0010537407360970974,"feature_2":0.001499802921898663},{"class":2,"feature_1":0.0015385908773168921,"feature_2":0.0015191967831924558},{"class":0,"feature_1":0.001086064032278955,"feature_2":0.001473944284953177},{"class":2,"feature_1":0.002889706287533045,"feature_2":0.001764854183420539},{"class":1,"feature_1":0.0015644495142623782,"feature_2":0.0013834387063980103},{"class":2,"feature_1":0.0018488949863240123,"feature_2":0.0014674795093014836},{"class":0,"feature_1":0.0010472760768607259,"feature_2":0.0014804088277742267},{"class":2,"feature_1":0.001461014966480434,"feature_2":0.0013899035984650254},{"class":0,"feature_1":0.0013123275712132454,"feature_2":0.0014028329169377685},{"class":2,"feature_1":0.0020299055613577366,"feature_2":0.001583843375556171},{"class":2,"feature_1":0.0015903081512078643,"feature_2":0.0014480852987617254},{"class":0,"feature_1":0.002372533082962036,"feature_2":0.0016290962230414152},{"class":2,"feature_1":0.0014092973433434963,"feature_2":0.0014674795093014836},{"class":1,"feature_1":8.921240805648267e-4,"feature_2":0.0014804088277742267},{"class":0,"feature_1":0.0010020234622061253,"feature_2":0.0013446507509797812},{"class":2,"feature_1":0.0032711217645555735,"feature_2":0.0014157622354105115},{"class":1,"feature_1":7.240427657961845e-4,"feature_2":0.0011571754002943635},{"class":1,"feature_1":0.001241216086782515,"feature_2":0.0020040469244122505},{"class":1,"feature_1":0.0010278820991516113,"feature_2":0.0011313167633488774},{"class":1,"feature_1":0.0012606100644916296,"feature_2":0.0010149527806788683},{"class":1,"feature_1":8.985887398011982e-4,"feature_2":0.0013381862081587315},{"class":1,"feature_1":6.076787831261754e-4,"feature_2":0.0012735393829643726}]},"encoding":{"color":{"field":"class"},"x":{"field":"feature_1","type":"quantitative"},"y":{"field":"feature_2","type":"quantitative"}},"mark":"circle"},{"data":{"values":[{"cluster_feature_1":0.0011630341177806258,"cluster_feature_2":0.0014636411797255278},{"cluster_feature_1":0.0015410014893859625,"cluster_feature_2":0.0013915469171479344},{"cluster_feature_1":0.0014706484507769346,"cluster_feature_2":0.0014494797214865685}]},"encoding":{"x":{"field":"cluster_feature_1","type":"quantitative"},"y":{"field":"cluster_feature_2","type":"quantitative"}},"mark":{"color":"green","size":100,"type":"circle"}}],"title":{"offset":25,"text":"Scatterplot of data samples pojected on plane wine feature 1 x wine feature 2"},"width":1440}
test_preds = Scholar.Cluster.KMeans.predict(model, test_inputs)
#Nx.Tensor<
  s64[36]
  EXLA.Backend
  [0, 2, 1, 1, 1, 1, 0, 2, 0, 2, 0, 0, 2, 1, 2, 1, 1, 0, 1, 1, 1, 2, 1, 0, 0, 0, 1, 2, 0, 0, 1, 0, 1, 0, 0, 0]
>
Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
#Nx.Tensor<
  f32
  EXLA.Backend
  0.75
>

Making Decisions

One of the most popular non-DL algorithms in use today are decision trees and it’s variants, such as gradient boosting

Decision trees construct a nested tree based on input features. The tree is a hierarchical decision flow that partitions input features into one of a desired number of classes.

Gradient boosting is a type of ensemble method that constructs weak classifiers iteratively by building classifiers to cover the weaknesses of previous classifiers.

EXGBoost, Elixir’s decision tree library, integrates directly with Nx tensors and is relatively simple