Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Machine Learning in Elixir (Ch. 1-5)

ml.livemd

Machine Learning in Elixir (Ch. 1-5)

Mix.install(
  [
    {:axon, "~> 0.5"},
    {:nx, "~> 0.5"},
    {:explorer, "~> 0.5"},
    {:kino, "~> 0.8"},
    {:scholar, "~> 0.3.1"},
    {:exla, "~> 0.5"},
    {:benchee, github: "bencheeorg/benchee", override: true},
    {:scidata, "~> 0.1"},
    {:stb_image, "~> 0.6"},
    {:vega_lite, "~> 0.1"},
    {:kino_vega_lite, "~> 0.1"}
  ],
  config: [
    nx: [
      default_backend: {EXLA.Backend, []},
      default_defn_options: [compiler: EXLA]
    ]
  ]
)

Chapter 1

require Explorer.DataFrame, as: DF
Explorer.DataFrame
iris = Explorer.Datasets.iris()
#Explorer.DataFrame<
  Polars[150 x 5]
  sepal_length f64 [5.1, 4.9, 4.7, 4.6, 5.0, ...]
  sepal_width f64 [3.5, 3.0, 3.2, 3.1, 3.6, ...]
  petal_length f64 [1.4, 1.4, 1.3, 1.5, 1.4, ...]
  petal_width f64 [0.2, 0.2, 0.2, 0.2, 0.2, ...]
  species string ["Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", ...]
>

Now we must normalize the data by iterating across the col that we want, and subtracting by the mean then taking the standard deviation.

feature_cols = ~w(sepal_width sepal_length petal_length petal_width)
normalized_iris =
  DF.mutate(
    iris,
    for col <- across(^feature_cols) do
      {col.name, (col - mean(col)) / standard_deviation(col)}
    end
  )
normalized_iris = DF.mutate(normalized_iris, [
  species: Explorer.Series.cast(species, :category)
])
shuffled_normalized_iris = DF.shuffle(normalized_iris)
#Explorer.DataFrame<
  Polars[150 x 5]
  sepal_length f64 [1.5175921634851133, -0.2938573685262962, 0.5514857464123618, 0.5514857464123618,
   0.4307224442782682, ...]
  sepal_width f64 [-0.12454037930145855, -0.35517071134119754, -0.8164313754206746,
   -1.7389527035796306, -0.35517071134119754, ...]
  petal_length f64 [1.213618539617258, -0.08992565766777902, 0.6468601929715899,
   0.36348101964875573, 0.30680518498418863, ...]
  petal_width f64 [1.1810530653405331, 0.13278111385485278, 0.787951083533403, 0.13278111385485278,
   0.13278111385485278, ...]
  species category ["Iris-virginica", "Iris-versicolor", "Iris-virginica", "Iris-versicolor",
   "Iris-versicolor", ...]
>

Split the data set into a training and test.

train_df = DF.slice(shuffled_normalized_iris, 0..119)
test_df = DF.slice(shuffled_normalized_iris, 120..149)
#Explorer.DataFrame<
  Polars[30 x 5]
  sepal_length f64 [0.4307224442782682, -0.897673879196766, 0.4307224442782682, -0.2938573685262962,
   0.5514857464123618, ...]
  sepal_width f64 [0.7979809488574964, 0.5673506168177574, -0.5858010433809365, -0.8164313754206746,
   0.7979809488574964, ...]
  petal_length f64 [0.930239366294424, -1.1667665162945489, 0.5901843583070228, 0.25012935031962197,
   1.0435910356235572, ...]
  petal_width f64 [1.443121053211953, -0.9154908376308276, 0.787951083533403, 0.13278111385485278,
   1.574155047147663, ...]
  species category ["Iris-virginica", "Iris-setosa", "Iris-virginica", "Iris-versicolor",
   "Iris-virginica", ...]
>

One hot encoding

x_train = Nx.stack(train_df[feature_cols], axis: -1)
y_train =
  train_df["species"]
  |> Nx.stack(axis: -1)
  |> Nx.equal(Nx.iota({1, 3}, axis: -1))

x_test = Nx.stack(test_df[feature_cols], axis: -1)
y_test =
  test_df["species"]
  |> Nx.stack(axis: -1)
  |> Nx.equal(Nx.iota({1, 3}, axis: -1))
#Nx.Tensor<
  u8[30][3]
  EXLA.Backend
  [
    [0, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 0],
    [0, 0, 1],
    [0, 1, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 1, 0],
    [1, 0, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
    [0, 0, 1],
    [1, 0, 0],
    [0, 1, ...],
    ...
  ]
>

Defining the model

model =
  Axon.input("iris_features", shape: {nil, 4})
  |> Axon.dense(3, activation: :softmax)
#Axon<
  inputs: %{"iris_features" => {nil, 4}}
  outputs: "softmax_0"
  nodes: 3
>
Axon.Display.as_graph(model, Nx.template({1, 4}, :f32))
data_stream = Stream.repeatedly(fn ->
  {x_train, y_train}
end)
#Function<53.38948127/2 in Stream.repeatedly/1>
trained_model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(data_stream, %{}, iterations: 500, epochs: 10)
Epoch: 0, Batch: 450, accuracy: 0.7216556 loss: 0.6444226
Epoch: 1, Batch: 450, accuracy: 0.8651145 loss: 0.5075818
Epoch: 2, Batch: 450, accuracy: 0.8833323 loss: 0.4441806
Epoch: 3, Batch: 450, accuracy: 0.8918683 loss: 0.4041536
Epoch: 4, Batch: 450, accuracy: 0.9085174 loss: 0.3751374
Epoch: 5, Batch: 450, accuracy: 0.9254994 loss: 0.3524860
Epoch: 6, Batch: 450, accuracy: 0.9357361 loss: 0.3339924
Epoch: 7, Batch: 450, accuracy: 0.9416658 loss: 0.3184373
Epoch: 8, Batch: 450, accuracy: 0.9416473 loss: 0.3050718
Epoch: 9, Batch: 450, accuracy: 0.9478759 loss: 0.2934021
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[3]
      EXLA.Backend
      [-0.43029528856277466, 1.3603023290634155, -0.9300051927566528]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][3]
      EXLA.Backend
      [
        [1.3022854328155518, -0.5877847671508789, -0.5648554563522339],
        [-1.0349063873291016, 0.4400143027305603, 0.5383542776107788],
        [-0.9887109398841858, 0.2664645314216614, 0.9659081101417542],
        [-1.6879781484603882, -0.7090367078781128, 2.7165253162384033]
      ]
    >
  }
}
data = [{x_test, y_test}]

model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(data, trained_model_state)
Batch: 0, accuracy: 1.0000000
%{
  0 => %{
    "accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend
      1.0
    >
  }
}

Chapter 2

a = Nx.tensor([1, 2, 3])
#Nx.Tensor<
  s64[3]
  EXLA.Backend
  [1, 2, 3]
>
a
|> Nx.as_type({:f, 32})
|> Nx.reshape({1, 3, 1})
#Nx.Tensor<
  f32[1][3][1]
  EXLA.Backend
  [
    [
      [1.0],
      [2.0],
      [3.0]
    ]
  ]
