Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Ch 4: Optimize Everything

ch4_optimize_everything.livemd

Ch 4: Optimize Everything

Mix.install([
  {:nx, "~> 0.7.2"}
])

Learning with Optimization

Optimization is the search for the best

An ML system can be defined as any computing system that is capable of improving from experience on a specific task with respect to some arbitrary performance measure. The goal is to have a system that provides predictions about unseen examples

A basic architecture of such a system is:

def predict(input) do
  label = do_something(input)
  label
end

At its core, a trained ML system transforms inputs into labels. When given some unseen input X, the model can predict its expected output Y by following “the line”. These “lines” are nothing more than visualizations of two parameters: m or slope of the “line”, and b or the intercept of the “line”

Here input is equivalent to X, label is equivalent to Y, and m and b are the paremeters of the “line” drawn:

def predict(input, m, b) do
  label = m * input + b
  label
end

ML, in practice, revolves around finding input parameters that best transform inputs to labels. The form of these parameters might now always look like m and b, above is a simplified example of linear regression in two-dimensions

In practice, the forms of transformations and parameters in a ML algorithm will vary from problem to problem and typically have much higher dimensionality. A more practical and generalized view of the predict function could be:

def predict(input) do
  label = f(input, params)
  label
end

Where f represents the transformation performed by the ML algorithm of choice and params represents the learned parameters for the algorithm. The goal is to find the parameters that best map inputs to labels with the function; Optimize params to achieve the best performance for the given task

Minimizing Risk to Maximize Performance

Rather than optimizing for performance on an unobserved test set, optimize for performance on an available training set and hope the ML model captures enough of the features of the unseen test set that the model’s parameters are able to optimally assign labels to unseen inputs

During training a loss function or cost function is defined that is then optimized with respect to your model’s parameters. The parameterized loss function is the overall objective function. The relationship between model, parameters, loss function and objective function in an ML problems looks like:

@doc """
A function that uses `params` to transform `inputs`
into `labels`
"""
def model(params, inputs) do
  labels = f(params, inputs)
  labels
end

@doc """
A function that measures the difference between an
actual label and a predicted label

Gives indication or measure of correctness
"""
def loss(actual_labels, predicted_labels) do
  loss_value = measure_diff(actual_labels, predicted_labels)
  loss_value
end

@doc """
The loss function parameterized with the model
and parameters
"""
def objective(params, actual_inputs, actual_labels) do
  predicted_labels = model(params, actual_inputs)
  loss(actual_labels, predicted_labels)
end

ML is concerned with reducing generalization error on unseen inputs. In statistical theory, generalization error is also known as risk, and the model’s goal is to minimize risk. A model can’t optimize for rish when it does not have access to the entire distribution of possible inputs. Instead the model optimizes empirical risk, which is performance ont the empirical distribution or training set.

Minimizing empirical risk is almost equivalent to minimizing true risk. There are always sources of error when try to capture infinite distributions with finite data and computing power so they are never exactly equivalent.

The process of minimize emperical risk is known as empirical risk minimization (ERM)

Defining Objectives

ML algorithms almost never directly optimize for the performance measures of the problem being solved. Instead, ML algorithms use surrogate loss functions

Surrogate loss functions serve as proxies for true objectives. They are typically easier functions to optimize, and optimizing them typically also indirectly optimizes the performance objectives of the problem to be solved. The choice of a surrogate loss often depends on the ML task and type of model

Likelihood Estimation

The likelihood function describes the probability of observed data as a function of its parameters. A loss function that is a measure of similarity between two functions is an optimization process known as maximum likelihood estimation (MLE)

Cross-entropy

Cross-entropy as a loss function is typically used in the context of classification tasks

defn binary_cross_entropy(y_true, y_pred) do
  y_true * Nx.log(y_pred) - (1 - y_true) * Nx.log(1 - y_pred)
end

Mean Squared Error

Mean squared error measures per-example loss as the average squared difference between true labels and predicted labels

defn mean_squared_error(y_true, y_pred) do
  y_true
  |> Nx.subtract(y_pred)
  |> Nx.pow(2)
  |> Nx.mean(axes: -1)
end

Regularize to Generalize

The objective of training an ML model is generalization

Given two models:

  • A performs noticeably worse on the training set than B
  • A performs noticeably better on the test set than B

Which model is preferred?

Given the primary objective is performance on unseen data, B is the preferred choice from a model performance perspective. But, why does a model perform noticeably worse on the training set while the other performs noticeably better on the test set? The answer is regularization.

Overfitting, Underfitting and Capacity

Overfitting is a scenario which a trained model has low training error but high generalization error. Both ERM and MLE optimize model parameters with respect to errors on a training set, which means they are explicitly fitting functions to match training data. Both techniques are prone to overfitting

Underfitting is a scenario which a model doesn’t even have a low training error

Both overfitting and underfitting are typically functions of a model’s capacity. A model’s capacity is its ability to fit many different functions which are chosen while designing an ML model. This set of functions is known as the hypothesis space of functions

Defining Regularization

Regularization is any technique used to combat overfitting, more generally, any technique used to reduce generalization error

Since ERM and MLE are prone to overfitting, regularization is a necessary step for many ML algorithms

Complexity Penalties

Complexity penalties are a commonly used regularizer for training ML models that generalize. They impose a cost at model evaluation time by adding a penalty term with some penalty weight. Weight decay is a common regularization penalty that introduces a penalty term equal to the L2-Norm of the model’s parameters. L2-Norm can be interpretted as “distance from the origin”

Weight-decay expresses the preference for smaller model weights. Mathematically, it constrains the feasible parameter space of a given model to those that lie closer to the origin. Intuitively, weight-decay can be thought of as penalizing a model that gets too confident in particular weights

Early-stopping

Early-stopping is a regularizer that stops model training if overfitting is detected. It’s not possible to perfectly detect overfitting, so the typical approach for monitoring overfitting is with a validation set

Validation set are portions of the original training data that are not used to train but instead used to periodically monitor model performance

Descending Gradients

Gradient descent is an iterative optimization routine that uses the gradients of a function evaluated at a particular point to minimize a particular function

Implementing Stochastic Gradient Descent with Nx is simple thanks to it’s automatic differentiation capabilities

key = Nx.Random.key(42)
#Nx.Tensor<
  u32[2]
  [0, 42]
>
{true_params, new_key} = Nx.Random.uniform(key, shape: {32, 1})

true_function = fn params, x ->
  Nx.dot(x, params)