>
a = Nx.tensor([[[-1, -2, -3], [-4, -5, -6]], [[1, 2, 3], [4, 5, 6]]])
Nx.abs(a)
#Nx.Tensor<
  s64[2][2][3]
  EXLA.Backend
  [
    [
      [1, 2, 3],
      [4, 5, 6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ]
  ]
>
one = Nx.tensor([1, 2, 3])
b = Nx.tensor([[4, 5, 6], [7, 8, 9]])
Nx.add(one, b)
#Nx.Tensor<
  s64[2][3]
  EXLA.Backend
  [
    [5, 7, 9],
    [8, 10, 12]
  ]
>
revs =
  Nx.tensor(
    [
      [21, 64, 86, 26, 74, 81, 38, 79, 70, 48, 85, 33],
      [64, 82, 48, 39, 70, 71, 81, 53, 50, 67, 36, 50],
      [68, 74, 39, 78, 95, 62, 53, 21, 43, 59, 51, 88],
      [47, 74, 97, 51, 98, 47, 61, 36, 83, 55, 74, 43]
    ],
    names: [:year, :month]
  )
Nx.sum(revs, axes: [:year])
Nx.sum(revs, axes: [:month])
#Nx.Tensor<
  s64[year: 4]
  EXLA.Backend
  [705, 711, 731, 766]
>

Using defn

defmodule MyModule do
  import Nx.Defn

  defn adds_one(x) do
    Nx.add(x, 1) |> print_expr()
  end
end

MyModule.adds_one(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64[3]
  
  Nx.Defn.Expr
  parameter a:0   s64[3]
  b = add 1, a    s64[3]
>
#Nx.Tensor<
  s64[3]
  EXLA.Backend
  [2, 3, 4]
>
defmodule Softmax do
  import Nx.Defn

  defn softmax(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n))
end

key = Nx.Random.key(42)
{tensor, _key} = Nx.Random.uniform(key, shape: {1_000_000})

Benchee.run(
  %{
    "JIT with EXLA" => fn ->
      apply(EXLA.jit(&amp;Softmax.softmax/1), [tensor])
    end,
    "Regular Elixir" => fn ->
      Softmax.softmax(tensor)
    end
  },
  time: 10
)
Warning: the benchmark JIT with EXLA is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Warning: the benchmark Regular Elixir is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Operating System: Linux
CPU Information: 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
Number of Available Cores: 8
Available memory: 31.14 GB
Elixir 1.17.1
Erlang 27.0
JIT enabled: true

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s

Benchmarking JIT with EXLA ...
Benchmarking Regular Elixir ...
Calculating statistics...
Formatting results...

Name                     ips        average  deviation         median         99th %
JIT with EXLA         231.00        4.33 ms    ±32.16%        4.75 ms        6.92 ms
Regular Elixir        216.16        4.63 ms    ±26.95%        5.07 ms        6.55 ms

Comparison: 
JIT with EXLA         231.00
Regular Elixir        216.16 - 1.07x slower +0.30 ms
%Benchee.Suite{
  system: %Benchee.System{
    elixir: "1.17.1",
    erlang: "27.0",
    jit_enabled?: true,
    num_cores: 8,
    os: :Linux,
    available_memory: "31.14 GB",
    cpu_speed: "11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz"
  },
  configuration: %Benchee.Configuration{
    parallel: 1,
    time: 10000000000.0,
    warmup: 2000000000.0,
    memory_time: 0.0,
    reduction_time: 0.0,
    pre_check: false,
    formatters: [Benchee.Formatters.Console],
    percentiles: ~c"2c",
    print: %{configuration: true, benchmarking: true, fast_warning: true},
    inputs: nil,
    input_names: [],
    save: false,
    load: false,
    unit_scaling: :best,
    assigns: %{},
    before_each: nil,
    after_each: nil,
    before_scenario: nil,
    after_scenario: nil,
    measure_function_call_overhead: false,
    title: nil,
    profile_after: false
  },
  scenarios: [
    %Benchee.Scenario{
      name: "JIT with EXLA",
      job_name: "JIT with EXLA",
      function: #Function<43.39164016/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 4328979.287944493,
          ips: 231.00133622372329,
          std_dev: 1392391.75478593,
          std_dev_ratio: 0.32164435590244467,
          std_dev_ips: 74.30027600228354,
          median: 4747913.5,
          percentiles: %{50 => 4747913.5, 99 => 6919670.18},
          mode: 5232723,
          minimum: 1864560,
          maximum: 9329758,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 2306
        },
        samples: [5114107, 5362198, 4916142, 5115965, 4796499, 5703386, 5614357, 5303840, 5742211,
         5860846, 4901829, 5208238, 5928216, 5388787, 5668236, 5083852, 4350200, 4655715, 5678152,
         4923900, 4713065, 5709610, 4713120, 4203843, 5219188, 5244548, 5236738, 5324473, 5476706,
         5160914, 5544268, 5291099, 5198318, ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    },
    %Benchee.Scenario{
      name: "Regular Elixir",
      job_name: "Regular Elixir",
      function: #Function<43.39164016/0 in :erl_eval.expr/6>,
      input_name: :__no_input,
      input: :__no_input,
      before_each: nil,
      after_each: nil,
      before_scenario: nil,
      after_scenario: nil,
      tag: nil,
      run_time_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: 4626242.521779425,
          ips: 216.15814460487098,
          std_dev: 1246865.5461491058,
          std_dev_ratio: 0.26952014302732985,
          std_dev_ips: 58.25897405042708,
          median: 5066220.0,
          percentiles: %{50 => 5066220.0, 99 => 6549032.129999999},
          mode: nil,
          minimum: 1824331,
          maximum: 8035928,
          relative_more: 1.0686682042260545,
          relative_less: 0.9357441309149971,
          absolute_difference: 297263.23383493256,
          sample_size: 2158
        },
        samples: [6386470, 6109751, 6198511, 6310616, 5633511, 4599288, 4838860, 4341488, 3289374,
         3488793, 3283210, 3338509, 2762387, 3214241, 2655245, 2518645, 3734936, 3049036, 2591736,
         2483865, 2338518, 2475519, 2921886, 2951746, 2460539, 2000266, 1932629, 2273828, 2356111,
         2984978, 3135985, 2583151, ...]
      },
      memory_usage_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      },
      reductions_data: %Benchee.CollectionData{
        statistics: %Benchee.Statistics{
          average: nil,
          ips: nil,
          std_dev: nil,
          std_dev_ratio: nil,
          std_dev_ips: nil,
          median: nil,
          percentiles: nil,
          mode: nil,
          minimum: nil,
          maximum: nil,
          relative_more: nil,
          relative_less: nil,
          absolute_difference: nil,
          sample_size: 0
        },
        samples: []
      }
    }
  ]
}

Chapter 3

Nx.add(Nx.iota({2, 2, 2}), Nx.iota({2, 2}))
#Nx.Tensor<
  s64[2][2][2]
  EXLA.Backend
  [
    [
      [0, 2],
      [4, 6]
    ],
    [
      [4, 6],
      [8, 10]
    ]
  ]
>
r = Nx.iota({2, 2, 3}) |> IO.inspect()
s = Nx.iota({3, 2}) |> IO.inspect()

Nx.dot(r, s)
#Nx.Tensor<
  s64[2][2][3]
  EXLA.Backend
  [
    [
      [0, 1, 2],
      [3, 4, 5]
    ],
    [
      [6, 7, 8],
      [9, 10, 11]
    ]
  ]
>
#Nx.Tensor<
  s64[3][2]
  EXLA.Backend
  [
    [0, 1],
    [2, 3],
    [4, 5]
  ]
>
#Nx.Tensor<
  s64[2][2][2]
  EXLA.Backend
  [
    [
      [10, 13],
      [28, 40]
    ],
    [
      [46, 67],
      [64, 94]
    ]
  ]
>
simulation = fn key ->
  {value, key} = Nx.Random.uniform(key)
  if Nx.to_number(value) < 0.5, do: {0, key}, else: {1, key}
end

key = Nx.Random.key(42)

for n <- [10, 100] do
  Enum.map_reduce(1..n, key, fn _, key -> simulation.(key) end)
  |> elem(0)
  |> Enum.sum()
  |> IO.inspect()
end
6
49
[6, 49]
defmodule BerryFarm do
  import Nx.Defn

  defn profits(trees) do
    -((trees - 1) ** 4) + (trees ** 3) + (trees ** 2)
  end

  defn profits_derivative(trees) do
    grad(trees, &amp;profits/1)
  end
end

trees = Nx.linspace(0, 4, n: 100)
profits = BerryFarm.profits(trees)
profits_derivative = BerryFarm.profits_derivative(trees)
#Nx.Tensor<
  f32[100]
  EXLA.Backend
  [4.0, 3.620183229446411, 3.287757396697998, 3.001140832901001, 2.7587497234344482, 2.5590009689331055, 2.4003114700317383, 2.2810988426208496, 2.199779748916626, 2.154770851135254, 2.1444895267486572, 2.1673526763916016, 2.2217769622802734, 2.306180000305176, 2.418977975845337, 2.558588743209839, 2.72342848777771, 2.911914587020874, 3.122464418411255, 3.353494167327881, 3.603421211242676, 3.8706626892089844, 4.153635025024414, 4.450756072998047, 4.760441780090332, 5.081110000610352, 5.411177158355713, 5.749061107635498, 6.09317684173584, 6.441944122314453, 6.793778419494629, 7.147095680236816, 7.500314712524414, 7.8518524169921875, 8.200122833251953, 8.543547630310059, 8.88054084777832, 9.209519386291504, 9.528902053833008, 9.837104797363281, 10.132542610168457, 10.41363525390625, 10.678799629211426, 10.926450729370117, 11.155006408691406, 11.362885475158691, 11.548501968383789, 11.710275650024414, 11.846620559692383, 11.955955505371094, ...]
>
alias VegaLite, as: Vl

Vl.new(title: "Berry Profits", width: 480, height: 320)
|> Vl.data_from_values(%{
  trees: Nx.to_flat_list(trees),
  profits: Nx.to_flat_list(profits),
  profits_derivative: Nx.to_flat_list(profits_derivative)
})
|> Vl.layers([
  Vl.new()
  |> Vl.mark(:line, interpolate: :basis)
  |> Vl.encode_field(:x, "trees", type: :quantitative)
  |> Vl.encode_field(:y, "profits", type: :quantitative),

  Vl.new()
  |> Vl.mark(:line, interpolate: :basis)
  |> Vl.encode_field(:x, "trees", type: :quantitative)
  |> Vl.encode_field(:y, "profits_derivative", type: :quantitative)
  |> Vl.encode(:color, value: "#ff0000")
])
defmodule GradFun do
  import Nx.Defn

  defn my_function(x) do
    x
    |> Nx.cos()
    |> Nx.exp()
    |> Nx.sum()
    |> print_expr()
  end

  defn grad_my_function(x) do
    grad(x, &amp;my_function/1) |> print_expr()
  end
end

GradFun.grad_my_function(Nx.tensor([1.0, 2.0, 3.0]))
#Nx.Tensor<
  f32
  
  Nx.Defn.Expr
  parameter a:0                            f32[3]
  b = cos a                                f32[3]
  c = exp b                                f32[3]
  d = sum c, axes: nil, keep_axes: false   f32
>
#Nx.Tensor<
  f32[3]
  
  Nx.Defn.Expr
  parameter a:0       f32[3]
  b = cos a           f32[3]
  c = exp b           f32[3]
  d = sin a           f32[3]
  e = negate d        f32[3]
  f = multiply c, e   f32[3]
>
#Nx.Tensor<
  f32[3]
  EXLA.Backend
  [-1.444406509399414, -0.5997574925422668, -0.05243729427456856]
>

Chapter 4

defmodule CrossEntropy do
  import Nx.Defn

  defn binary_cross_entropy(y_true, y_pred) do
    y_true * Nx.log(y_pred) - (1 - y_true) * Nx.log(1 - y_pred)
  end
end

for x <- [0.455, 0.333, 0.999, 0.8], do:
  CrossEntropy.binary_cross_entropy(1, x)
  |> IO.inspect()
#Nx.Tensor<
  f32
  EXLA.Backend
  -0.7874578237533569
>
#Nx.Tensor<
  f32
  EXLA.Backend
  -1.0996127128601074
>
#Nx.Tensor<
  f32
  EXLA.Backend
  -0.0010004874784499407
>
#Nx.Tensor<
  f32
  EXLA.Backend
  -0.2231435328722
>
[
  #Nx.Tensor<
    f32
    EXLA.Backend
    -0.7874578237533569
  >,
  #Nx.Tensor<
    f32
    EXLA.Backend
    -1.0996127128601074
  >,
  #Nx.Tensor<
    f32
    EXLA.Backend
    -0.0010004874784499407
  >,
  #Nx.Tensor<
    f32
    EXLA.Backend
    -0.2231435328722
  >
]

Stochastic gradient descent

key = Nx.Random.key(42)

{true_params, new_key} = Nx.Random.uniform(key, shape: {32, 1})
true_function = fn params, x ->
  Nx.dot(x, params) |> Nx.cos()
end

{train_x, new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
train_y = true_function.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(train_y, 1))