end
#Function<41.105768164/2 in :erl_eval.expr/6>
{train_x, new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
train_y = true_function.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(train_y, 1))
[
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.21419739723205566, 0.4202451705932617, 0.9174779653549194, 0.6951805353164673, 0.5770437717437744, 0.015613555908203125, 0.9820147752761841, 0.12217259407043457, 0.6313685178756714, 0.7669196128845215, 0.16057193279266357, 0.9078190326690674, 0.23667335510253906, 0.6146631240844727, 0.12274301052093506, 0.09529983997344971, 0.01969766616821289, 0.920687198638916, 0.6566638946533203, 0.8590985536575317, 0.336417555809021, 0.8502820730209351, 0.40605008602142334, 7.408857345581055e-4, 0.19105935096740723, 0.0871502161026001, 0.35579657554626465, 0.0827188491821289, 0.30659806728363037, 0.8906278610229492, 0.21179890632629395, 0.23324048519134521]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [4.911402702331543]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.5646347999572754, 0.9963592290878296, 0.3351341485977173, 0.21977758407592773, 0.5122578144073486, 0.6533946990966797, 0.46478259563446045, 0.025933027267456055, 0.23669421672821045, 0.11027109622955322, 0.6683825254440308, 0.8013483285903931, 0.1865140199661255, 0.8466619253158569, 0.9660632610321045, 0.6956028938293457, 0.9352200031280518, 0.3796454668045044, 0.06283676624298096, 0.9677258729934692, 0.782143235206604, 0.9192811250686646, 0.07585024833679199, 0.5071550607681274, 0.19996094703674316, 0.9129003286361694, 0.9038044214248657, 0.7150691747665405, 0.8976137638092041, 0.2711818218231201, 0.5959255695343018, 0.14550673961639404]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.325557708740234]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.2890130281448364, 0.3377346992492676, 0.3301658630371094, 0.1233062744140625, 0.20411169528961182, 0.28748345375061035, 0.6112610101699829, 0.20312929153442383, 0.2255460023880005, 0.4090055227279663, 0.3722764253616333, 0.920012354850769, 0.9858261346817017, 0.3830984830856323, 0.11448848247528076, 0.8428031206130981, 0.7443640232086182, 0.024654746055603027, 0.7059779167175293, 0.4974362850189209, 0.04586231708526611, 0.5795108079910278, 0.4304255247116089, 0.9775724411010742, 0.468021035194397, 0.48742222785949707, 0.955541729927063, 0.7387502193450928, 0.6655933856964111, 0.8951784372329712, 0.5373797416687012, 0.5804636478424072]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.946480751037598]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.028279900550842285, 0.6395993232727051, 0.031375885009765625, 0.7024474143981934, 0.5130670070648193, 0.1737067699432373, 0.10613739490509033, 0.40717077255249023, 0.047069668769836426, 0.867316722869873, 0.9732393026351929, 0.1491994857788086, 0.14715611934661865, 0.43631136417388916, 0.15817761421203613, 0.4147731065750122, 0.7241733074188232, 0.25670480728149414, 0.014826536178588867, 0.16278398036956787, 0.734898567199707, 0.516982913017273, 0.859603762626648, 0.7494732141494751, 0.8663722276687622, 0.26246213912963867, 0.1697627305984497, 0.8041620254516602, 0.23243749141693115, 0.2765010595321655, 0.019356846809387207, 0.02446436882019043]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.697136402130127]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6429306268692017, 0.8686245679855347, 0.9968123435974121, 0.8627218008041382, 0.5130723714828491, 0.6576781272888184, 0.5033124685287476, 0.8946937322616577, 0.5216338634490967, 0.7755841016769409, 0.3903813362121582, 0.610281229019165, 0.3125624656677246, 0.25667572021484375, 0.32864129543304443, 0.6825462579727173, 0.27321064472198486, 0.5848970413208008, 0.4000225067138672, 0.9944697618484497, 0.007418990135192871, 0.3305246829986572, 0.04195880889892578, 0.8209608793258667, 0.025928139686584473, 0.22786939144134521, 0.939563512802124, 0.8292726278305054, 0.8125652074813843, 0.08058571815490723, 0.8242276906967163, 0.9003914594650269]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.497109889984131]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9033777713775635, 0.010692119598388672, 0.9856387376785278, 0.28001368045806885, 0.12742769718170166, 0.056987643241882324, 0.6563032865524292, 0.015868425369262695, 0.3580033779144287, 0.6868652105331421, 0.8388558626174927, 0.3083916902542114, 0.585604190826416, 0.5165541172027588, 0.010527729988098145, 0.5832229852676392, 0.54107666015625, 0.710890531539917, 0.4708055257797241, 0.04568064212799072, 0.7684018611907959, 0.8329710960388184, 0.6369191408157349, 0.1535571813583374, 0.20751547813415527, 0.9164774417877197, 0.6013720035552979, 0.6344060897827148, 0.7668310403823853, 0.3542180061340332, 0.44900357723236084, 0.040746450424194336]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.123233318328857]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.731298565864563, 0.24785888195037842, 0.011466622352600098, 0.1112598180770874, 0.03377127647399902, 0.280226469039917, 0.5953247547149658, 0.739460825920105, 0.8142504692077637, 0.6693843603134155, 0.5555098056793213, 0.28535640239715576, 0.2316725254058838, 0.5937529802322388, 0.6147925853729248, 0.5925025939941406, 0.7417064905166626, 0.6948648691177368, 0.001192927360534668, 0.3013712167739868, 0.06051337718963623, 0.7015379667282104, 0.4316558837890625, 0.8253896236419678, 0.2610088586807251, 0.7366442680358887, 0.25142836570739746, 0.7319520711898804, 0.6824352741241455, 0.11031544208526611, 0.8072546720504761, 0.13371288776397705]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.368332386016846]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.06457006931304932, 0.5491797924041748, 0.9626376628875732, 0.42310619354248047, 0.4333592653274536, 0.6147439479827881, 0.3320721387863159, 0.37261414527893066, 0.3961852788925171, 0.47902143001556396, 0.36530065536499023, 0.2080230712890625, 0.8404935598373413, 0.8675642013549805, 0.695580244064331, 0.3153465986251831, 0.22675347328186035, 0.592113733291626, 0.4637026786804199, 0.8686285018920898, 0.10617625713348389, 0.7399885654449463, 0.6763122081756592, 0.8995020389556885, 0.536932110786438, 0.9256490468978882, 0.25525403022766113, 0.7603592872619629, 0.45229804515838623, 0.5522929430007935, 0.45831072330474854, 0.9319932460784912]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.7084126472473145]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9271001815795898, 0.3429156541824341, 0.09655606746673584, 0.09466707706451416, 0.9006072282791138, 0.2138385772705078, 0.5151525735855103, 0.7722166776657104, 0.6935923099517822, 0.4172854423522949, 0.6790682077407837, 0.7691861391067505, 0.6899746656417847, 0.4150722026824951, 0.289791464805603, 0.6220169067382812, 0.29761624336242676, 0.7116636037826538, 0.5838549137115479, 0.988337516784668, 0.06846141815185547, 0.20311403274536133, 0.9452928304672241, 0.08352434635162354, 0.4954712390899658, 0.7736111879348755, 0.6713893413543701, 0.43488168716430664, 0.05077075958251953, 0.17744314670562744, 0.15465223789215088, 0.8547592163085938]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.840044021606445]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.06017613410949707, 0.7573988437652588, 0.2671293020248413, 0.377036452293396, 0.15488910675048828, 0.8024213314056396, 0.44861674308776855, 0.3249342441558838, 0.9163384437561035, 0.05472278594970703, 0.4368704557418823, 0.7114390134811401, 0.050534725189208984, 0.5101956129074097, 0.6751821041107178, 0.8379642963409424, 0.21117258071899414, 0.623075008392334, 0.6573165655136108, 0.008363127708435059, 0.8689806461334229, 0.7193349599838257, 0.37983238697052, 0.19805502891540527, 0.5923962593078613, 0.9071087837219238, 0.8148437738418579, 0.13653087615966797, 0.29834258556365967, 0.2331470251083374, 0.390960693359375, 0.6359344720840454]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.279651641845703]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.540674090385437, 0.8634153604507446, 0.6930928230285645, 0.5184097290039062, 0.8631595373153687, 0.9299262762069702, 0.603912353515625, 0.45213890075683594, 0.2887920141220093, 0.02929389476776123, 0.782915472984314, 0.28684449195861816, 0.13385224342346191, 0.12771594524383545, 0.874720573425293, 0.24373912811279297, 0.5483328104019165, 0.30896270275115967, 0.06717348098754883, 0.5742068290710449, 0.2770789861679077, 0.01426839828491211, 0.8000349998474121, 0.51806640625, 0.9243561029434204, 0.8260501623153687, 0.4912137985229492, 0.8073453903198242, 0.017646431922912598, 0.7417961359024048, 0.5518542528152466, 0.6329332590103149]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.279178619384766]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7634663581848145, 0.07522153854370117, 0.12503480911254883, 0.995384931564331, 0.6072309017181396, 0.12540876865386963, 0.9362881183624268, 0.030733466148376465, 0.6750000715255737, 0.5595276355743408, 0.8576765060424805, 0.7870781421661377, 0.15496611595153809, 0.4225485324859619, 0.5202906131744385, 0.702318549156189, 0.9370477199554443, 0.319008469581604, 0.15663254261016846, 0.1396721601486206, 0.5332878828048706, 0.4656214714050293, 0.12286245822906494, 0.18269526958465576, 0.6548546552658081, 0.873659610748291, 0.9572122097015381, 0.7902345657348633, 0.588735818862915, 0.770898699760437, 0.08010900020599365, 0.9774599075317383]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.251378059387207]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.5733932256698608, 0.28852391242980957, 0.05642271041870117, 0.19096148014068604, 0.8746805191040039, 0.9121378660202026, 0.33812904357910156, 0.5280702114105225, 0.037364959716796875, 0.19046878814697266, 0.987305760383606, 0.5207898616790771, 0.7191983461380005, 0.3039665222167969, 0.10663533210754395, 0.8956723213195801, 0.8797550201416016, 0.9881069660186768, 0.07668650150299072, 0.2546120882034302, 0.7153855562210083, 0.7115358114242554, 0.45427632331848145, 0.8423178195953369, 0.015808701515197754, 0.0854419469833374, 0.22367537021636963, 0.2679309844970703, 0.7769159078598022, 0.7780354022979736, 0.24312353134155273, 0.44626355171203613]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.275750160217285]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.34603703022003174, 0.6236165761947632, 0.3351191282272339, 0.10127341747283936, 0.5319764614105225, 0.5452526807785034, 0.45823919773101807, 0.1731433868408203, 0.920047402381897, 0.6591004133224487, 0.3959653377532959, 0.8915334939956665, 0.37920379638671875, 0.7517484426498413, 0.896480917930603, 0.3786351680755615, 0.47003233432769775, 0.8744161128997803, 0.00768125057220459, 0.7991448640823364, 0.7198724746704102, 0.5535911321640015, 0.048066139221191406, 0.06627070903778076, 0.49426960945129395, 0.6911654472351074, 0.9227865934371948, 0.23994970321655273, 0.8584533929824829, 0.29686975479125977, 0.08083772659301758, 0.8508659601211548]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.358133792877197]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.25897514820098877, 0.39639222621917725, 0.9325405359268188, 0.5915662050247192, 0.019397377967834473, 0.934141993522644, 0.4311712980270386, 0.6546499729156494, 0.6923919916152954, 0.01774585247039795, 0.7375062704086304, 0.7418352365493774, 0.30239737033843994, 0.22865331172943115, 0.44480597972869873, 0.5696520805358887, 0.8188011646270752, 0.6122468709945679, 0.6789653301239014, 0.1875927448272705, 0.39335906505584717, 0.6955385208129883, 0.8845469951629639, 0.3688088655471802, 0.30924367904663086, 0.011875391006469727, 0.34288227558135986, 0.1975114345550537, 0.8599079847335815, 0.24817359447479248, 0.6347780227661133, 0.4112318754196167]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.487173557281494]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.22403907775878906, 0.08259356021881104, 0.45466434955596924, 0.3757232427597046, 0.23307907581329346, 0.06618845462799072, 0.07642745971679688, 0.37379252910614014, 0.1926511526107788, 0.32429349422454834, 0.12091898918151855, 0.9498410224914551, 0.5374590158462524, 0.20623505115509033, 0.5750828981399536, 0.4643818140029907, 0.715045690536499, 0.25955915451049805, 0.7587361335754395, 0.7317408323287964, 0.9211848974227905, 0.6531050205230713, 0.8654029369354248, 0.27620232105255127, 0.5775798559188843, 0.493998646736145, 0.6039055585861206, 0.39140546321868896, 0.11413073539733887, 0.23982369899749756, 0.44249284267425537, 0.5829825401306152]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.244842529296875]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8709145784378052, 0.39750826358795166, 0.8718006610870361, 0.7330857515335083, 0.0794292688369751, 0.39753711223602295, 0.9122143983840942, 0.22768986225128174, 0.8776557445526123, 0.335676908493042, 0.19429230690002441, 0.6122128963470459, 0.7476933002471924, 0.9513078927993774, 0.018947124481201172, 0.6470063924789429, 0.2292841672897339, 0.4042837619781494, 0.8632285594940186, 0.28368353843688965, 0.8194841146469116, 0.6801226139068604, 0.1277766227722168, 0.34661781787872314, 0.8841997385025024, 0.6553384065628052, 0.6167296171188354, 0.5355550050735474, 0.1543673276901245, 0.8647083044052124, 0.3289914131164551, 0.3924553394317627]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.822877407073975]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.631003737449646, 0.9914751052856445, 0.5126968622207642, 0.7870519161224365, 0.43065476417541504, 0.0041075944900512695, 0.4931894540786743, 0.7177931070327759, 0.8114221096038818, 0.579695463180542, 0.5687620639801025, 0.8791037797927856, 0.04205179214477539, 0.3524487018585205, 0.4811859130859375, 0.9353959560394287, 0.16603446006774902, 0.13031387329101562, 0.13674676418304443, 0.18401944637298584, 0.3514493703842163, 0.453141450881958, 0.36090385913848877, 0.4975547790527344, 0.9103914499282837, 0.8617379665374756, 0.7500883340835571, 0.057416319847106934, 0.222234845161438, 0.23873364925384521, 0.9162167310714722, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.379029273986816]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7609003782272339, 0.8510605096817017, 0.41890597343444824, 0.9356156587600708, 0.004058837890625, 0.6882948875427246, 0.777018666267395, 0.4773087501525879, 0.32202649116516113, 0.4344886541366577, 0.6230182647705078, 0.3931225538253784, 0.5879846811294556, 0.9389195442199707, 0.654564380645752, 0.5673922300338745, 0.17179429531097412, 0.19027507305145264, 0.6757367849349976, 0.05374908447265625, 0.6485933065414429, 0.9585849046707153, 0.0642237663269043, 0.893607497215271, 0.27879488468170166, 0.6041398048400879, 0.3211805820465088, 0.5789915323257446, 0.8080614805221558, 0.9113047122955322, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.914783954620361]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6569708585739136, 0.07861983776092529, 0.23802471160888672, 0.12751150131225586, 0.32677197456359863, 0.33022868633270264, 0.22364568710327148, 0.09476423263549805, 0.5992125272750854, 0.07833504676818848, 0.09313857555389404, 0.26847267150878906, 0.4497326612472534, 0.5623631477355957, 0.3078885078430176, 0.7772209644317627, 0.6664516925811768, 0.5849689245223999, 0.7815028429031372, 0.3451420068740845, 0.715222954750061, 0.48974764347076416, 0.3865237236022949, 0.294453501701355, 0.8022323846817017, 0.2440582513809204, 0.0838627815246582, 0.05519282817840576, 0.4562106132507324, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.969756603240967]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.42634856700897217, 0.8311656713485718, 0.27419185638427734, 0.3551710844039917, 0.3151041269302368, 0.7997424602508545, 0.9024616479873657, 0.7667425870895386, 0.41606831550598145, 0.4073857069015503, 0.17675912380218506, 0.2852247953414917, 0.4539358615875244, 0.16290879249572754, 0.8755624294281006, 0.41856563091278076, 0.056872010231018066, 0.9749754667282104, 0.4539651870727539, 0.5548393726348877, 0.5711885690689087, 0.9489098787307739, 0.04192376136779785, 0.4533880949020386, 0.6346640586853027, 0.5504379272460938, 0.3365325927734375, 0.9165209531784058, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.136617660522461]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.5641757249832153, 0.9077368974685669, 0.851854681968689, 0.6558282375335693, 0.6125811338424683, 0.8124090433120728, 0.9702218770980835, 0.02779090404510498, 0.7319562435150146, 0.5420526266098022, 0.6218849420547485, 0.9008104801177979, 0.1349804401397705, 0.7537416219711304, 0.5223276615142822, 0.9131952524185181, 0.9375765323638916, 0.35538744926452637, 0.9191871881484985, 0.536587119102478, 0.9227880239486694, 0.1520380973815918, 0.8761709928512573, 0.24608349800109863, 0.8468005657196045, 0.7329789400100708, 0.6444404125213623, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [9.924013137817383]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6698530912399292, 0.5241881608963013, 0.778433084487915, 0.5185062885284424, 0.20008862018585205, 0.2378925085067749, 0.16121220588684082, 0.4850255250930786, 0.3773798942565918, 0.5570013523101807, 0.38140594959259033, 0.7139229774475098, 0.4059755802154541, 0.09518742561340332, 0.3535490036010742, 0.17777550220489502, 0.21837365627288818, 0.8693867921829224, 0.6186724901199341, 0.19334256649017334, 0.9257535934448242, 0.3667900562286377, 0.7642225027084351, 0.016174674034118652, 0.938443660736084, 0.6296350955963135, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.076970100402832]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9571073055267334, 0.4434546232223511, 0.8838400840759277, 0.6176793575286865, 0.7852997779846191, 0.6787863969802856, 0.504677414894104, 0.43745851516723633, 0.7717669010162354, 0.6644160747528076, 0.8086050748825073, 0.154152512550354, 0.9943249225616455, 0.6983662843704224, 0.22040307521820068, 0.5256414413452148, 0.6795828342437744, 0.8989334106445312, 0.7875980138778687, 0.14916563034057617, 0.6334056854248047, 0.07992613315582275, 0.983758807182312, 0.3881075382232666, 0.26950299739837646, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.411914825439453]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.30753517150878906, 0.44749248027801514, 0.8855024576187134, 0.5025475025177002, 0.3357102870941162, 0.10730814933776855, 0.9354255199432373, 0.36325860023498535, 0.06997525691986084, 0.7279484272003174, 0.3918635845184326, 0.2349175214767456, 0.9933228492736816, 0.10251474380493164, 0.2719736099243164, 0.9983277320861816, 0.5580542087554932, 0.6940358877182007, 0.3528175354003906, 0.11204969882965088, 0.7471216917037964, 0.281843900680542, 0.584757924079895, 0.4283181428909302, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.507858276367188]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6248239278793335, 0.22014057636260986, 0.7038096189498901, 0.9967999458312988, 0.9672704935073853, 0.8332101106643677, 0.6936581134796143, 0.3075512647628784, 0.6806656122207642, 0.5613412857055664, 0.8672329187393188, 0.7468259334564209, 0.1101677417755127, 0.711369514465332, 0.20896148681640625, 0.4358783960342407, 0.10943818092346191, 0.40923237800598145, 0.28446221351623535, 0.058237552642822266, 0.8432282209396362, 0.6778583526611328, 0.23971843719482422, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.675758361816406]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.3962472677230835, 0.27139055728912354, 0.5323368310928345, 0.6035983562469482, 0.2958228588104248, 0.4986027479171753, 0.09223330020904541, 0.20876073837280273, 0.96107017993927, 0.9156640768051147, 0.6105798482894897, 0.7501920461654663, 0.3282437324523926, 0.5214126110076904, 0.9571560621261597, 0.18034577369689941, 0.5769946575164795, 0.005969643592834473, 0.17204713821411133, 0.4094083309173584, 0.12268543243408203, 0.7253572940826416, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.932293891906738]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6804090738296509, 0.8428153991699219, 0.39075136184692383, 0.7502392530441284, 0.6528669595718384, 0.763106107711792, 0.7931734323501587, 0.8093459606170654, 0.0921180248260498, 0.17537617683410645, 0.18979740142822266, 0.04392850399017334, 0.2329108715057373, 0.46092772483825684, 0.11988234519958496, 0.876970648765564, 0.5901447534561157, 0.22900879383087158, 0.29504191875457764, 0.29355132579803467, 0.5967469215393066, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.5610127449035645]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.18611860275268555, 0.6831626892089844, 0.9850664138793945, 0.5827019214630127, 0.2757478952407837, 0.7032021284103394, 0.6382184028625488, 0.26853466033935547, 0.759202241897583, 0.2775331735610962, 0.7704559564590454, 0.17358756065368652, 0.13375437259674072, 0.7020305395126343, 0.3446732759475708, 0.04839301109313965, 0.10832452774047852, 0.9177273511886597, 0.42168760299682617, 0.7985275983810425, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.907943248748779]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8724111318588257, 0.9060395956039429, 0.663974404335022, 0.4417393207550049, 0.3604276180267334, 0.14726388454437256, 0.9536886215209961, 0.6962372064590454, 0.6357282400131226, 0.7485288381576538, 0.40242207050323486, 0.9402639865875244, 0.9245424270629883, 0.44496726989746094, 0.9414383172988892, 0.4424762725830078, 0.398421049118042, 0.32365882396698, 0.49386703968048096, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.954349994659424]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.33178138732910156, 0.9648864269256592, 0.23880553245544434, 0.5952130556106567, 0.34877192974090576, 0.45050501823425293, 0.39850425720214844, 0.910041093826294, 0.5629657506942749, 0.3974491357803345, 0.4866042137145996, 0.5799317359924316, 0.09559011459350586, 0.9947899580001831, 0.45065581798553467, 0.8340728282928467, 0.8428640365600586, 0.5117822885513306, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.702505588531494]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.3027317523956299, 0.7090158462524414, 0.4645681381225586, 0.9297244548797607, 0.4892878532409668, 0.8504208326339722, 0.6548589468002319, 0.7956224679946899, 0.955285906791687, 0.8071571588516235, 0.010240316390991211, 0.8046956062316895, 0.7212648391723633, 0.36670148372650146, 0.724218487739563, 0.31477439403533936, 0.23347866535186768, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.32716178894043]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8743611574172974, 0.6752761602401733, 0.07919669151306152, 0.7560880184173584, 0.3694124221801758, 0.4278639554977417, 0.30996274948120117, 0.06689155101776123, 0.755111813545227, 0.4953286647796631, 0.3981630802154541, 0.9670692682266235, 0.29500603675842285, 0.10872447490692139, 0.7086962461471558, 0.8519164323806763, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.5628581047058105]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8509676456451416, 0.1929837465286255, 0.8952206373214722, 0.054573774337768555, 0.2884252071380615, 0.11869692802429199, 0.1865607500076294, 0.9117492437362671, 0.18362164497375488, 0.7144087553024292, 0.25691163539886475, 0.469804048538208, 0.5537645816802979, 0.8629399538040161, 0.7328643798828125, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.008695602416992]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.50124192237854, 0.7099924087524414, 0.39429163932800293, 0.30667269229888916, 0.805189847946167, 0.47770798206329346, 0.48304665088653564, 0.08487999439239502, 0.5447455644607544, 0.27381670475006104, 0.45372307300567627, 0.9596720933914185, 0.2775808572769165, 0.17481279373168945, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.747370719909668]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.4664132595062256, 0.8956135511398315, 0.7649083137512207, 0.8003439903259277, 0.40910804271698, 0.03284120559692383, 0.1743713617324829, 0.8939645290374756, 0.5393245220184326, 0.6551700830459595, 0.14191937446594238, 0.30531764030456543, 0.8912879228591919, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.560020446777344]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.1197669506072998, 0.07636654376983643, 0.2435610294342041, 0.09590327739715576, 0.505244255065918, 0.1342778205871582, 0.45705747604370117, 0.45460188388824463, 0.2514047622680664, 0.8220257759094238, 0.7251520156860352, 0.14361393451690674, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.024771213531494]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6785858869552612, 0.4399087429046631, 0.5257576704025269, 0.5514159202575684, 0.3694615364074707, 0.8350393772125244, 0.620568037033081, 0.17208480834960938, 0.13299310207366943, 0.2517954111099243, 0.2774406671524048, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.137871265411377]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9986991882324219, 0.5269997119903564, 0.7744612693786621, 0.23489665985107422, 0.1424393653869629, 0.691918134689331, 0.3357495069503784, 0.9809718132019043, 0.17029345035552979, 0.7596328258514404, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.278748512268066]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.37081921100616455, 0.6535693407058716, 0.011877059936523438, 0.9298231601715088, 0.36428725719451904, 0.6512700319290161, 0.7696858644485474, 0.19250214099884033, 0.47986137866973877, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.702021598815918]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7482943534851074, 0.7727749347686768, 0.1144859790802002, 0.7724059820175171, 0.6368030309677124, 0.8291548490524292, 0.313435435295105, 0.6373882293701172, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.593993186950684]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.39354753494262695, 0.6011625528335571, 0.21247148513793945, 0.2802027463912964, 0.3761352300643921, 0.11047065258026123, 0.1385045051574707, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.230188846588135]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8625725507736206, 0.20491087436676025, 0.09915947914123535, 0.12683987617492676, 0.5620198249816895, 0.13202905654907227, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.756672382354736]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.4777796268463135, 0.5366525650024414, 0.8587156534194946, 0.9605873823165894, 0.6869806051254272, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.652935981750488]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9411790370941162, 0.8997765779495239, 0.2766681909561157, 0.33345866203308105, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.9773478507995605]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9947642087936401, 0.2692502737045288, 0.9714730978012085, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.865835666656494]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7531495094299316, 0.8768380880355835, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.616130352020264]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.18010473251342773, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       ...
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       ...
     ]
   >, ...},
  {...},
  ...
]
{test_x, _new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
test_y = true_function.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))
[
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9657397270202637, 0.9266661405563354, 0.2524207830429077, 0.506806492805481, 0.03272294998168945, 0.6381621360778809, 0.4016733169555664, 0.4144333600997925, 0.8692346811294556, 0.19583988189697266, 0.4356701374053955, 0.037007689476013184, 0.4367654323577881, 0.9086041450500488, 0.4730778932571411, 0.29556596279144287, 0.49315857887268066, 0.3683987855911255, 0.8670364618301392, 0.527277946472168, 0.028360843658447266, 0.13743293285369873, 0.8709059953689575, 0.1861327886581421, 0.4181276559829712, 0.9427480697631836, 0.4339343309402466, 0.8707499504089355, 0.6826666593551636, 0.528895378112793, 0.17522680759429932, 0.4048128128051758]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.652977466583252]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.4392540454864502, 0.9165053367614746, 0.9777518510818481, 0.879123330116272, 0.612689733505249, 0.01696908473968506, 0.133436918258667, 0.4318392276763916, 0.5053318738937378, 0.7980244159698486, 0.1885296106338501, 0.9951480627059937, 0.3975728750228882, 0.226912260055542, 0.4825739860534668, 0.9671891927719116, 0.24038493633270264, 0.13231432437896729, 0.38793301582336426, 0.05815780162811279, 0.43374860286712646, 0.2860398292541504, 0.6426401138305664, 0.8966696262359619, 0.09666109085083008, 0.4394463300704956, 0.35843217372894287, 0.34688258171081543, 0.5460761785507202, 0.5041118860244751, 0.5477373600006104, 0.8824354410171509]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.90969705581665]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7762134075164795, 0.7822309732437134, 0.24949312210083008, 0.24509012699127197, 0.9004219770431519, 0.8151938915252686, 0.9005923271179199, 0.8640304803848267, 0.4731714725494385, 0.5921633243560791, 0.4887489080429077, 0.8375271558761597, 0.9577419757843018, 0.5891522169113159, 0.12607717514038086, 0.708527684211731, 0.41328203678131104, 0.6296629905700684, 0.6268273591995239, 0.35883355140686035, 0.36125707626342773, 0.6910197734832764, 0.7902359962463379, 0.7439805269241333, 0.4775749444961548, 0.9078165292739868, 0.3568282127380371, 0.15519630908966064, 0.11200845241546631, 0.7795575857162476, 0.468631386756897, 0.9759647846221924]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.54786205291748]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.3197380304336548, 0.17616689205169678, 0.8258638381958008, 0.3432091474533081, 0.35468900203704834, 0.5186667442321777, 0.7499172687530518, 0.4087836742401123, 0.4280334711074829, 0.2278900146484375, 0.6885714530944824, 0.6648629903793335, 0.5448073148727417, 0.1430720090866089, 0.842303991317749, 0.8900002241134644, 0.4492759704589844, 0.8455078601837158, 0.4587341547012329, 0.3691824674606323, 0.2542390823364258, 0.871134877204895, 0.26322853565216064, 0.10538768768310547, 0.355352520942688, 0.8888055086135864, 0.488552451133728, 0.6250888109207153, 0.9855941534042358, 0.738310694694519, 0.6712085008621216, 0.04661083221435547]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.220166206359863]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.40412354469299316, 0.7216566801071167, 0.25867438316345215, 0.9800392389297485, 0.9075496196746826, 0.6819312572479248, 0.0149993896484375, 0.04356968402862549, 0.9605370759963989, 0.02644956111907959, 0.28210699558258057, 0.7568575143814087, 0.21303772926330566, 0.002684950828552246, 0.6519417762756348, 0.28664088249206543, 0.15737569332122803, 0.37507736682891846, 0.05415797233581543, 0.03802788257598877, 0.8071837425231934, 0.06110048294067383, 0.6388435363769531, 0.44481122493743896, 0.23555970191955566, 0.61528480052948, 0.8113986253738403, 0.012137651443481445, 0.9276052713394165, 0.9450554847717285, 0.9840184450149536, 0.20820486545562744]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.5124430656433105]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.5034708976745605, 0.7755038738250732, 0.13867413997650146, 0.29906952381134033, 0.014742374420166016, 0.7755328416824341, 0.9173959493637085, 0.0935053825378418, 0.31686699390411377, 0.06115245819091797, 0.8989229202270508, 0.19432556629180908, 0.7501810789108276, 0.2113250494003296, 0.5822068452835083, 0.6005638837814331, 0.625515341758728, 0.6752986907958984, 0.9507982730865479, 0.7879356145858765, 0.5397478342056274, 0.3113539218902588, 0.8102543354034424, 0.2979027032852173, 0.7655726671218872, 0.42514193058013916, 0.09351170063018799, 0.8037655353546143, 0.4778313636779785, 0.44777703285217285, 0.3096567392349243, 0.33784306049346924]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.7442731857299805]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7518961429595947, 0.22414886951446533, 0.15714240074157715, 0.8663241863250732, 0.508256196975708, 0.30795371532440186, 0.11998486518859863, 0.18344223499298096, 0.4011112451553345, 0.924648642539978, 0.5058850049972534, 0.5193443298339844, 0.9716345071792603, 0.948397159576416, 0.5351895093917847, 0.5134536027908325, 0.6595708131790161, 0.06837213039398193, 0.05189085006713867, 0.8435298204421997, 0.7968239784240723, 0.12332558631896973, 0.7250438928604126, 0.7147141695022583, 0.8842874765396118, 0.9462226629257202, 0.843963623046875, 0.8965095281600952, 0.4305756092071533, 0.5930991172790527, 0.11764276027679443, 0.5833710432052612]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.623931884765625]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9475998878479004, 0.6403565406799316, 0.5817118883132935, 0.5467712879180908, 0.8543612957000732, 0.06530475616455078, 0.14756488800048828, 0.15206146240234375, 0.68489670753479, 0.932651162147522, 0.9179636240005493, 0.8796118497848511, 0.2882128953933716, 0.7526837587356567, 0.1426788568496704, 0.18050217628479004, 0.6951268911361694, 0.7308511734008789, 0.6911174058914185, 0.19187331199645996, 0.925081729888916, 0.8188349008560181, 0.5788781642913818, 0.33968937397003174, 0.8412926197052002, 0.50633704662323, 0.40607786178588867, 0.39345502853393555, 0.9535032510757446, 0.0635685920715332, 0.7170870304107666, 0.8757264614105225]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.50348949432373]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.01942753791809082, 0.899272084236145, 0.5304620265960693, 0.5620572566986084, 0.128362774848938, 0.31026554107666016, 0.6900253295898438, 0.2783416509628296, 0.004452347755432129, 0.5778182744979858, 0.026953697204589844, 0.014599919319152832, 0.3131972551345825, 0.6139227151870728, 0.6718645095825195, 0.9182217121124268, 0.4055975675582886, 0.9959360361099243, 0.3222285509109497, 0.1344226598739624, 0.8531616926193237, 0.1252962350845337, 0.7893067598342896, 0.6823188066482544, 0.38434433937072754, 0.0016857385635375977, 0.9079246520996094, 0.33411502838134766, 0.05022597312927246, 0.5846171379089355, 0.889033317565918, 0.7293587923049927]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.110525608062744]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8360906839370728, 0.4355432987213135, 0.11632812023162842, 0.7702548503875732, 0.24256396293640137, 0.36099421977996826, 0.6917792558670044, 0.288962721824646, 0.6243956089019775, 0.9861946105957031, 0.3847709894180298, 0.6880143880844116, 0.589323878288269, 0.4923354387283325, 0.3279656171798706, 0.4151395559310913, 0.852401614189148, 0.0718458890914917, 0.01529836654663086, 0.06954300403594971, 0.7971522808074951, 0.7249754667282104, 0.25757861137390137, 0.906819224357605, 0.6608389616012573, 0.40988433361053467, 0.26649951934814453, 0.6497167348861694, 0.31986987590789795, 0.8541487455368042, 0.7966134548187256, 0.23020529747009277]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.6582722663879395]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7711091041564941, 0.6362074613571167, 0.8661282062530518, 0.6193832159042358, 0.6161632537841797, 0.19212019443511963, 0.7516170740127563, 0.023564815521240234, 0.7833913564682007, 0.8175530433654785, 0.029859185218811035, 0.30578577518463135, 0.9423370361328125, 0.20194673538208008, 0.6542264223098755, 0.9779584407806396, 0.11775898933410645, 0.5317223072052002, 0.3922593593597412, 0.832879900932312, 0.657945990562439, 0.43512094020843506, 0.32924580574035645, 0.21120929718017578, 0.76695716381073, 0.9446995258331299, 0.02226400375366211, 0.8510793447494507, 0.06922507286071777, 0.03070998191833496, 0.9929032325744629, 0.6356418132781982]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.539000988006592]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.5236676931381226, 0.2989940643310547, 0.03134489059448242, 0.12583446502685547, 0.042815566062927246, 0.8364819288253784, 0.31674015522003174, 0.04060947895050049, 0.05521893501281738, 0.7784453630447388, 0.48672759532928467, 0.8186780214309692, 0.00807809829711914, 0.39795219898223877, 0.22978472709655762, 0.5791110992431641, 0.6117820739746094, 0.7412447929382324, 0.42317402362823486, 0.28765225410461426, 0.36166059970855713, 0.5173482894897461, 0.9059319496154785, 0.3208935260772705, 0.3955960273742676, 0.5770881175994873, 0.963921308517456, 0.05305802822113037, 0.009126543998718262, 0.30502188205718994, 0.348180890083313, 0.28527331352233887]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.539951324462891]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7139891386032104, 0.12666535377502441, 0.15333807468414307, 0.7544382810592651, 0.9695196151733398, 0.49136245250701904, 0.20070266723632812, 0.005652427673339844, 0.02549123764038086, 0.3883492946624756, 0.7958550453186035, 0.4632751941680908, 0.14279115200042725, 0.5663028955459595, 0.31234872341156006, 0.6877082586288452, 0.04052889347076416, 0.19631445407867432, 0.8272514343261719, 0.7589792013168335, 0.8727586269378662, 0.9460961818695068, 0.7840994596481323, 0.1846456527709961, 0.7626980543136597, 0.5093346834182739, 0.5205307006835938, 0.2435612678527832, 0.4535341262817383, 0.3754945993423462, 0.9493304491043091, 0.05621170997619629]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.268062591552734]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8532519340515137, 0.6970062255859375, 0.652370810508728, 0.4061359167098999, 0.11044573783874512, 0.11151957511901855, 0.851331353187561, 0.6565314531326294, 0.33859121799468994, 0.7652009725570679, 0.3588383197784424, 0.07348513603210449, 0.7815285921096802, 0.9533407688140869, 0.8006638288497925, 0.04949069023132324, 0.8293572664260864, 0.3746068477630615, 0.8676903247833252, 0.9169406890869141, 0.9336735010147095, 0.06596994400024414, 0.8225301504135132, 0.18987727165222168, 0.24470460414886475, 0.8587253093719482, 0.8066114187240601, 0.4743626117706299, 0.8888722658157349, 0.36300718784332275, 0.2819058895111084, 0.5664075613021851]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.505477428436279]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.3902559280395508, 0.14534735679626465, 0.09916174411773682, 0.7248067855834961, 0.20137739181518555, 0.6646915674209595, 0.0778660774230957, 0.8685482740402222, 0.641196608543396, 0.5231763124465942, 0.6376198530197144, 0.6526670455932617, 0.2163105010986328, 0.8063833713531494, 0.6756443977355957, 0.4447154998779297, 0.20969760417938232, 0.951002836227417, 0.045929908752441406, 0.17532849311828613, 0.9260181188583374, 0.5131326913833618, 0.30024540424346924, 0.3300440311431885, 0.9004764556884766, 0.8441228866577148, 0.9477831125259399, 0.6751878261566162, 0.4996030330657959, 0.2638866901397705, 0.7265427112579346, 0.49843263626098633]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.501368045806885]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8431416749954224, 0.20925045013427734, 0.882979154586792, 0.7653672695159912, 0.8408812284469604, 0.9956172704696655, 0.039161086082458496, 0.2890273332595825, 0.8148702383041382, 0.7094781398773193, 0.45367276668548584, 0.49341559410095215, 0.40397489070892334, 0.044689178466796875, 0.2746458053588867, 0.14785456657409668, 0.5764540433883667, 0.9384440183639526, 0.16802644729614258, 0.5668824911117554, 0.7575173377990723, 0.9617680311203003, 0.34545373916625977, 0.9809876680374146, 0.9966757297515869, 0.9557144641876221, 0.9793694019317627, 0.49138343334198, 0.327367901802063, 0.8423446416854858, 0.41049087047576904, 0.16183257102966309]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.983162879943848]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.22011446952819824, 0.34421753883361816, 0.6294939517974854, 0.5911561250686646, 0.3088088035583496, 0.23263812065124512, 0.3219066858291626, 0.03805994987487793, 0.1934577226638794, 0.9446452856063843, 0.9956920146942139, 0.5987362861633301, 0.26041269302368164, 0.8809354305267334, 0.768547534942627, 0.8348363637924194, 0.826107382774353, 0.6921907663345337, 0.046318769454956055, 0.5803121328353882, 0.9755550622940063, 0.917121410369873, 0.30008530616760254, 0.2571007013320923, 0.6167714595794678, 0.022228240966796875, 0.7143242359161377, 0.5102230310440063, 0.16165518760681152, 0.9506291151046753, 0.9326122999191284, 0.2996530532836914]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.267423152923584]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.38776493072509766, 0.7318899631500244, 0.1307905912399292, 0.09548139572143555, 0.8834151029586792, 0.25518178939819336, 0.8047440052032471, 0.5709458589553833, 0.022022604942321777, 0.8857442140579224, 0.802797794342041, 0.5964082479476929, 0.6515809297561646, 0.06845474243164062, 0.37185752391815186, 0.12856698036193848, 0.09193015098571777, 0.738864541053772, 0.43469250202178955, 0.7443429231643677, 0.016843795776367188, 0.13896071910858154, 0.9344608783721924, 0.04187476634979248, 0.021153926849365234, 0.5739061832427979, 0.11942613124847412, 0.6132626533508301, 0.5382595062255859, 0.9019054174423218, 0.3720097541809082, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.044473648071289]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.5642416477203369, 0.830094575881958, 0.11100232601165771, 0.38808906078338623, 0.2167491912841797, 0.9876217842102051, 0.8534290790557861, 0.38251447677612305, 0.9997782707214355, 0.5055009126663208, 0.836064338684082, 0.26873278617858887, 0.44322919845581055, 0.9741466045379639, 0.9353281259536743, 0.6532653570175171, 0.21104705333709717, 0.9035766124725342, 0.11548709869384766, 0.793843150138855, 0.591930627822876, 0.41485321521759033, 0.41184914112091064, 0.5477373600006104, 0.08824586868286133, 0.7526575326919556, 0.44268131256103516, 0.0763329267501831, 0.4934626817703247, 0.8778635263442993, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.992553234100342]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8263260126113892, 0.47554445266723633, 0.7326714992523193, 0.4069685935974121, 0.48123836517333984, 0.9155631065368652, 0.19115734100341797, 0.8787575960159302, 0.5277758836746216, 0.08595383167266846, 0.6884660720825195, 0.3847067356109619, 0.4597644805908203, 0.11427831649780273, 0.611096978187561, 0.17932391166687012, 0.046775102615356445, 0.15377891063690186, 0.17494094371795654, 0.6756585836410522, 0.6878424882888794, 0.23508989810943604, 0.795313835144043, 0.28050291538238525, 0.006268501281738281, 0.423947811126709, 0.3076256513595581, 0.712846040725708, 0.33599305152893066, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.1456146240234375]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.30186140537261963, 0.9873268604278564, 0.035573720932006836, 0.4155128002166748, 0.24097764492034912, 0.03020632266998291, 0.20567595958709717, 0.5887387990951538, 0.6643663644790649, 0.31537091732025146, 0.9205235242843628, 0.6104476451873779, 0.9955638647079468, 0.08851909637451172, 0.34864628314971924, 0.5111007690429688, 0.8216805458068848, 0.9719328880310059, 0.7817020416259766, 0.5537395477294922, 0.42271876335144043, 0.07025539875030518, 0.20106756687164307, 0.06972110271453857, 0.26150715351104736, 0.41637277603149414, 0.36588919162750244, 0.3534187078475952, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.348757266998291]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.8590803146362305, 0.5636343955993652, 0.03443634510040283, 0.5104682445526123, 0.606324315071106, 0.9271607398986816, 0.4498816728591919, 0.8567771911621094, 0.0051795244216918945, 0.3874478340148926, 0.1921311616897583, 0.6639413833618164, 0.46572935581207275, 0.8383623361587524, 0.9469438791275024, 0.020708560943603516, 0.15898573398590088, 0.944486141204834, 0.3688013553619385, 0.3565635681152344, 0.4661533832550049, 0.28654932975769043, 0.9991466999053955, 0.395352840423584, 0.1945810317993164, 0.37531542778015137, 0.34606099128723145, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.0145039558410645]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.40698301792144775, 0.39485931396484375, 0.7665799856185913, 0.45240330696105957, 0.3180277347564697, 0.6046972274780273, 0.27099859714508057, 0.32419121265411377, 0.15665650367736816, 0.8422553539276123, 0.17924022674560547, 0.9220238924026489, 0.6719664335250854, 0.3106119632720947, 0.7598171234130859, 0.41273176670074463, 0.24574947357177734, 0.4455568790435791, 0.7285884618759155, 0.5802567005157471, 0.7816479206085205, 0.43250608444213867, 0.7760515213012695, 0.9805769920349121, 0.48999035358428955, 0.6657071113586426, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.522810935974121]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6499532461166382, 0.33332526683807373, 0.18239235877990723, 0.28092002868652344, 0.6100766658782959, 0.8831055164337158, 0.5425313711166382, 0.46851158142089844, 0.0425875186920166, 0.4390711784362793, 0.24202609062194824, 0.4561119079589844, 0.30997657775878906, 0.8286688327789307, 0.9777672290802002, 0.7183065414428711, 0.994769811630249, 0.9684973955154419, 0.516169548034668, 0.7496813535690308, 0.1341181993484497, 0.23441553115844727, 0.4797624349594116, 0.7527389526367188, 0.706782341003418, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.673696041107178]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.29753637313842773, 0.9054058790206909, 0.5384794473648071, 0.008930683135986328, 0.9904316663742065, 0.012788891792297363, 0.6981593370437622, 0.21004319190979004, 0.24448156356811523, 0.3577829599380493, 0.8826776742935181, 0.556861162185669, 0.7953556776046753, 0.755801796913147, 0.8585830926895142, 0.8141891956329346, 0.611968994140625, 0.02731025218963623, 0.8326984643936157, 0.026828527450561523, 0.04161202907562256, 0.72672438621521, 0.16502487659454346, 0.3540531396865845, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.771440029144287]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.22602331638336182, 0.938198447227478, 0.6433297395706177, 0.22689640522003174, 0.2643556594848633, 0.38158679008483887, 0.18027138710021973, 0.7762355804443359, 0.8237777948379517, 0.16852951049804688, 0.2584739923477173, 0.38265061378479004, 0.028604865074157715, 0.05207550525665283, 0.08737969398498535, 0.831161379814148, 0.337926983833313, 0.578020453453064, 0.1244196891784668, 0.3897353410720825, 0.6503366231918335, 0.5011011362075806, 0.3202742338180542, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.270139217376709]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.1591346263885498, 0.3551875352859497, 0.19456207752227783, 0.08360016345977783, 0.49338340759277344, 0.8102091550827026, 0.038229942321777344, 0.5133059024810791, 0.42447197437286377, 0.1302204132080078, 0.032562255859375, 0.4770599603652954, 0.10980844497680664, 0.39250481128692627, 0.9826400279998779, 0.6408740282058716, 0.4889005422592163, 0.8684313297271729, 0.6371150016784668, 0.5141459703445435, 0.04483532905578613, 0.008337259292602539, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.415218830108643]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6458258628845215, 0.14830482006072998, 0.1367412805557251, 0.05629110336303711, 0.6188018321990967, 0.09611082077026367, 0.022463202476501465, 0.16802668571472168, 0.13714361190795898, 0.9738653898239136, 0.22950482368469238, 0.14387571811676025, 0.3699824810028076, 0.037640929222106934, 0.8094890117645264, 0.16305005550384521, 0.9263873100280762, 0.4807380437850952, 0.35078299045562744, 0.5584464073181152, 0.42825543880462646, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.667717456817627]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6350167989730835, 0.7023348808288574, 0.9816912412643433, 0.3239823579788208, 0.3320831060409546, 0.9436564445495605, 0.8940634727478027, 0.9500092267990112, 0.039070963859558105, 0.7722084522247314, 0.014232039451599121, 0.5869686603546143, 0.6310857534408569, 0.754490852355957, 0.4464530944824219, 0.28466737270355225, 0.12341678142547607, 0.28795409202575684, 0.03844261169433594, 0.05093085765838623, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.96828556060791]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.1708146333694458, 0.7656099796295166, 0.17411589622497559, 0.13566994667053223, 0.5877799987792969, 0.8375746011734009, 0.7500095367431641, 0.43630826473236084, 0.6984328031539917, 0.458881139755249, 0.3605443239212036, 0.6612114906311035, 0.7714365720748901, 0.18256628513336182, 0.6045645475387573, 0.6824719905853271, 0.093025803565979, 0.8150986433029175, 0.6484025716781616, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.744658946990967]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.37816083431243896, 0.3764007091522217, 0.366793155670166, 0.5291237831115723, 0.4347909688949585, 0.051524996757507324, 0.7034255266189575, 0.21059012413024902, 0.4240748882293701, 0.33774685859680176, 0.602317214012146, 0.8616265058517456, 0.09764957427978516, 0.8331103324890137, 0.6966776847839355, 0.9676048755645752, 0.39181971549987793, 0.5993291139602661, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.194350242614746]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9014437198638916, 0.6692262887954712, 0.2199995517730713, 0.34047043323516846, 0.5105847120285034, 0.5800155401229858, 0.6711152791976929, 0.24664950370788574, 0.837967038154602, 0.1854724884033203, 0.4090847969055176, 0.5367509126663208, 0.9298523664474487, 0.8082610368728638, 0.4719582796096802, 0.9887839555740356, 0.8493902683258057, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.900388240814209]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.18410146236419678, 0.16370105743408203, 0.00571596622467041, 0.6846107244491577, 0.8010516166687012, 0.5752760171890259, 0.7778260707855225, 0.570926308631897, 0.2175813913345337, 0.15920031070709229, 0.7178256511688232, 0.7729295492172241, 0.09782063961029053, 0.5135018825531006, 0.6833776235580444, 0.07111799716949463, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.223117351531982]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9505445957183838, 0.6332504749298096, 0.4598352909088135, 0.821086049079895, 0.4521135091781616, 0.40192127227783203, 0.37758028507232666, 0.7068672180175781, 0.5899826288223267, 0.3339945077896118, 0.5181410312652588, 0.5538294315338135, 0.32791435718536377, 0.6153789758682251, 0.21849405765533447, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.009317874908447]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.5818334817886353, 0.008952617645263672, 0.817952036857605, 0.8802107572555542, 0.2041914463043213, 0.6401019096374512, 0.5620921850204468, 0.6981043815612793, 0.9288793802261353, 0.17047274112701416, 0.8074302673339844, 0.027227401733398438, 0.8451176881790161, 0.8029766082763672, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.590928554534912]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.21364152431488037, 0.9299622774124146, 0.176169753074646, 0.40923750400543213, 0.0010215044021606445, 0.6993589401245117, 0.2511633634567261, 0.6728019714355469, 0.2957075834274292, 0.5348055362701416, 0.8762129545211792, 0.07801520824432373, 0.3015238046646118, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.7690606117248535]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.3406723737716675, 0.8631099462509155, 0.12643325328826904, 0.38934123516082764, 0.46326422691345215, 0.039055824279785156, 0.5590641498565674, 0.24887871742248535, 0.38112592697143555, 0.7917823791503906, 0.8130742311477661, 0.016570329666137695, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.340514183044434]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.05386042594909668, 0.0680922269821167, 0.31601762771606445, 0.18086862564086914, 0.7679719924926758, 0.6589527130126953, 0.9141957759857178, 0.402393102645874, 0.8808540105819702, 0.9081192016601562, 0.6332200765609741, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.193706512451172]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7115259170532227, 0.19793486595153809, 0.3831000328063965, 0.12318956851959229, 0.048275113105773926, 0.6922358274459839, 0.8630118370056152, 0.6173487901687622, 0.4392777681350708, 0.7511276006698608, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.914463043212891]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.15324389934539795, 0.608772873878479, 0.30232787132263184, 0.42286384105682373, 0.6527767181396484, 0.8560984134674072, 0.95783531665802, 0.35850799083709717, 0.20661818981170654, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.008178234100342]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.9347794055938721, 0.5582127571105957, 0.8688193559646606, 0.9338347911834717, 0.6681429147720337, 0.09503507614135742, 0.1364145278930664, 0.2338886260986328, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.377873420715332]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.44904494285583496, 0.6156696081161499, 0.24088788032531738, 0.04646635055541992, 0.1615074872970581, 0.3069014549255371, 0.476338267326355, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.039738178253174]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.26967108249664307, 0.34533655643463135, 0.32481229305267334, 0.032245635986328125, 0.33962345123291016, 0.5251162052154541, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [7.07838249206543]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.16084957122802734, 0.4878326654434204, 0.31659018993377686, 0.9245070219039917, 0.24317359924316406, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.446048259735107]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.2458251714706421, 0.5735921859741211, 0.2629578113555908, 0.3110496997833252, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [5.8523969650268555]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.7443956136703491, 0.07421326637268066, 0.4091228246688843, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [8.270795822143555]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.6418168544769287, 0.5278061628341675, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       [6.982004642486572]
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       [0.5448164939880371, ...]
     ]
   >,
   #Nx.Tensor<
     f32[1][1]
     [
       ...
     ]
   >},
  {#Nx.Tensor<
     f32[1][32]
     [
       ...
     ]
   >, ...},
  {...},
  ...
]