{test_x, _new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
test_y = true_function.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))
[
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.9657397270202637, 0.9266661405563354, 0.2524207830429077, 0.506806492805481, 0.03272294998168945, 0.6381621360778809, 0.4016733169555664, 0.4144333600997925, 0.8692346811294556, 0.19583988189697266, 0.4356701374053955, 0.037007689476013184, 0.4367654323577881, 0.9086041450500488, 0.4730778932571411, 0.29556596279144287, 0.49315857887268066, 0.3683987855911255, 0.8670364618301392, 0.527277946472168, 0.028360843658447266, 0.13743293285369873, 0.8709059953689575, 0.1861327886581421, 0.4181276559829712, 0.9427480697631836, 0.4339343309402466, 0.8707499504089355, 0.6826666593551636, 0.528895378112793, 0.17522680759429932, 0.4048128128051758]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9324023127555847]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.4392540454864502, 0.9165053367614746, 0.9777518510818481, 0.879123330116272, 0.612689733505249, 0.01696908473968506, 0.133436918258667, 0.4318392276763916, 0.5053318738937378, 0.7980244159698486, 0.1885296106338501, 0.9951480627059937, 0.3975728750228882, 0.226912260055542, 0.4825739860534668, 0.9671891927719116, 0.24038493633270264, 0.13231432437896729, 0.38793301582336426, 0.05815780162811279, 0.43374860286712646, 0.2860398292541504, 0.6426401138305664, 0.8966696262359619, 0.09666109085083008, 0.4394463300704956, 0.35843217372894287, 0.34688258171081543, 0.5460761785507202, 0.5041118860244751, 0.5477373600006104, 0.8824354410171509]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.810077965259552]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.7762134075164795, 0.7822309732437134, 0.24949312210083008, 0.24509012699127197, 0.9004219770431519, 0.8151938915252686, 0.9005923271179199, 0.8640304803848267, 0.4731714725494385, 0.5921633243560791, 0.4887489080429077, 0.8375271558761597, 0.9577419757843018, 0.5891522169113159, 0.12607717514038086, 0.708527684211731, 0.41328203678131104, 0.6296629905700684, 0.6268273591995239, 0.35883355140686035, 0.36125707626342773, 0.6910197734832764, 0.7902359962463379, 0.7439805269241333, 0.4775749444961548, 0.9078165292739868, 0.3568282127380371, 0.15519630908966064, 0.11200845241546631, 0.7795575857162476, 0.468631386756897, 0.9759647846221924]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [-0.6395251154899597]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.3197380304336548, 0.17616689205169678, 0.8258638381958008, 0.3432091474533081, 0.35468900203704834, 0.5186667442321777, 0.7499172687530518, 0.4087836742401123, 0.4280334711074829, 0.2278900146484375, 0.6885714530944824, 0.6648629903793335, 0.5448073148727417, 0.1430720090866089, 0.842303991317749, 0.8900002241134644, 0.4492759704589844, 0.8455078601837158, 0.4587341547012329, 0.3691824674606323, 0.2542390823364258, 0.871134877204895, 0.26322853565216064, 0.10538768768310547, 0.355352520942688, 0.8888055086135864, 0.488552451133728, 0.6250888109207153, 0.9855941534042358, 0.738310694694519, 0.6712085008621216, 0.04661083221435547]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [-0.3580564856529236]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.40412354469299316, 0.7216566801071167, 0.25867438316345215, 0.9800392389297485, 0.9075496196746826, 0.6819312572479248, 0.0149993896484375, 0.04356968402862549, 0.9605370759963989, 0.02644956111907959, 0.28210699558258057, 0.7568575143814087, 0.21303772926330566, 0.002684950828552246, 0.6519417762756348, 0.28664088249206543, 0.15737569332122803, 0.37507736682891846, 0.05415797233581543, 0.03802788257598877, 0.8071837425231934, 0.06110048294067383, 0.6388435363769531, 0.44481122493743896, 0.23555970191955566, 0.61528480052948, 0.8113986253738403, 0.012137651443481445, 0.9276052713394165, 0.9450554847717285, 0.9840184450149536, 0.20820486545562744]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9738354682922363]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.5034708976745605, 0.7755038738250732, 0.13867413997650146, 0.29906952381134033, 0.014742374420166016, 0.7755328416824341, 0.9173959493637085, 0.0935053825378418, 0.31686699390411377, 0.06115245819091797, 0.8989229202270508, 0.19432556629180908, 0.7501810789108276, 0.2113250494003296, 0.5822068452835083, 0.6005638837814331, 0.625515341758728, 0.6752986907958984, 0.9507982730865479, 0.7879356145858765, 0.5397478342056274, 0.3113539218902588, 0.8102543354034424, 0.2979027032852173, 0.7655726671218872, 0.42514193058013916, 0.09351170063018799, 0.8037655353546143, 0.4778313636779785, 0.44777703285217285, 0.3096567392349243, 0.33784306049346924]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.10948850959539413]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.7518961429595947, 0.22414886951446533, 0.15714240074157715, 0.8663241863250732, 0.508256196975708, 0.30795371532440186, 0.11998486518859863, 0.18344223499298096, 0.4011112451553345, 0.924648642539978, 0.5058850049972534, 0.5193443298339844, 0.9716345071792603, 0.948397159576416, 0.5351895093917847, 0.5134536027908325, 0.6595708131790161, 0.06837213039398193, 0.05189085006713867, 0.8435298204421997, 0.7968239784240723, 0.12332558631896973, 0.7250438928604126, 0.7147141695022583, 0.8842874765396118, 0.9462226629257202, 0.843963623046875, 0.8965095281600952, 0.4305756092071533, 0.5930991172790527, 0.11764276027679443, 0.5833710432052612]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.22802595794200897]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.9475998878479004, 0.6403565406799316, 0.5817118883132935, 0.5467712879180908, 0.8543612957000732, 0.06530475616455078, 0.14756488800048828, 0.15206146240234375, 0.68489670753479, 0.932651162147522, 0.9179636240005493, 0.8796118497848511, 0.2882128953933716, 0.7526837587356567, 0.1426788568496704, 0.18050217628479004, 0.6951268911361694, 0.7308511734008789, 0.6911174058914185, 0.19187331199645996, 0.925081729888916, 0.8188349008560181, 0.5788781642913818, 0.33968937397003174, 0.8412926197052002, 0.50633704662323, 0.40607786178588867, 0.39345502853393555, 0.9535032510757446, 0.0635685920715332, 0.7170870304107666, 0.8757264614105225]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [-0.6047953367233276]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.01942753791809082, 0.899272084236145, 0.5304620265960693, 0.5620572566986084, 0.128362774848938, 0.31026554107666016, 0.6900253295898438, 0.2783416509628296, 0.004452347755432129, 0.5778182744979858, 0.026953697204589844, 0.014599919319152832, 0.3131972551345825, 0.6139227151870728, 0.6718645095825195, 0.9182217121124268, 0.4055975675582886, 0.9959360361099243, 0.3222285509109497, 0.1344226598739624, 0.8531616926193237, 0.1252962350845337, 0.7893067598342896, 0.6823188066482544, 0.38434433937072754, 0.0016857385635375977, 0.9079246520996094, 0.33411502838134766, 0.05022597312927246, 0.5846171379089355, 0.889033317565918, 0.7293587923049927]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9851313829421997]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.8360906839370728, 0.4355432987213135, 0.11632812023162842, 0.7702548503875732, 0.24256396293640137, 0.36099421977996826, 0.6917792558670044, 0.288962721824646, 0.6243956089019775, 0.9861946105957031, 0.3847709894180298, 0.6880143880844116, 0.589323878288269, 0.4923354387283325, 0.3279656171798706, 0.4151395559310913, 0.852401614189148, 0.0718458890914917, 0.01529836654663086, 0.06954300403594971, 0.7971522808074951, 0.7249754667282104, 0.25757861137390137, 0.906819224357605, 0.6608389616012573, 0.40988433361053467, 0.26649951934814453, 0.6497167348861694, 0.31986987590789795, 0.8541487455368042, 0.7966134548187256, 0.23020529747009277]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9304757714271545]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.7711091041564941, 0.6362074613571167, 0.8661282062530518, 0.6193832159042358, 0.6161632537841797, 0.19212019443511963, 0.7516170740127563, 0.023564815521240234, 0.7833913564682007, 0.8175530433654785, 0.029859185218811035, 0.30578577518463135, 0.9423370361328125, 0.20194673538208008, 0.6542264223098755, 0.9779584407806396, 0.11775898933410645, 0.5317223072052002, 0.3922593593597412, 0.832879900932312, 0.657945990562439, 0.43512094020843506, 0.32924580574035645, 0.21120929718017578, 0.76695716381073, 0.9446995258331299, 0.02226400375366211, 0.8510793447494507, 0.06922507286071777, 0.03070998191833496, 0.9929032325744629, 0.6356418132781982]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.30979853868484497]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.5236676931381226, 0.2989940643310547, 0.03134489059448242, 0.12583446502685547, 0.042815566062927246, 0.8364819288253784, 0.31674015522003174, 0.04060947895050049, 0.05521893501281738, 0.7784453630447388, 0.48672759532928467, 0.8186780214309692, 0.00807809829711914, 0.39795219898223877, 0.22978472709655762, 0.5791110992431641, 0.6117820739746094, 0.7412447929382324, 0.42317402362823486, 0.28765225410461426, 0.36166059970855713, 0.5173482894897461, 0.9059319496154785, 0.3208935260772705, 0.3955960273742676, 0.5770881175994873, 0.963921308517456, 0.05305802822113037, 0.009126543998718262, 0.30502188205718994, 0.348180890083313, 0.28527331352233887]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.736284077167511]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.7139891386032104, 0.12666535377502441, 0.15333807468414307, 0.7544382810592651, 0.9695196151733398, 0.49136245250701904, 0.20070266723632812, 0.005652427673339844, 0.02549123764038086, 0.3883492946624756, 0.7958550453186035, 0.4632751941680908, 0.14279115200042725, 0.5663028955459595, 0.31234872341156006, 0.6877082586288452, 0.04052889347076416, 0.19631445407867432, 0.8272514343261719, 0.7589792013168335, 0.8727586269378662, 0.9460961818695068, 0.7840994596481323, 0.1846456527709961, 0.7626980543136597, 0.5093346834182739, 0.5205307006835938, 0.2435612678527832, 0.4535341262817383, 0.3754945993423462, 0.9493304491043091, 0.05621170997619629]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.5529653429985046]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.8532519340515137, 0.6970062255859375, 0.652370810508728, 0.4061359167098999, 0.11044573783874512, 0.11151957511901855, 0.851331353187561, 0.6565314531326294, 0.33859121799468994, 0.7652009725570679, 0.3588383197784424, 0.07348513603210449, 0.7815285921096802, 0.9533407688140869, 0.8006638288497925, 0.04949069023132324, 0.8293572664260864, 0.3746068477630615, 0.8676903247833252, 0.9169406890869141, 0.9336735010147095, 0.06596994400024414, 0.8225301504135132, 0.18987727165222168, 0.24470460414886475, 0.8587253093719482, 0.8066114187240601, 0.4743626117706299, 0.8888722658157349, 0.36300718784332275, 0.2819058895111084, 0.5664075613021851]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.34149232506752014]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.3902559280395508, 0.14534735679626465, 0.09916174411773682, 0.7248067855834961, 0.20137739181518555, 0.6646915674209595, 0.0778660774230957, 0.8685482740402222, 0.641196608543396, 0.5231763124465942, 0.6376198530197144, 0.6526670455932617, 0.2163105010986328, 0.8063833713531494, 0.6756443977355957, 0.4447154998779297, 0.20969760417938232, 0.951002836227417, 0.045929908752441406, 0.17532849311828613, 0.9260181188583374, 0.5131326913833618, 0.30024540424346924, 0.3300440311431885, 0.9004764556884766, 0.8441228866577148, 0.9477831125259399, 0.6751878261566162, 0.4996030330657959, 0.2638866901397705, 0.7265427112579346, 0.49843263626098633]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.34535130858421326]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.8431416749954224, 0.20925045013427734, 0.882979154586792, 0.7653672695159912, 0.8408812284469604, 0.9956172704696655, 0.039161086082458496, 0.2890273332595825, 0.8148702383041382, 0.7094781398773193, 0.45367276668548584, 0.49341559410095215, 0.40397489070892334, 0.044689178466796875, 0.2746458053588867, 0.14785456657409668, 0.5764540433883667, 0.9384440183639526, 0.16802644729614258, 0.5668824911117554, 0.7575173377990723, 0.9617680311203003, 0.34545373916625977, 0.9809876680374146, 0.9966757297515869, 0.9557144641876221, 0.9793694019317627, 0.49138343334198, 0.327367901802063, 0.8423446416854858, 0.41049087047576904, 0.16183257102966309]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [-0.12882225215435028]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.22011446952819824, 0.34421753883361816, 0.6294939517974854, 0.5911561250686646, 0.3088088035583496, 0.23263812065124512, 0.3219066858291626, 0.03805994987487793, 0.1934577226638794, 0.9446452856063843, 0.9956920146942139, 0.5987362861633301, 0.26041269302368164, 0.8809354305267334, 0.768547534942627, 0.8348363637924194, 0.826107382774353, 0.6921907663345337, 0.046318769454956055, 0.5803121328353882, 0.9755550622940063, 0.917121410369873, 0.30008530616760254, 0.2571007013320923, 0.6167714595794678, 0.022228240966796875, 0.7143242359161377, 0.5102230310440063, 0.16165518760681152, 0.9506291151046753, 0.9326122999191284, 0.2996530532836914]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.5534984469413757]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.38776493072509766, 0.7318899631500244, 0.1307905912399292, 0.09548139572143555, 0.8834151029586792, 0.25518178939819336, 0.8047440052032471, 0.5709458589553833, 0.022022604942321777, 0.8857442140579224, 0.802797794342041, 0.5964082479476929, 0.6515809297561646, 0.06845474243164062, 0.37185752391815186, 0.12856698036193848, 0.09193015098571777, 0.738864541053772, 0.43469250202178955, 0.7443429231643677, 0.016843795776367188, 0.13896071910858154, 0.9344608783721924, 0.04187476634979248, 0.021153926849365234, 0.5739061832427979, 0.11942613124847412, 0.6132626533508301, 0.5382595062255859, 0.9019054174423218, 0.3720097541809082, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9716433882713318]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.5642416477203369, 0.830094575881958, 0.11100232601165771, 0.38808906078338623, 0.2167491912841797, 0.9876217842102051, 0.8534290790557861, 0.38251447677612305, 0.9997782707214355, 0.5055009126663208, 0.836064338684082, 0.26873278617858887, 0.44322919845581055, 0.9741466045379639, 0.9353281259536743, 0.6532653570175171, 0.21104705333709717, 0.9035766124725342, 0.11548709869384766, 0.793843150138855, 0.591930627822876, 0.41485321521759033, 0.41184914112091064, 0.5477373600006104, 0.08824586868286133, 0.7526575326919556, 0.44268131256103516, 0.0763329267501831, 0.4934626817703247, 0.8778635263442993, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [-0.13812902569770813]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.8263260126113892, 0.47554445266723633, 0.7326714992523193, 0.4069685935974121, 0.48123836517333984, 0.9155631065368652, 0.19115734100341797, 0.8787575960159302, 0.5277758836746216, 0.08595383167266846, 0.6884660720825195, 0.3847067356109619, 0.4597644805908203, 0.11427831649780273, 0.611096978187561, 0.17932391166687012, 0.046775102615356445, 0.15377891063690186, 0.17494094371795654, 0.6756585836410522, 0.6878424882888794, 0.23508989810943604, 0.795313835144043, 0.28050291538238525, 0.006268501281738281, 0.423947811126709, 0.3076256513595581, 0.712846040725708, 0.33599305152893066, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9905521273612976]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.30186140537261963, 0.9873268604278564, 0.035573720932006836, 0.4155128002166748, 0.24097764492034912, 0.03020632266998291, 0.20567595958709717, 0.5887387990951538, 0.6643663644790649, 0.31537091732025146, 0.9205235242843628, 0.6104476451873779, 0.9955638647079468, 0.08851909637451172, 0.34864628314971924, 0.5111007690429688, 0.8216805458068848, 0.9719328880310059, 0.7817020416259766, 0.5537395477294922, 0.42271876335144043, 0.07025539875030518, 0.20106756687164307, 0.06972110271453857, 0.26150715351104736, 0.41637277603149414, 0.36588919162750244, 0.3534187078475952, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9978509545326233]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.8590803146362305, 0.5636343955993652, 0.03443634510040283, 0.5104682445526123, 0.606324315071106, 0.9271607398986816, 0.4498816728591919, 0.8567771911621094, 0.0051795244216918945, 0.3874478340148926, 0.1921311616897583, 0.6639413833618164, 0.46572935581207275, 0.8383623361587524, 0.9469438791275024, 0.020708560943603516, 0.15898573398590088, 0.944486141204834, 0.3688013553619385, 0.3565635681152344, 0.4661533832550049, 0.28654932975769043, 0.9991466999053955, 0.395352840423584, 0.1945810317993164, 0.37531542778015137, 0.34606099128723145, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.7442944049835205]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.40698301792144775, 0.39485931396484375, 0.7665799856185913, 0.45240330696105957, 0.3180277347564697, 0.6046972274780273, 0.27099859714508057, 0.32419121265411377, 0.15665650367736816, 0.8422553539276123, 0.17924022674560547, 0.9220238924026489, 0.6719664335250854, 0.3106119632720947, 0.7598171234130859, 0.41273176670074463, 0.24574947357177734, 0.4455568790435791, 0.7285884618759155, 0.5802567005157471, 0.7816479206085205, 0.43250608444213867, 0.7760515213012695, 0.9805769920349121, 0.48999035358428955, 0.6657071113586426, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.32514989376068115]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.6499532461166382, 0.33332526683807373, 0.18239235877990723, 0.28092002868652344, 0.6100766658782959, 0.8831055164337158, 0.5425313711166382, 0.46851158142089844, 0.0425875186920166, 0.4390711784362793, 0.24202609062194824, 0.4561119079589844, 0.30997657775878906, 0.8286688327789307, 0.9777672290802002, 0.7183065414428711, 0.994769811630249, 0.9684973955154419, 0.516169548034668, 0.7496813535690308, 0.1341181993484497, 0.23441553115844727, 0.4797624349594116, 0.7527389526367188, 0.706782341003418, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.17931008338928223]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.29753637313842773, 0.9054058790206909, 0.5384794473648071, 0.008930683135986328, 0.9904316663742065, 0.012788891792297363, 0.6981593370437622, 0.21004319190979004, 0.24448156356811523, 0.3577829599380493, 0.8826776742935181, 0.556861162185669, 0.7953556776046753, 0.755801796913147, 0.8585830926895142, 0.8141891956329346, 0.611968994140625, 0.02731025218963623, 0.8326984643936157, 0.026828527450561523, 0.04161202907562256, 0.72672438621521, 0.16502487659454346, 0.3540531396865845, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.08244838565587997]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.22602331638336182, 0.938198447227478, 0.6433297395706177, 0.22689640522003174, 0.2643556594848633, 0.38158679008483887, 0.18027138710021973, 0.7762355804443359, 0.8237777948379517, 0.16852951049804688, 0.2584739923477173, 0.38265061378479004, 0.028604865074157715, 0.05207550525665283, 0.08737969398498535, 0.831161379814148, 0.337926983833313, 0.578020453453064, 0.1244196891784668, 0.3897353410720825, 0.6503366231918335, 0.5011011362075806, 0.3202742338180542, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9999148845672607]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.1591346263885498, 0.3551875352859497, 0.19456207752227783, 0.08360016345977783, 0.49338340759277344, 0.8102091550827026, 0.038229942321777344, 0.5133059024810791, 0.42447197437286377, 0.1302204132080078, 0.032562255859375, 0.4770599603652954, 0.10980844497680664, 0.39250481128692627, 0.9826400279998779, 0.6408740282058716, 0.4889005422592163, 0.8684313297271729, 0.6371150016784668, 0.5141459703445435, 0.04483532905578613, 0.008337259292602539, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9912962317466736]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.6458258628845215, 0.14830482006072998, 0.1367412805557251, 0.05629110336303711, 0.6188018321990967, 0.09611082077026367, 0.022463202476501465, 0.16802668571472168, 0.13714361190795898, 0.9738653898239136, 0.22950482368469238, 0.14387571811676025, 0.3699824810028076, 0.037640929222106934, 0.8094890117645264, 0.16305005550384521, 0.9263873100280762, 0.4807380437850952, 0.35078299045562744, 0.5584464073181152, 0.42825543880462646, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9269740581512451]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.6350167989730835, 0.7023348808288574, 0.9816912412643433, 0.3239823579788208, 0.3320831060409546, 0.9436564445495605, 0.8940634727478027, 0.9500092267990112, 0.039070963859558105, 0.7722084522247314, 0.014232039451599121, 0.5869686603546143, 0.6310857534408569, 0.754490852355957, 0.4464530944824219, 0.28466737270355225, 0.12341678142547607, 0.28795409202575684, 0.03844261169433594, 0.05093085765838623, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.7743555903434753]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.1708146333694458, 0.7656099796295166, 0.17411589622497559, 0.13566994667053223, 0.5877799987792969, 0.8375746011734009, 0.7500095367431641, 0.43630826473236084, 0.6984328031539917, 0.458881139755249, 0.3605443239212036, 0.6612114906311035, 0.7714365720748901, 0.18256628513336182, 0.6045645475387573, 0.6824719905853271, 0.093025803565979, 0.8150986433029175, 0.6484025716781616, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.10910505801439285]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.37816083431243896, 0.3764007091522217, 0.366793155670166, 0.5291237831115723, 0.4347909688949585, 0.051524996757507324, 0.7034255266189575, 0.21059012413024902, 0.4240748882293701, 0.33774685859680176, 0.602317214012146, 0.8616265058517456, 0.09764957427978516, 0.8331103324890137, 0.6966776847839355, 0.9676048755645752, 0.39181971549987793, 0.5993291139602661, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.6128256320953369]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.9014437198638916, 0.6692262887954712, 0.2199995517730713, 0.34047043323516846, 0.5105847120285034, 0.5800155401229858, 0.6711152791976929, 0.24664950370788574, 0.837967038154602, 0.1854724884033203, 0.4090847969055176, 0.5367509126663208, 0.9298523664474487, 0.8082610368728638, 0.4719582796096802, 0.9887839555740356, 0.8493902683258057, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [-0.046389952301979065]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.18410146236419678, 0.16370105743408203, 0.00571596622467041, 0.6846107244491577, 0.8010516166687012, 0.5752760171890259, 0.7778260707855225, 0.570926308631897, 0.2175813913345337, 0.15920031070709229, 0.7178256511688232, 0.7729295492172241, 0.09782063961029053, 0.5135018825531006, 0.6833776235580444, 0.07111799716949463, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.5898432731628418]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.9505445957183838, 0.6332504749298096, 0.4598352909088135, 0.821086049079895, 0.4521135091781616, 0.40192127227783203, 0.37758028507232666, 0.7068672180175781, 0.5899826288223267, 0.3339945077896118, 0.5181410312652588, 0.5538294315338135, 0.32791435718536377, 0.6153789758682251, 0.21849405765533447, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.7477481961250305]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.5818334817886353, 0.008952617645263672, 0.817952036857605, 0.8802107572555542, 0.2041914463043213, 0.6401019096374512, 0.5620921850204468, 0.6981043815612793, 0.9288793802261353, 0.17047274112701416, 0.8074302673339844, 0.027227401733398438, 0.8451176881790161, 0.8029766082763672, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9530194401741028]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.21364152431488037, 0.9299622774124146, 0.176169753074646, 0.40923750400543213, 0.0010215044021606445, 0.6993589401245117, 0.2511633634567261, 0.6728019714355469, 0.2957075834274292, 0.5348055362701416, 0.8762129545211792, 0.07801520824432373, 0.3015238046646118, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.08481946587562561]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.3406723737716675, 0.8631099462509155, 0.12643325328826904, 0.38934123516082764, 0.46326422691345215, 0.039055824279785156, 0.5590641498565674, 0.24887871742248535, 0.38112592697143555, 0.7917823791503906, 0.8130742311477661, 0.016570329666137695, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [-0.4675636291503906]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.05386042594909668, 0.0680922269821167, 0.31601762771606445, 0.18086862564086914, 0.7679719924926758, 0.6589527130126953, 0.9141957759857178, 0.402393102645874, 0.8808540105819702, 0.9081192016601562, 0.6332200765609741, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.6133341789245605]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.7115259170532227, 0.19793486595153809, 0.3831000328063965, 0.12318956851959229, 0.048275113105773926, 0.6922358274459839, 0.8630118370056152, 0.6173487901687622, 0.4392777681350708, 0.7511276006698608, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.932788610458374]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.15324389934539795, 0.608772873878479, 0.30232787132263184, 0.42286384105682373, 0.6527767181396484, 0.8560984134674072, 0.95783531665802, 0.35850799083709717, 0.20661818981170654, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.748504102230072]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.9347794055938721, 0.5582127571105957, 0.8688193559646606, 0.9338347911834717, 0.6681429147720337, 0.09503507614135742, 0.1364145278930664, 0.2338886260986328, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9955204129219055]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.44904494285583496, 0.6156696081161499, 0.24088788032531738, 0.04646635055541992, 0.1615074872970581, 0.3069014549255371, 0.476338267326355, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.3215336799621582]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.26967108249664307, 0.34533655643463135, 0.32481229305267334, 0.032245635986328125, 0.33962345123291016, 0.5251162052154541, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.7001439929008484]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.16084957122802734, 0.4878326654434204, 0.31659018993377686, 0.9245070219039917, 0.24317359924316406, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9867671728134155]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.2458251714706421, 0.5735921859741211, 0.2629578113555908, 0.3110496997833252, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.9086368083953857]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.7443956136703491, 0.07421326637268066, 0.4091228246688843, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [-0.40484946966171265]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.6418168544769287, 0.5278061628341675, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       [0.7656022310256958]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       [0.5448164939880371, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     EXLA.Backend
     [
       ...
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     EXLA.Backend
     [
       ...
     ]
   >, ...},
  {...},
  ...
]
defmodule SGD do
  import Nx.Defn

  defn init_random_params(key) do
    Nx.Random.uniform(key, shape: {32, 1})
  end

  defn model(params, inputs) do
    labels = Nx.dot(inputs, params)
    labels
  end

  defn mean_squared_error(y_true, y_pred) do
    y_true
    |> Nx.subtract(y_pred)
    |> Nx.pow(2)
    |> Nx.mean(axes: [-1])
  end

  defn loss(actual_label, predicted_label) do
    loss_value = mean_squared_error(actual_label, predicted_label)
    loss_value
  end

  defn objective(params, actual_inputs, actual_labels) do
    predicted_labels = model(params, actual_inputs)
    loss(actual_labels, predicted_labels)
  end 

  defn step(params, actual_inputs, actual_labels) do
    {loss, params_grad} = value_and_grad(params, fn params ->
      objective(params, actual_inputs, actual_labels)
    end)
    new_params = params - 1.0e-2 * params_grad
    {loss, new_params}
  end

  def evaluate(trained_params, test_data) do
    test_data
    |> Enum.map(fn {x, y} ->
      prediction = model(trained_params, x)
      loss(y, prediction)
    end)
    |> Enum.reduce(0, &amp;Nx.add/2)
  end

  def train(data, iterations, key) do
    {params, _key} = init_random_params(key)
    loss = Nx.tensor(0.0)

    {_, trained_params} =
      for i <- 1..iterations, reduce: {loss, params} do
        {loss, params} ->
          for {{x, y}, j} <- Enum.with_index(data), reduce: {loss, params} do
            {loss, params} ->
              {batch_loss, new_params} = step(params, x, y)
              avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
              IO.write("\rEpic: #{i}, Loss: #{Nx.to_number(avg_loss)}")
              {avg_loss, new_params}
          end
      end
    trained_params
  end