Each example is a tuple of {input, label} where input is a uniform random vector and label is obtained by evaluating the true data generating function, true_function

defmodule SGD do
  import Nx.Defn

  defn init_random_params(key) do
    Nx.Random.uniform(key, shape: {32, 1})
  end

  defn model(params, inputs) do
    labels = Nx.dot(inputs, params)
    labels
  end

  defn mean_squared_error(y_true, y_pred) do
    y_true
    |> Nx.subtract(y_pred)
    |> Nx.pow(2)
    |> Nx.mean(axes: [-1])
  end

  defn loss(actual_label, predicted_label) do
    loss_value = mean_squared_error(actual_label, predicted_label)
    loss_value
  end

  def objective(params, actual_inputs, actual_labels) do
    predicted_labels = model(params, actual_inputs)
    loss(actual_labels, predicted_labels)
  end

  def step(params, actual_inputs, actual_labels) do
    {loss, params_grad} =
      value_and_grad(params, fn params ->
        objective(params, actual_inputs, actual_labels)
      end)

    new_params = params - 1.0e-2 * params_grad
    {loss, new_params}
  end

  def evaluate(trained_params, test_data) do
    test_data
    |> Enum.map(fn {x, y} ->
      prediction = model(trained_params, x)
      loss(y, prediction)
    end)
    |> Enum.reduce(0, &amp;Nx.add/2)
  end

  def train(data, iterations, key) do
    {params, _key} = init_random_params(key)
    loss = Nx.tensor(0.0)

    {_, trained_params} =
      for i <- 1..iterations, reduce: {loss, params} do
        {loss, params} ->
          for {{x, y}, j} <- Enum.with_index(data), reduce: {loss, params} do
            {loss, params} ->
              {batch_loss, new_params} = step(params, x, y)
              avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
              IO.write("\rEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}")
              {avg_loss, new_params}
          end
      end

    trained_params
  end
end
{:module, SGD, <<70, 79, 82, 49, 0, 0, 28, ...>>, {:train, 3}}
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 1, key)