end
{:module, SGD, <<70, 79, 82, 49, 0, 0, 31, ...>>, {:train, 3}}
key = Nx.Random.key(100)
{random_params, _} = SGD.init_random_params(key)
SGD.evaluate(random_params, test_data)
#Nx.Tensor<
  f32[1]
  EXLA.Backend
  [427584.0625]
>
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 1, key)
SGD.evaluate(trained_params, test_data)
Epic: 1, Loss: 4.419022843649145e-6
#Nx.Tensor<
  f32[1]
  EXLA.Backend
  [2648.607421875]
>

Chapter 5

m = :rand.uniform() * 10
b = :rand.uniform() * 10

key = Nx.Random.key(42)
size = 100
{x, new_key} = Nx.Random.normal(key, 0.0, 1.0, shape: {size, 1})
{noise_x, new_key} = Nx.Random.normal(new_key, 0.0, 1.0, shape: {size, 1})

y =
  m
  |> Nx.multiply(Nx.add(x, noise_x))
  |> Nx.add(b)
#Nx.Tensor<
  f32[100][1]
  EXLA.Backend
  [
    [1.239382266998291],
    [5.344666957855225],
    [9.054686546325684],
    [-0.9490962028503418],
    [7.150064945220947],
    [6.748875617980957],
    [7.904621124267578],
    [3.337090492248535],
    [6.795029163360596],
    [9.969968795776367],
    [1.6933927536010742],
    [19.88507843017578],
    [6.888549327850342],
    [9.998130798339844],
    [6.731719017028809],
    [9.519281387329102],
    [4.11965799331665],
    [3.476489543914795],
    [14.017874717712402],
    [-7.367081642150879],
    [-7.590155601501465],
    [3.8600473403930664],
    [6.463741779327393],
    [6.2305073738098145],
    [10.549234390258789],
    [9.811880111694336],
    [11.68867301940918],
    [10.834447860717773],
    [19.456283569335938],
    [-5.116979598999023],
    [-8.104174613952637],
    [8.733948707580566],
    [14.805582046508789],
    [9.803984642028809],
    [13.485210418701172],
    [22.02405548095703],
    [9.046976089477539],
    [1.8259243965148926],
    [-5.544628143310547],
    [17.507389068603516],
    [7.30125093460083],
    [26.619380950927734],
    [6.8991475105285645],
    [9.702752113342285],
    [6.994278907775879],
    [17.967500686645508],
    [9.042976379394531],
    [6.460255146026611],
    [-2.2111105918884277],
    [2.369089365005493],
    ...
  ]
>
Vl.new(title: "Scatterplot", width: 720, height: 480)
|> Vl.data_from_values(%{
  x: Nx.to_flat_list(x),
  y: Nx.to_flat_list(y)
})
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
model = Scholar.Linear.LinearRegression.fit(x, y)
%Scholar.Linear.LinearRegression{
  coefficients: #Nx.Tensor<
    f32[1][1]
    EXLA.Backend
    [
      [5.524238109588623]
    ]
  >,
  intercept: #Nx.Tensor<
    f32[1]
    EXLA.Backend
    [5.875639915466309]
  >
}
Scholar.Linear.LinearRegression.predict(model, Nx.iota({3, 1}))
#Nx.Tensor<
  f32[3][1]
  EXLA.Backend
  [
    [5.875639915466309],
    [11.399877548217773],
    [16.924116134643555]
  ]
>
pred_xs = Nx.linspace(-3.0, 3.0, n: 100) |> Nx.new_axis(-1)
pred_ys = Scholar.Linear.LinearRegression.predict(model, pred_xs)

Vl.new(title: "Scatterplot Distribution and Fit Curve", width: 720, height: 480)
|> Vl.data_from_values(%{
  x: Nx.to_flat_list(x),
  y: Nx.to_flat_list(y),
  pred_x: Nx.to_flat_list(pred_xs),
  pred_y: Nx.to_flat_list(pred_ys)
})
|> Vl.layers([
  Vl.new()
  |> Vl.mark(:point)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative),
  Vl.new()
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "pred_x", type: :quantitative)
  |> Vl.encode_field(:y, "pred_y", type: :quantitative)
])
{inputs, targets} = Scidata.Wine.download()
{train, test} =
  inputs
  |> Enum.zip(targets)
  |> Enum.shuffle()
  |> Enum.split(floor(length(inputs) * 0.8))

{train_inputs, train_targets} = Enum.unzip(train)
train_inputs = Nx.tensor(train_inputs)
train_targets = Nx.tensor(train_targets)

{test_inputs, test_targets} = Enum.unzip(test)
test_inputs = Nx.tensor(test_inputs)
test_targets = Nx.tensor(test_targets)

train_inputs = Scholar.Preprocessing.min_max_scale(train_inputs)
test_inputs = Scholar.Preprocessing.min_max_scale(test_inputs)
#Nx.Tensor<
  f32[36][13]
  EXLA.Backend
  [
    [0.007685163989663124, 0.0022144701797515154, 0.0012977271107956767, 0.013310633599758148, 0.05051611363887787, 8.988844347186387e-4, 8.631672244518995e-4, 2.7978525031358004e-4, 8.810258004814386e-4, 0.002774040913209319, 4.1670139762572944e-4, 0.001113187987357378, 0.30648982524871826],
    [0.007185122463852167, 6.250521400943398e-4, 9.58413234911859e-4, 0.009917492978274822, 0.08980510383844376, 0.0010179419768974185, 6.78627984598279e-4, 0.0, 0.0014048789162188768, 0.001613229513168335, 6.78627984598279e-4, 0.0017441929085180163, 0.4273332357406616],
    [0.0071077351458370686, 0.0010060362983494997, 0.0012977271107956767, 0.010929482989013195, 0.04813496395945549, 8.691200637258589e-4, 8.095912635326385e-4, 2.2620933305006474e-4, 8.929315372370183e-4, 0.0013453501742333174, 5.595705006271601e-4, 0.001267962739802897, 0.285654753446579],
    [0.0073637086898088455, 9.465074981562793e-4, 0.001095329411327839, 0.012120057828724384, 0.05051611363887787, 0.0012262926902621984, 0.0010596121428534389, 1.0715178359532729e-4, 7.976855267770588e-4, 0.0016668055905029178, 5.357589107006788e-4, ...],
    ...
  ]
>
model = Scholar.Linear.LogisticRegression.fit(
  train_inputs,
  train_targets,
  num_classes: 3
)
%Scholar.Linear.LogisticRegression{
  coefficients: #Nx.Tensor<
    f32[13][3]
    EXLA.Backend
    [
      [1.0946117639541626, 1.0140817165374756, 0.8913055062294006],
      [1.2052838802337646, 0.9849055409431458, 0.8098109364509583],
      [1.0240973234176636, 1.0001778602600098, 0.9757255911827087],
      [1.3684810400009155, 0.6371003985404968, 0.9944186806678772],
      [1.1884140968322754, 0.8866520524024963, 0.9249335527420044],
      [0.880347728729248, 1.045974612236023, 1.0736756324768066],
      [0.7496421933174133, 1.098419427871704, 1.1519380807876587],
      [1.0152549743652344, 0.9921272993087769, 0.9926177859306335],
      [0.9138743281364441, 1.0223854780197144, 1.0637397766113281],
      [1.5787770748138428, 0.9660248756408691, 0.4551984965801239],
      [0.943101167678833, 1.006646752357483, 1.0502517223358154],
      [0.8121234774589539, 1.0624675750732422, 1.1254044771194458],
      [-2.150484085083008, 13.511106491088867, -8.360634803771973]
    ]
  >,
  bias: #Nx.Tensor<
    f32[3]
    EXLA.Backend
    [1.8604915142059326, -6.460078239440918, 4.5995893478393555]
  >
}
test_preds = Scholar.Linear.LogisticRegression.predict(model, test_inputs)
Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
Scholar.Metrics.Classification.confusion_matrix(test_targets, test_preds, num_classes: 3)
#Nx.Tensor<
  u64[3][3]
  EXLA.Backend
  [
    [3, 11, 0],
    [0, 0, 13],
    [1, 0, 8]
  ]
>
Vl.new(title: "Confusion Matrix", width: 480, height: 240)
|> Vl.data_from_values(%{
  predicted: Nx.to_flat_list(test_preds),
  actual: Nx.to_flat_list(test_targets)
})
|> Vl.mark(:rect)
|> Vl.encode_field(:x, "predicted")
|> Vl.encode_field(:y, "actual")
|> Vl.encode(:color, aggregate: :count)

K-Nearest Neighbor

model = Scholar.Neighbors.KNNClassifier.fit(
  train_inputs, train_targets, num_neighbors: 5, num_classes: 3 
)

test_preds = Scholar.Neighbors.KNNClassifier.predict(model, test_inputs)

Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
#Nx.Tensor<
  f32
  EXLA.Backend
  0.694444477558136
>

K-Means Clustering

model = Scholar.Cluster.KMeans.fit(train_inputs, num_clusters: 3)
%Scholar.Cluster.KMeans{
  num_iterations: #Nx.Tensor<
    s64
    EXLA.Backend
    2
  >,
  clusters: #Nx.Tensor<
    f32[3][13]
    EXLA.Backend
    [
      [0.008282720111310482, 0.0015221600187942386, 0.0014713852433487773, 0.01258805487304926, 0.0666232481598854, 0.0012808121973648667, 9.521916508674622e-4, 1.6673454956617206e-4, 8.787906845100224e-4, 0.0036122675519436598, 4.810251703020185e-4, 0.0014167047338560224, 0.4691700041294098],
      [0.008868475444614887, 0.0011699299793690443, 0.0014671300305053592, 0.010855224914848804, 0.06807452440261841, 0.0017753373831510544, 0.00185955292545259, 9.592168498784304e-5, 0.0011823351960629225, 0.0035062956158071756, 6.024370668455958e-4, 0.001944467076100409, 0.7536961436271667],
      [0.007999742403626442, 0.0015280431834980845, 0.0013884290819987655, 0.013398759998381138, 0.05932284891605377, 0.001259022275917232, 0.0010443272767588496, 1.7318504978902638e-4, 8.409738074988127e-4, 0.0025327885523438454, 5.191014497540891e-4, 0.00151579431258142, 0.29649674892425537]
    ]
  >,
  inertia: #Nx.Tensor<
    f32
    EXLA.Backend
    0.6819310188293457
  >,
  labels: #Nx.Tensor<
    s64[142]
    EXLA.Backend
    [2, 0, 0, 1, 0, 0, 2, 1, 0, 0, 2, 2, 2, 2, 0, 0, 2, 1, 2, 0, 2, 2, 0, 1, 2, 2, 0, 1, 0, 2, 0, 0, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 0, 0, 2, 0, ...]
  >
}
wine_features = %{
  "feature_1" => train_inputs[[.., 1]] |> Nx.to_flat_list(),
  "feature_2" => train_inputs[[.., 2]] |> Nx.to_flat_list(),
  "class" => train_targets |> Nx.to_flat_list()
}
coords = [
  cluster_feature_1: model.clusters[[.., 1]] |> Nx.to_flat_list(),
  cluster_feature_2: model.clusters[[.., 2]] |> Nx.to_flat_list()
]
title =
  "Scatterplot of data samples projected on plane wine"
  <> " feature 1 x wine feature 2"
Vl.new(
  width: 720,
  height: 480,
  title: [
    text: title,
    offset: 25
  ]
)
|> Vl.layers([
  Vl.new()
  |> Vl.data_from_values(wine_features)
  |> Vl.mark(:circle)
  |> Vl.encode_field(:x, "feature_1", type: :quantitative)
  |> Vl.encode_field(:y, "feature_2", type: :quantitative)
  |> Vl.encode_field(:color, "class"),
  Vl.new()
  |> Vl.data_from_values(coords)
  |> Vl.mark(:circle, color: :green, size: 100)
  |> Vl.encode_field(:x, "cluster_feature_1", type: :quantitative)
  |> Vl.encode_field(:y, "cluster_feature_2", type: :quantitative)
])
test_preds = Scholar.Cluster.KMeans.predict(model, test_inputs)
Scholar.Metrics.Classification.accuracy(test_targets, test_preds)
#Nx.Tensor<
  f32
  EXLA.Backend
  0.25
>