Ch 4: Optimize Everything
Mix.install([
{:nx, "~> 0.7.2"}
])
Learning with Optimization
Optimization is the search for the best
An ML system can be defined as any computing system that is capable of improving from experience on a specific task with respect to some arbitrary performance measure. The goal is to have a system that provides predictions about unseen examples
A basic architecture of such a system is:
def predict(input) do
label = do_something(input)
label
end
At its core, a trained ML system transforms inputs into labels. When given some unseen input X
, the model can predict its expected output Y
by following “the line”. These “lines” are nothing more than visualizations of two parameters: m
or slope of the “line”, and b
or the intercept of the “line”
Here input
is equivalent to X
, label
is equivalent to Y
, and m
and b
are the paremeters of the “line” drawn:
def predict(input, m, b) do
label = m * input + b
label
end
ML, in practice, revolves around finding input parameters that best transform inputs to labels. The form of these parameters might now always look like m
and b
, above is a simplified example of linear regression in two-dimensions
In practice, the forms of transformations and parameters in a ML algorithm will vary from problem to problem and typically have much higher dimensionality. A more practical and generalized view of the predict
function could be:
def predict(input) do
label = f(input, params)
label
end
Where f
represents the transformation performed by the ML algorithm of choice and params
represents the learned parameters for the algorithm. The goal is to find the parameters that best map inputs to labels with the function; Optimize params
to achieve the best performance for the given task
Minimizing Risk to Maximize Performance
Rather than optimizing for performance on an unobserved test set, optimize for performance on an available training set and hope the ML model captures enough of the features of the unseen test set that the model’s parameters are able to optimally assign labels to unseen inputs
During training a loss function or cost function is defined that is then optimized with respect to your model’s parameters. The parameterized loss function is the overall objective function. The relationship between model, parameters, loss function and objective function in an ML problems looks like:
@doc """
A function that uses `params` to transform `inputs`
into `labels`
"""
def model(params, inputs) do
labels = f(params, inputs)
labels
end
@doc """
A function that measures the difference between an
actual label and a predicted label
Gives indication or measure of correctness
"""
def loss(actual_labels, predicted_labels) do
loss_value = measure_diff(actual_labels, predicted_labels)
loss_value
end
@doc """
The loss function parameterized with the model
and parameters
"""
def objective(params, actual_inputs, actual_labels) do
predicted_labels = model(params, actual_inputs)
loss(actual_labels, predicted_labels)
end
ML is concerned with reducing generalization error on unseen inputs. In statistical theory, generalization error is also known as risk, and the model’s goal is to minimize risk. A model can’t optimize for rish when it does not have access to the entire distribution of possible inputs. Instead the model optimizes empirical risk, which is performance ont the empirical distribution or training set.
Minimizing empirical risk is almost equivalent to minimizing true risk. There are always sources of error when try to capture infinite distributions with finite data and computing power so they are never exactly equivalent.
The process of minimize emperical risk is known as empirical risk minimization (ERM)
Defining Objectives
ML algorithms almost never directly optimize for the performance measures of the problem being solved. Instead, ML algorithms use surrogate loss functions
Surrogate loss functions serve as proxies for true objectives. They are typically easier functions to optimize, and optimizing them typically also indirectly optimizes the performance objectives of the problem to be solved. The choice of a surrogate loss often depends on the ML task and type of model
Likelihood Estimation
The likelihood function describes the probability of observed data as a function of its parameters. A loss function that is a measure of similarity between two functions is an optimization process known as maximum likelihood estimation (MLE)
Cross-entropy
Cross-entropy as a loss function is typically used in the context of classification tasks
defn binary_cross_entropy(y_true, y_pred) do
y_true * Nx.log(y_pred) - (1 - y_true) * Nx.log(1 - y_pred)
end
Mean Squared Error
Mean squared error measures per-example loss as the average squared difference between true labels and predicted labels
defn mean_squared_error(y_true, y_pred) do
y_true
|> Nx.subtract(y_pred)
|> Nx.pow(2)
|> Nx.mean(axes: -1)
end
Regularize to Generalize
The objective of training an ML model is generalization
Given two models:
-
A
performs noticeably worse on the training set thanB
-
A
performs noticeably better on the test set thanB
Which model is preferred?
Given the primary objective is performance on unseen data, B
is the preferred choice from a model performance perspective. But, why does a model perform noticeably worse on the training set while the other performs noticeably better on the test set? The answer is regularization.
Overfitting, Underfitting and Capacity
Overfitting is a scenario which a trained model has low training error but high generalization error. Both ERM and MLE optimize model parameters with respect to errors on a training set, which means they are explicitly fitting functions to match training data. Both techniques are prone to overfitting
Underfitting is a scenario which a model doesn’t even have a low training error
Both overfitting and underfitting are typically functions of a model’s capacity. A model’s capacity is its ability to fit many different functions which are chosen while designing an ML model. This set of functions is known as the hypothesis space of functions
Defining Regularization
Regularization is any technique used to combat overfitting, more generally, any technique used to reduce generalization error
Since ERM and MLE are prone to overfitting, regularization is a necessary step for many ML algorithms
Complexity Penalties
Complexity penalties are a commonly used regularizer for training ML models that generalize. They impose a cost at model evaluation time by adding a penalty term with some penalty weight. Weight decay is a common regularization penalty that introduces a penalty term equal to the L2-Norm of the model’s parameters. L2-Norm can be interpretted as “distance from the origin”
Weight-decay expresses the preference for smaller model weights. Mathematically, it constrains the feasible parameter space of a given model to those that lie closer to the origin. Intuitively, weight-decay can be thought of as penalizing a model that gets too confident in particular weights
Early-stopping
Early-stopping is a regularizer that stops model training if overfitting is detected. It’s not possible to perfectly detect overfitting, so the typical approach for monitoring overfitting is with a validation set
Validation set are portions of the original training data that are not used to train but instead used to periodically monitor model performance
Descending Gradients
Gradient descent is an iterative optimization routine that uses the gradients of a function evaluated at a particular point to minimize a particular function
Implementing Stochastic Gradient Descent with Nx
is simple thanks to it’s automatic differentiation capabilities
key = Nx.Random.key(42)
#Nx.Tensor<
u32[2]
[0, 42]
>
{true_params, new_key} = Nx.Random.uniform(key, shape: {32, 1})
true_function = fn params, x ->
Nx.dot(x, params)
end
#Function<41.105768164/2 in :erl_eval.expr/6>
{train_x, new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
train_y = true_function.(true_params, train_x)
train_data = Enum.zip(Nx.to_batched(train_x, 1), Nx.to_batched(train_y, 1))
[
{#Nx.Tensor<
f32[1][32]
[
[0.21419739723205566, 0.4202451705932617, 0.9174779653549194, 0.6951805353164673, 0.5770437717437744, 0.015613555908203125, 0.9820147752761841, 0.12217259407043457, 0.6313685178756714, 0.7669196128845215, 0.16057193279266357, 0.9078190326690674, 0.23667335510253906, 0.6146631240844727, 0.12274301052093506, 0.09529983997344971, 0.01969766616821289, 0.920687198638916, 0.6566638946533203, 0.8590985536575317, 0.336417555809021, 0.8502820730209351, 0.40605008602142334, 7.408857345581055e-4, 0.19105935096740723, 0.0871502161026001, 0.35579657554626465, 0.0827188491821289, 0.30659806728363037, 0.8906278610229492, 0.21179890632629395, 0.23324048519134521]
]
>,
#Nx.Tensor<
f32[1][1]
[
[4.911402702331543]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.5646347999572754, 0.9963592290878296, 0.3351341485977173, 0.21977758407592773, 0.5122578144073486, 0.6533946990966797, 0.46478259563446045, 0.025933027267456055, 0.23669421672821045, 0.11027109622955322, 0.6683825254440308, 0.8013483285903931, 0.1865140199661255, 0.8466619253158569, 0.9660632610321045, 0.6956028938293457, 0.9352200031280518, 0.3796454668045044, 0.06283676624298096, 0.9677258729934692, 0.782143235206604, 0.9192811250686646, 0.07585024833679199, 0.5071550607681274, 0.19996094703674316, 0.9129003286361694, 0.9038044214248657, 0.7150691747665405, 0.8976137638092041, 0.2711818218231201, 0.5959255695343018, 0.14550673961639404]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.325557708740234]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.2890130281448364, 0.3377346992492676, 0.3301658630371094, 0.1233062744140625, 0.20411169528961182, 0.28748345375061035, 0.6112610101699829, 0.20312929153442383, 0.2255460023880005, 0.4090055227279663, 0.3722764253616333, 0.920012354850769, 0.9858261346817017, 0.3830984830856323, 0.11448848247528076, 0.8428031206130981, 0.7443640232086182, 0.024654746055603027, 0.7059779167175293, 0.4974362850189209, 0.04586231708526611, 0.5795108079910278, 0.4304255247116089, 0.9775724411010742, 0.468021035194397, 0.48742222785949707, 0.955541729927063, 0.7387502193450928, 0.6655933856964111, 0.8951784372329712, 0.5373797416687012, 0.5804636478424072]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.946480751037598]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.028279900550842285, 0.6395993232727051, 0.031375885009765625, 0.7024474143981934, 0.5130670070648193, 0.1737067699432373, 0.10613739490509033, 0.40717077255249023, 0.047069668769836426, 0.867316722869873, 0.9732393026351929, 0.1491994857788086, 0.14715611934661865, 0.43631136417388916, 0.15817761421203613, 0.4147731065750122, 0.7241733074188232, 0.25670480728149414, 0.014826536178588867, 0.16278398036956787, 0.734898567199707, 0.516982913017273, 0.859603762626648, 0.7494732141494751, 0.8663722276687622, 0.26246213912963867, 0.1697627305984497, 0.8041620254516602, 0.23243749141693115, 0.2765010595321655, 0.019356846809387207, 0.02446436882019043]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.697136402130127]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6429306268692017, 0.8686245679855347, 0.9968123435974121, 0.8627218008041382, 0.5130723714828491, 0.6576781272888184, 0.5033124685287476, 0.8946937322616577, 0.5216338634490967, 0.7755841016769409, 0.3903813362121582, 0.610281229019165, 0.3125624656677246, 0.25667572021484375, 0.32864129543304443, 0.6825462579727173, 0.27321064472198486, 0.5848970413208008, 0.4000225067138672, 0.9944697618484497, 0.007418990135192871, 0.3305246829986572, 0.04195880889892578, 0.8209608793258667, 0.025928139686584473, 0.22786939144134521, 0.939563512802124, 0.8292726278305054, 0.8125652074813843, 0.08058571815490723, 0.8242276906967163, 0.9003914594650269]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.497109889984131]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9033777713775635, 0.010692119598388672, 0.9856387376785278, 0.28001368045806885, 0.12742769718170166, 0.056987643241882324, 0.6563032865524292, 0.015868425369262695, 0.3580033779144287, 0.6868652105331421, 0.8388558626174927, 0.3083916902542114, 0.585604190826416, 0.5165541172027588, 0.010527729988098145, 0.5832229852676392, 0.54107666015625, 0.710890531539917, 0.4708055257797241, 0.04568064212799072, 0.7684018611907959, 0.8329710960388184, 0.6369191408157349, 0.1535571813583374, 0.20751547813415527, 0.9164774417877197, 0.6013720035552979, 0.6344060897827148, 0.7668310403823853, 0.3542180061340332, 0.44900357723236084, 0.040746450424194336]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.123233318328857]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.731298565864563, 0.24785888195037842, 0.011466622352600098, 0.1112598180770874, 0.03377127647399902, 0.280226469039917, 0.5953247547149658, 0.739460825920105, 0.8142504692077637, 0.6693843603134155, 0.5555098056793213, 0.28535640239715576, 0.2316725254058838, 0.5937529802322388, 0.6147925853729248, 0.5925025939941406, 0.7417064905166626, 0.6948648691177368, 0.001192927360534668, 0.3013712167739868, 0.06051337718963623, 0.7015379667282104, 0.4316558837890625, 0.8253896236419678, 0.2610088586807251, 0.7366442680358887, 0.25142836570739746, 0.7319520711898804, 0.6824352741241455, 0.11031544208526611, 0.8072546720504761, 0.13371288776397705]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.368332386016846]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.06457006931304932, 0.5491797924041748, 0.9626376628875732, 0.42310619354248047, 0.4333592653274536, 0.6147439479827881, 0.3320721387863159, 0.37261414527893066, 0.3961852788925171, 0.47902143001556396, 0.36530065536499023, 0.2080230712890625, 0.8404935598373413, 0.8675642013549805, 0.695580244064331, 0.3153465986251831, 0.22675347328186035, 0.592113733291626, 0.4637026786804199, 0.8686285018920898, 0.10617625713348389, 0.7399885654449463, 0.6763122081756592, 0.8995020389556885, 0.536932110786438, 0.9256490468978882, 0.25525403022766113, 0.7603592872619629, 0.45229804515838623, 0.5522929430007935, 0.45831072330474854, 0.9319932460784912]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.7084126472473145]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9271001815795898, 0.3429156541824341, 0.09655606746673584, 0.09466707706451416, 0.9006072282791138, 0.2138385772705078, 0.5151525735855103, 0.7722166776657104, 0.6935923099517822, 0.4172854423522949, 0.6790682077407837, 0.7691861391067505, 0.6899746656417847, 0.4150722026824951, 0.289791464805603, 0.6220169067382812, 0.29761624336242676, 0.7116636037826538, 0.5838549137115479, 0.988337516784668, 0.06846141815185547, 0.20311403274536133, 0.9452928304672241, 0.08352434635162354, 0.4954712390899658, 0.7736111879348755, 0.6713893413543701, 0.43488168716430664, 0.05077075958251953, 0.17744314670562744, 0.15465223789215088, 0.8547592163085938]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.840044021606445]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.06017613410949707, 0.7573988437652588, 0.2671293020248413, 0.377036452293396, 0.15488910675048828, 0.8024213314056396, 0.44861674308776855, 0.3249342441558838, 0.9163384437561035, 0.05472278594970703, 0.4368704557418823, 0.7114390134811401, 0.050534725189208984, 0.5101956129074097, 0.6751821041107178, 0.8379642963409424, 0.21117258071899414, 0.623075008392334, 0.6573165655136108, 0.008363127708435059, 0.8689806461334229, 0.7193349599838257, 0.37983238697052, 0.19805502891540527, 0.5923962593078613, 0.9071087837219238, 0.8148437738418579, 0.13653087615966797, 0.29834258556365967, 0.2331470251083374, 0.390960693359375, 0.6359344720840454]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.279651641845703]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.540674090385437, 0.8634153604507446, 0.6930928230285645, 0.5184097290039062, 0.8631595373153687, 0.9299262762069702, 0.603912353515625, 0.45213890075683594, 0.2887920141220093, 0.02929389476776123, 0.782915472984314, 0.28684449195861816, 0.13385224342346191, 0.12771594524383545, 0.874720573425293, 0.24373912811279297, 0.5483328104019165, 0.30896270275115967, 0.06717348098754883, 0.5742068290710449, 0.2770789861679077, 0.01426839828491211, 0.8000349998474121, 0.51806640625, 0.9243561029434204, 0.8260501623153687, 0.4912137985229492, 0.8073453903198242, 0.017646431922912598, 0.7417961359024048, 0.5518542528152466, 0.6329332590103149]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.279178619384766]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7634663581848145, 0.07522153854370117, 0.12503480911254883, 0.995384931564331, 0.6072309017181396, 0.12540876865386963, 0.9362881183624268, 0.030733466148376465, 0.6750000715255737, 0.5595276355743408, 0.8576765060424805, 0.7870781421661377, 0.15496611595153809, 0.4225485324859619, 0.5202906131744385, 0.702318549156189, 0.9370477199554443, 0.319008469581604, 0.15663254261016846, 0.1396721601486206, 0.5332878828048706, 0.4656214714050293, 0.12286245822906494, 0.18269526958465576, 0.6548546552658081, 0.873659610748291, 0.9572122097015381, 0.7902345657348633, 0.588735818862915, 0.770898699760437, 0.08010900020599365, 0.9774599075317383]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.251378059387207]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.5733932256698608, 0.28852391242980957, 0.05642271041870117, 0.19096148014068604, 0.8746805191040039, 0.9121378660202026, 0.33812904357910156, 0.5280702114105225, 0.037364959716796875, 0.19046878814697266, 0.987305760383606, 0.5207898616790771, 0.7191983461380005, 0.3039665222167969, 0.10663533210754395, 0.8956723213195801, 0.8797550201416016, 0.9881069660186768, 0.07668650150299072, 0.2546120882034302, 0.7153855562210083, 0.7115358114242554, 0.45427632331848145, 0.8423178195953369, 0.015808701515197754, 0.0854419469833374, 0.22367537021636963, 0.2679309844970703, 0.7769159078598022, 0.7780354022979736, 0.24312353134155273, 0.44626355171203613]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.275750160217285]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.34603703022003174, 0.6236165761947632, 0.3351191282272339, 0.10127341747283936, 0.5319764614105225, 0.5452526807785034, 0.45823919773101807, 0.1731433868408203, 0.920047402381897, 0.6591004133224487, 0.3959653377532959, 0.8915334939956665, 0.37920379638671875, 0.7517484426498413, 0.896480917930603, 0.3786351680755615, 0.47003233432769775, 0.8744161128997803, 0.00768125057220459, 0.7991448640823364, 0.7198724746704102, 0.5535911321640015, 0.048066139221191406, 0.06627070903778076, 0.49426960945129395, 0.6911654472351074, 0.9227865934371948, 0.23994970321655273, 0.8584533929824829, 0.29686975479125977, 0.08083772659301758, 0.8508659601211548]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.358133792877197]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.25897514820098877, 0.39639222621917725, 0.9325405359268188, 0.5915662050247192, 0.019397377967834473, 0.934141993522644, 0.4311712980270386, 0.6546499729156494, 0.6923919916152954, 0.01774585247039795, 0.7375062704086304, 0.7418352365493774, 0.30239737033843994, 0.22865331172943115, 0.44480597972869873, 0.5696520805358887, 0.8188011646270752, 0.6122468709945679, 0.6789653301239014, 0.1875927448272705, 0.39335906505584717, 0.6955385208129883, 0.8845469951629639, 0.3688088655471802, 0.30924367904663086, 0.011875391006469727, 0.34288227558135986, 0.1975114345550537, 0.8599079847335815, 0.24817359447479248, 0.6347780227661133, 0.4112318754196167]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.487173557281494]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.22403907775878906, 0.08259356021881104, 0.45466434955596924, 0.3757232427597046, 0.23307907581329346, 0.06618845462799072, 0.07642745971679688, 0.37379252910614014, 0.1926511526107788, 0.32429349422454834, 0.12091898918151855, 0.9498410224914551, 0.5374590158462524, 0.20623505115509033, 0.5750828981399536, 0.4643818140029907, 0.715045690536499, 0.25955915451049805, 0.7587361335754395, 0.7317408323287964, 0.9211848974227905, 0.6531050205230713, 0.8654029369354248, 0.27620232105255127, 0.5775798559188843, 0.493998646736145, 0.6039055585861206, 0.39140546321868896, 0.11413073539733887, 0.23982369899749756, 0.44249284267425537, 0.5829825401306152]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.244842529296875]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8709145784378052, 0.39750826358795166, 0.8718006610870361, 0.7330857515335083, 0.0794292688369751, 0.39753711223602295, 0.9122143983840942, 0.22768986225128174, 0.8776557445526123, 0.335676908493042, 0.19429230690002441, 0.6122128963470459, 0.7476933002471924, 0.9513078927993774, 0.018947124481201172, 0.6470063924789429, 0.2292841672897339, 0.4042837619781494, 0.8632285594940186, 0.28368353843688965, 0.8194841146469116, 0.6801226139068604, 0.1277766227722168, 0.34661781787872314, 0.8841997385025024, 0.6553384065628052, 0.6167296171188354, 0.5355550050735474, 0.1543673276901245, 0.8647083044052124, 0.3289914131164551, 0.3924553394317627]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.822877407073975]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.631003737449646, 0.9914751052856445, 0.5126968622207642, 0.7870519161224365, 0.43065476417541504, 0.0041075944900512695, 0.4931894540786743, 0.7177931070327759, 0.8114221096038818, 0.579695463180542, 0.5687620639801025, 0.8791037797927856, 0.04205179214477539, 0.3524487018585205, 0.4811859130859375, 0.9353959560394287, 0.16603446006774902, 0.13031387329101562, 0.13674676418304443, 0.18401944637298584, 0.3514493703842163, 0.453141450881958, 0.36090385913848877, 0.4975547790527344, 0.9103914499282837, 0.8617379665374756, 0.7500883340835571, 0.057416319847106934, 0.222234845161438, 0.23873364925384521, 0.9162167310714722, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.379029273986816]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7609003782272339, 0.8510605096817017, 0.41890597343444824, 0.9356156587600708, 0.004058837890625, 0.6882948875427246, 0.777018666267395, 0.4773087501525879, 0.32202649116516113, 0.4344886541366577, 0.6230182647705078, 0.3931225538253784, 0.5879846811294556, 0.9389195442199707, 0.654564380645752, 0.5673922300338745, 0.17179429531097412, 0.19027507305145264, 0.6757367849349976, 0.05374908447265625, 0.6485933065414429, 0.9585849046707153, 0.0642237663269043, 0.893607497215271, 0.27879488468170166, 0.6041398048400879, 0.3211805820465088, 0.5789915323257446, 0.8080614805221558, 0.9113047122955322, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.914783954620361]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6569708585739136, 0.07861983776092529, 0.23802471160888672, 0.12751150131225586, 0.32677197456359863, 0.33022868633270264, 0.22364568710327148, 0.09476423263549805, 0.5992125272750854, 0.07833504676818848, 0.09313857555389404, 0.26847267150878906, 0.4497326612472534, 0.5623631477355957, 0.3078885078430176, 0.7772209644317627, 0.6664516925811768, 0.5849689245223999, 0.7815028429031372, 0.3451420068740845, 0.715222954750061, 0.48974764347076416, 0.3865237236022949, 0.294453501701355, 0.8022323846817017, 0.2440582513809204, 0.0838627815246582, 0.05519282817840576, 0.4562106132507324, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.969756603240967]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.42634856700897217, 0.8311656713485718, 0.27419185638427734, 0.3551710844039917, 0.3151041269302368, 0.7997424602508545, 0.9024616479873657, 0.7667425870895386, 0.41606831550598145, 0.4073857069015503, 0.17675912380218506, 0.2852247953414917, 0.4539358615875244, 0.16290879249572754, 0.8755624294281006, 0.41856563091278076, 0.056872010231018066, 0.9749754667282104, 0.4539651870727539, 0.5548393726348877, 0.5711885690689087, 0.9489098787307739, 0.04192376136779785, 0.4533880949020386, 0.6346640586853027, 0.5504379272460938, 0.3365325927734375, 0.9165209531784058, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.136617660522461]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.5641757249832153, 0.9077368974685669, 0.851854681968689, 0.6558282375335693, 0.6125811338424683, 0.8124090433120728, 0.9702218770980835, 0.02779090404510498, 0.7319562435150146, 0.5420526266098022, 0.6218849420547485, 0.9008104801177979, 0.1349804401397705, 0.7537416219711304, 0.5223276615142822, 0.9131952524185181, 0.9375765323638916, 0.35538744926452637, 0.9191871881484985, 0.536587119102478, 0.9227880239486694, 0.1520380973815918, 0.8761709928512573, 0.24608349800109863, 0.8468005657196045, 0.7329789400100708, 0.6444404125213623, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[9.924013137817383]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6698530912399292, 0.5241881608963013, 0.778433084487915, 0.5185062885284424, 0.20008862018585205, 0.2378925085067749, 0.16121220588684082, 0.4850255250930786, 0.3773798942565918, 0.5570013523101807, 0.38140594959259033, 0.7139229774475098, 0.4059755802154541, 0.09518742561340332, 0.3535490036010742, 0.17777550220489502, 0.21837365627288818, 0.8693867921829224, 0.6186724901199341, 0.19334256649017334, 0.9257535934448242, 0.3667900562286377, 0.7642225027084351, 0.016174674034118652, 0.938443660736084, 0.6296350955963135, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.076970100402832]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9571073055267334, 0.4434546232223511, 0.8838400840759277, 0.6176793575286865, 0.7852997779846191, 0.6787863969802856, 0.504677414894104, 0.43745851516723633, 0.7717669010162354, 0.6644160747528076, 0.8086050748825073, 0.154152512550354, 0.9943249225616455, 0.6983662843704224, 0.22040307521820068, 0.5256414413452148, 0.6795828342437744, 0.8989334106445312, 0.7875980138778687, 0.14916563034057617, 0.6334056854248047, 0.07992613315582275, 0.983758807182312, 0.3881075382232666, 0.26950299739837646, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.411914825439453]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.30753517150878906, 0.44749248027801514, 0.8855024576187134, 0.5025475025177002, 0.3357102870941162, 0.10730814933776855, 0.9354255199432373, 0.36325860023498535, 0.06997525691986084, 0.7279484272003174, 0.3918635845184326, 0.2349175214767456, 0.9933228492736816, 0.10251474380493164, 0.2719736099243164, 0.9983277320861816, 0.5580542087554932, 0.6940358877182007, 0.3528175354003906, 0.11204969882965088, 0.7471216917037964, 0.281843900680542, 0.584757924079895, 0.4283181428909302, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.507858276367188]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6248239278793335, 0.22014057636260986, 0.7038096189498901, 0.9967999458312988, 0.9672704935073853, 0.8332101106643677, 0.6936581134796143, 0.3075512647628784, 0.6806656122207642, 0.5613412857055664, 0.8672329187393188, 0.7468259334564209, 0.1101677417755127, 0.711369514465332, 0.20896148681640625, 0.4358783960342407, 0.10943818092346191, 0.40923237800598145, 0.28446221351623535, 0.058237552642822266, 0.8432282209396362, 0.6778583526611328, 0.23971843719482422, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.675758361816406]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.3962472677230835, 0.27139055728912354, 0.5323368310928345, 0.6035983562469482, 0.2958228588104248, 0.4986027479171753, 0.09223330020904541, 0.20876073837280273, 0.96107017993927, 0.9156640768051147, 0.6105798482894897, 0.7501920461654663, 0.3282437324523926, 0.5214126110076904, 0.9571560621261597, 0.18034577369689941, 0.5769946575164795, 0.005969643592834473, 0.17204713821411133, 0.4094083309173584, 0.12268543243408203, 0.7253572940826416, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.932293891906738]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6804090738296509, 0.8428153991699219, 0.39075136184692383, 0.7502392530441284, 0.6528669595718384, 0.763106107711792, 0.7931734323501587, 0.8093459606170654, 0.0921180248260498, 0.17537617683410645, 0.18979740142822266, 0.04392850399017334, 0.2329108715057373, 0.46092772483825684, 0.11988234519958496, 0.876970648765564, 0.5901447534561157, 0.22900879383087158, 0.29504191875457764, 0.29355132579803467, 0.5967469215393066, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.5610127449035645]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.18611860275268555, 0.6831626892089844, 0.9850664138793945, 0.5827019214630127, 0.2757478952407837, 0.7032021284103394, 0.6382184028625488, 0.26853466033935547, 0.759202241897583, 0.2775331735610962, 0.7704559564590454, 0.17358756065368652, 0.13375437259674072, 0.7020305395126343, 0.3446732759475708, 0.04839301109313965, 0.10832452774047852, 0.9177273511886597, 0.42168760299682617, 0.7985275983810425, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.907943248748779]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8724111318588257, 0.9060395956039429, 0.663974404335022, 0.4417393207550049, 0.3604276180267334, 0.14726388454437256, 0.9536886215209961, 0.6962372064590454, 0.6357282400131226, 0.7485288381576538, 0.40242207050323486, 0.9402639865875244, 0.9245424270629883, 0.44496726989746094, 0.9414383172988892, 0.4424762725830078, 0.398421049118042, 0.32365882396698, 0.49386703968048096, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.954349994659424]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.33178138732910156, 0.9648864269256592, 0.23880553245544434, 0.5952130556106567, 0.34877192974090576, 0.45050501823425293, 0.39850425720214844, 0.910041093826294, 0.5629657506942749, 0.3974491357803345, 0.4866042137145996, 0.5799317359924316, 0.09559011459350586, 0.9947899580001831, 0.45065581798553467, 0.8340728282928467, 0.8428640365600586, 0.5117822885513306, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.702505588531494]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.3027317523956299, 0.7090158462524414, 0.4645681381225586, 0.9297244548797607, 0.4892878532409668, 0.8504208326339722, 0.6548589468002319, 0.7956224679946899, 0.955285906791687, 0.8071571588516235, 0.010240316390991211, 0.8046956062316895, 0.7212648391723633, 0.36670148372650146, 0.724218487739563, 0.31477439403533936, 0.23347866535186768, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.32716178894043]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8743611574172974, 0.6752761602401733, 0.07919669151306152, 0.7560880184173584, 0.3694124221801758, 0.4278639554977417, 0.30996274948120117, 0.06689155101776123, 0.755111813545227, 0.4953286647796631, 0.3981630802154541, 0.9670692682266235, 0.29500603675842285, 0.10872447490692139, 0.7086962461471558, 0.8519164323806763, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.5628581047058105]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8509676456451416, 0.1929837465286255, 0.8952206373214722, 0.054573774337768555, 0.2884252071380615, 0.11869692802429199, 0.1865607500076294, 0.9117492437362671, 0.18362164497375488, 0.7144087553024292, 0.25691163539886475, 0.469804048538208, 0.5537645816802979, 0.8629399538040161, 0.7328643798828125, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.008695602416992]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.50124192237854, 0.7099924087524414, 0.39429163932800293, 0.30667269229888916, 0.805189847946167, 0.47770798206329346, 0.48304665088653564, 0.08487999439239502, 0.5447455644607544, 0.27381670475006104, 0.45372307300567627, 0.9596720933914185, 0.2775808572769165, 0.17481279373168945, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.747370719909668]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.4664132595062256, 0.8956135511398315, 0.7649083137512207, 0.8003439903259277, 0.40910804271698, 0.03284120559692383, 0.1743713617324829, 0.8939645290374756, 0.5393245220184326, 0.6551700830459595, 0.14191937446594238, 0.30531764030456543, 0.8912879228591919, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.560020446777344]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.1197669506072998, 0.07636654376983643, 0.2435610294342041, 0.09590327739715576, 0.505244255065918, 0.1342778205871582, 0.45705747604370117, 0.45460188388824463, 0.2514047622680664, 0.8220257759094238, 0.7251520156860352, 0.14361393451690674, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.024771213531494]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6785858869552612, 0.4399087429046631, 0.5257576704025269, 0.5514159202575684, 0.3694615364074707, 0.8350393772125244, 0.620568037033081, 0.17208480834960938, 0.13299310207366943, 0.2517954111099243, 0.2774406671524048, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.137871265411377]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9986991882324219, 0.5269997119903564, 0.7744612693786621, 0.23489665985107422, 0.1424393653869629, 0.691918134689331, 0.3357495069503784, 0.9809718132019043, 0.17029345035552979, 0.7596328258514404, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.278748512268066]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.37081921100616455, 0.6535693407058716, 0.011877059936523438, 0.9298231601715088, 0.36428725719451904, 0.6512700319290161, 0.7696858644485474, 0.19250214099884033, 0.47986137866973877, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.702021598815918]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7482943534851074, 0.7727749347686768, 0.1144859790802002, 0.7724059820175171, 0.6368030309677124, 0.8291548490524292, 0.313435435295105, 0.6373882293701172, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.593993186950684]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.39354753494262695, 0.6011625528335571, 0.21247148513793945, 0.2802027463912964, 0.3761352300643921, 0.11047065258026123, 0.1385045051574707, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.230188846588135]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8625725507736206, 0.20491087436676025, 0.09915947914123535, 0.12683987617492676, 0.5620198249816895, 0.13202905654907227, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.756672382354736]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.4777796268463135, 0.5366525650024414, 0.8587156534194946, 0.9605873823165894, 0.6869806051254272, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.652935981750488]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9411790370941162, 0.8997765779495239, 0.2766681909561157, 0.33345866203308105, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.9773478507995605]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9947642087936401, 0.2692502737045288, 0.9714730978012085, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.865835666656494]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7531495094299316, 0.8768380880355835, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.616130352020264]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.18010473251342773, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
...
]
>},
{#Nx.Tensor<
f32[1][32]
[
...
]
>, ...},
{...},
...
]
{test_x, _new_key} = Nx.Random.uniform(new_key, shape: {10000, 32})
test_y = true_function.(true_params, test_x)
test_data = Enum.zip(Nx.to_batched(test_x, 1), Nx.to_batched(test_y, 1))
[
{#Nx.Tensor<
f32[1][32]
[
[0.9657397270202637, 0.9266661405563354, 0.2524207830429077, 0.506806492805481, 0.03272294998168945, 0.6381621360778809, 0.4016733169555664, 0.4144333600997925, 0.8692346811294556, 0.19583988189697266, 0.4356701374053955, 0.037007689476013184, 0.4367654323577881, 0.9086041450500488, 0.4730778932571411, 0.29556596279144287, 0.49315857887268066, 0.3683987855911255, 0.8670364618301392, 0.527277946472168, 0.028360843658447266, 0.13743293285369873, 0.8709059953689575, 0.1861327886581421, 0.4181276559829712, 0.9427480697631836, 0.4339343309402466, 0.8707499504089355, 0.6826666593551636, 0.528895378112793, 0.17522680759429932, 0.4048128128051758]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.652977466583252]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.4392540454864502, 0.9165053367614746, 0.9777518510818481, 0.879123330116272, 0.612689733505249, 0.01696908473968506, 0.133436918258667, 0.4318392276763916, 0.5053318738937378, 0.7980244159698486, 0.1885296106338501, 0.9951480627059937, 0.3975728750228882, 0.226912260055542, 0.4825739860534668, 0.9671891927719116, 0.24038493633270264, 0.13231432437896729, 0.38793301582336426, 0.05815780162811279, 0.43374860286712646, 0.2860398292541504, 0.6426401138305664, 0.8966696262359619, 0.09666109085083008, 0.4394463300704956, 0.35843217372894287, 0.34688258171081543, 0.5460761785507202, 0.5041118860244751, 0.5477373600006104, 0.8824354410171509]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.90969705581665]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7762134075164795, 0.7822309732437134, 0.24949312210083008, 0.24509012699127197, 0.9004219770431519, 0.8151938915252686, 0.9005923271179199, 0.8640304803848267, 0.4731714725494385, 0.5921633243560791, 0.4887489080429077, 0.8375271558761597, 0.9577419757843018, 0.5891522169113159, 0.12607717514038086, 0.708527684211731, 0.41328203678131104, 0.6296629905700684, 0.6268273591995239, 0.35883355140686035, 0.36125707626342773, 0.6910197734832764, 0.7902359962463379, 0.7439805269241333, 0.4775749444961548, 0.9078165292739868, 0.3568282127380371, 0.15519630908966064, 0.11200845241546631, 0.7795575857162476, 0.468631386756897, 0.9759647846221924]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.54786205291748]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.3197380304336548, 0.17616689205169678, 0.8258638381958008, 0.3432091474533081, 0.35468900203704834, 0.5186667442321777, 0.7499172687530518, 0.4087836742401123, 0.4280334711074829, 0.2278900146484375, 0.6885714530944824, 0.6648629903793335, 0.5448073148727417, 0.1430720090866089, 0.842303991317749, 0.8900002241134644, 0.4492759704589844, 0.8455078601837158, 0.4587341547012329, 0.3691824674606323, 0.2542390823364258, 0.871134877204895, 0.26322853565216064, 0.10538768768310547, 0.355352520942688, 0.8888055086135864, 0.488552451133728, 0.6250888109207153, 0.9855941534042358, 0.738310694694519, 0.6712085008621216, 0.04661083221435547]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.220166206359863]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.40412354469299316, 0.7216566801071167, 0.25867438316345215, 0.9800392389297485, 0.9075496196746826, 0.6819312572479248, 0.0149993896484375, 0.04356968402862549, 0.9605370759963989, 0.02644956111907959, 0.28210699558258057, 0.7568575143814087, 0.21303772926330566, 0.002684950828552246, 0.6519417762756348, 0.28664088249206543, 0.15737569332122803, 0.37507736682891846, 0.05415797233581543, 0.03802788257598877, 0.8071837425231934, 0.06110048294067383, 0.6388435363769531, 0.44481122493743896, 0.23555970191955566, 0.61528480052948, 0.8113986253738403, 0.012137651443481445, 0.9276052713394165, 0.9450554847717285, 0.9840184450149536, 0.20820486545562744]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.5124430656433105]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.5034708976745605, 0.7755038738250732, 0.13867413997650146, 0.29906952381134033, 0.014742374420166016, 0.7755328416824341, 0.9173959493637085, 0.0935053825378418, 0.31686699390411377, 0.06115245819091797, 0.8989229202270508, 0.19432556629180908, 0.7501810789108276, 0.2113250494003296, 0.5822068452835083, 0.6005638837814331, 0.625515341758728, 0.6752986907958984, 0.9507982730865479, 0.7879356145858765, 0.5397478342056274, 0.3113539218902588, 0.8102543354034424, 0.2979027032852173, 0.7655726671218872, 0.42514193058013916, 0.09351170063018799, 0.8037655353546143, 0.4778313636779785, 0.44777703285217285, 0.3096567392349243, 0.33784306049346924]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.7442731857299805]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7518961429595947, 0.22414886951446533, 0.15714240074157715, 0.8663241863250732, 0.508256196975708, 0.30795371532440186, 0.11998486518859863, 0.18344223499298096, 0.4011112451553345, 0.924648642539978, 0.5058850049972534, 0.5193443298339844, 0.9716345071792603, 0.948397159576416, 0.5351895093917847, 0.5134536027908325, 0.6595708131790161, 0.06837213039398193, 0.05189085006713867, 0.8435298204421997, 0.7968239784240723, 0.12332558631896973, 0.7250438928604126, 0.7147141695022583, 0.8842874765396118, 0.9462226629257202, 0.843963623046875, 0.8965095281600952, 0.4305756092071533, 0.5930991172790527, 0.11764276027679443, 0.5833710432052612]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.623931884765625]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9475998878479004, 0.6403565406799316, 0.5817118883132935, 0.5467712879180908, 0.8543612957000732, 0.06530475616455078, 0.14756488800048828, 0.15206146240234375, 0.68489670753479, 0.932651162147522, 0.9179636240005493, 0.8796118497848511, 0.2882128953933716, 0.7526837587356567, 0.1426788568496704, 0.18050217628479004, 0.6951268911361694, 0.7308511734008789, 0.6911174058914185, 0.19187331199645996, 0.925081729888916, 0.8188349008560181, 0.5788781642913818, 0.33968937397003174, 0.8412926197052002, 0.50633704662323, 0.40607786178588867, 0.39345502853393555, 0.9535032510757446, 0.0635685920715332, 0.7170870304107666, 0.8757264614105225]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.50348949432373]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.01942753791809082, 0.899272084236145, 0.5304620265960693, 0.5620572566986084, 0.128362774848938, 0.31026554107666016, 0.6900253295898438, 0.2783416509628296, 0.004452347755432129, 0.5778182744979858, 0.026953697204589844, 0.014599919319152832, 0.3131972551345825, 0.6139227151870728, 0.6718645095825195, 0.9182217121124268, 0.4055975675582886, 0.9959360361099243, 0.3222285509109497, 0.1344226598739624, 0.8531616926193237, 0.1252962350845337, 0.7893067598342896, 0.6823188066482544, 0.38434433937072754, 0.0016857385635375977, 0.9079246520996094, 0.33411502838134766, 0.05022597312927246, 0.5846171379089355, 0.889033317565918, 0.7293587923049927]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.110525608062744]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8360906839370728, 0.4355432987213135, 0.11632812023162842, 0.7702548503875732, 0.24256396293640137, 0.36099421977996826, 0.6917792558670044, 0.288962721824646, 0.6243956089019775, 0.9861946105957031, 0.3847709894180298, 0.6880143880844116, 0.589323878288269, 0.4923354387283325, 0.3279656171798706, 0.4151395559310913, 0.852401614189148, 0.0718458890914917, 0.01529836654663086, 0.06954300403594971, 0.7971522808074951, 0.7249754667282104, 0.25757861137390137, 0.906819224357605, 0.6608389616012573, 0.40988433361053467, 0.26649951934814453, 0.6497167348861694, 0.31986987590789795, 0.8541487455368042, 0.7966134548187256, 0.23020529747009277]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.6582722663879395]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7711091041564941, 0.6362074613571167, 0.8661282062530518, 0.6193832159042358, 0.6161632537841797, 0.19212019443511963, 0.7516170740127563, 0.023564815521240234, 0.7833913564682007, 0.8175530433654785, 0.029859185218811035, 0.30578577518463135, 0.9423370361328125, 0.20194673538208008, 0.6542264223098755, 0.9779584407806396, 0.11775898933410645, 0.5317223072052002, 0.3922593593597412, 0.832879900932312, 0.657945990562439, 0.43512094020843506, 0.32924580574035645, 0.21120929718017578, 0.76695716381073, 0.9446995258331299, 0.02226400375366211, 0.8510793447494507, 0.06922507286071777, 0.03070998191833496, 0.9929032325744629, 0.6356418132781982]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.539000988006592]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.5236676931381226, 0.2989940643310547, 0.03134489059448242, 0.12583446502685547, 0.042815566062927246, 0.8364819288253784, 0.31674015522003174, 0.04060947895050049, 0.05521893501281738, 0.7784453630447388, 0.48672759532928467, 0.8186780214309692, 0.00807809829711914, 0.39795219898223877, 0.22978472709655762, 0.5791110992431641, 0.6117820739746094, 0.7412447929382324, 0.42317402362823486, 0.28765225410461426, 0.36166059970855713, 0.5173482894897461, 0.9059319496154785, 0.3208935260772705, 0.3955960273742676, 0.5770881175994873, 0.963921308517456, 0.05305802822113037, 0.009126543998718262, 0.30502188205718994, 0.348180890083313, 0.28527331352233887]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.539951324462891]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7139891386032104, 0.12666535377502441, 0.15333807468414307, 0.7544382810592651, 0.9695196151733398, 0.49136245250701904, 0.20070266723632812, 0.005652427673339844, 0.02549123764038086, 0.3883492946624756, 0.7958550453186035, 0.4632751941680908, 0.14279115200042725, 0.5663028955459595, 0.31234872341156006, 0.6877082586288452, 0.04052889347076416, 0.19631445407867432, 0.8272514343261719, 0.7589792013168335, 0.8727586269378662, 0.9460961818695068, 0.7840994596481323, 0.1846456527709961, 0.7626980543136597, 0.5093346834182739, 0.5205307006835938, 0.2435612678527832, 0.4535341262817383, 0.3754945993423462, 0.9493304491043091, 0.05621170997619629]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.268062591552734]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8532519340515137, 0.6970062255859375, 0.652370810508728, 0.4061359167098999, 0.11044573783874512, 0.11151957511901855, 0.851331353187561, 0.6565314531326294, 0.33859121799468994, 0.7652009725570679, 0.3588383197784424, 0.07348513603210449, 0.7815285921096802, 0.9533407688140869, 0.8006638288497925, 0.04949069023132324, 0.8293572664260864, 0.3746068477630615, 0.8676903247833252, 0.9169406890869141, 0.9336735010147095, 0.06596994400024414, 0.8225301504135132, 0.18987727165222168, 0.24470460414886475, 0.8587253093719482, 0.8066114187240601, 0.4743626117706299, 0.8888722658157349, 0.36300718784332275, 0.2819058895111084, 0.5664075613021851]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.505477428436279]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.3902559280395508, 0.14534735679626465, 0.09916174411773682, 0.7248067855834961, 0.20137739181518555, 0.6646915674209595, 0.0778660774230957, 0.8685482740402222, 0.641196608543396, 0.5231763124465942, 0.6376198530197144, 0.6526670455932617, 0.2163105010986328, 0.8063833713531494, 0.6756443977355957, 0.4447154998779297, 0.20969760417938232, 0.951002836227417, 0.045929908752441406, 0.17532849311828613, 0.9260181188583374, 0.5131326913833618, 0.30024540424346924, 0.3300440311431885, 0.9004764556884766, 0.8441228866577148, 0.9477831125259399, 0.6751878261566162, 0.4996030330657959, 0.2638866901397705, 0.7265427112579346, 0.49843263626098633]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.501368045806885]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8431416749954224, 0.20925045013427734, 0.882979154586792, 0.7653672695159912, 0.8408812284469604, 0.9956172704696655, 0.039161086082458496, 0.2890273332595825, 0.8148702383041382, 0.7094781398773193, 0.45367276668548584, 0.49341559410095215, 0.40397489070892334, 0.044689178466796875, 0.2746458053588867, 0.14785456657409668, 0.5764540433883667, 0.9384440183639526, 0.16802644729614258, 0.5668824911117554, 0.7575173377990723, 0.9617680311203003, 0.34545373916625977, 0.9809876680374146, 0.9966757297515869, 0.9557144641876221, 0.9793694019317627, 0.49138343334198, 0.327367901802063, 0.8423446416854858, 0.41049087047576904, 0.16183257102966309]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.983162879943848]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.22011446952819824, 0.34421753883361816, 0.6294939517974854, 0.5911561250686646, 0.3088088035583496, 0.23263812065124512, 0.3219066858291626, 0.03805994987487793, 0.1934577226638794, 0.9446452856063843, 0.9956920146942139, 0.5987362861633301, 0.26041269302368164, 0.8809354305267334, 0.768547534942627, 0.8348363637924194, 0.826107382774353, 0.6921907663345337, 0.046318769454956055, 0.5803121328353882, 0.9755550622940063, 0.917121410369873, 0.30008530616760254, 0.2571007013320923, 0.6167714595794678, 0.022228240966796875, 0.7143242359161377, 0.5102230310440063, 0.16165518760681152, 0.9506291151046753, 0.9326122999191284, 0.2996530532836914]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.267423152923584]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.38776493072509766, 0.7318899631500244, 0.1307905912399292, 0.09548139572143555, 0.8834151029586792, 0.25518178939819336, 0.8047440052032471, 0.5709458589553833, 0.022022604942321777, 0.8857442140579224, 0.802797794342041, 0.5964082479476929, 0.6515809297561646, 0.06845474243164062, 0.37185752391815186, 0.12856698036193848, 0.09193015098571777, 0.738864541053772, 0.43469250202178955, 0.7443429231643677, 0.016843795776367188, 0.13896071910858154, 0.9344608783721924, 0.04187476634979248, 0.021153926849365234, 0.5739061832427979, 0.11942613124847412, 0.6132626533508301, 0.5382595062255859, 0.9019054174423218, 0.3720097541809082, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.044473648071289]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.5642416477203369, 0.830094575881958, 0.11100232601165771, 0.38808906078338623, 0.2167491912841797, 0.9876217842102051, 0.8534290790557861, 0.38251447677612305, 0.9997782707214355, 0.5055009126663208, 0.836064338684082, 0.26873278617858887, 0.44322919845581055, 0.9741466045379639, 0.9353281259536743, 0.6532653570175171, 0.21104705333709717, 0.9035766124725342, 0.11548709869384766, 0.793843150138855, 0.591930627822876, 0.41485321521759033, 0.41184914112091064, 0.5477373600006104, 0.08824586868286133, 0.7526575326919556, 0.44268131256103516, 0.0763329267501831, 0.4934626817703247, 0.8778635263442993, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.992553234100342]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8263260126113892, 0.47554445266723633, 0.7326714992523193, 0.4069685935974121, 0.48123836517333984, 0.9155631065368652, 0.19115734100341797, 0.8787575960159302, 0.5277758836746216, 0.08595383167266846, 0.6884660720825195, 0.3847067356109619, 0.4597644805908203, 0.11427831649780273, 0.611096978187561, 0.17932391166687012, 0.046775102615356445, 0.15377891063690186, 0.17494094371795654, 0.6756585836410522, 0.6878424882888794, 0.23508989810943604, 0.795313835144043, 0.28050291538238525, 0.006268501281738281, 0.423947811126709, 0.3076256513595581, 0.712846040725708, 0.33599305152893066, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.1456146240234375]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.30186140537261963, 0.9873268604278564, 0.035573720932006836, 0.4155128002166748, 0.24097764492034912, 0.03020632266998291, 0.20567595958709717, 0.5887387990951538, 0.6643663644790649, 0.31537091732025146, 0.9205235242843628, 0.6104476451873779, 0.9955638647079468, 0.08851909637451172, 0.34864628314971924, 0.5111007690429688, 0.8216805458068848, 0.9719328880310059, 0.7817020416259766, 0.5537395477294922, 0.42271876335144043, 0.07025539875030518, 0.20106756687164307, 0.06972110271453857, 0.26150715351104736, 0.41637277603149414, 0.36588919162750244, 0.3534187078475952, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.348757266998291]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.8590803146362305, 0.5636343955993652, 0.03443634510040283, 0.5104682445526123, 0.606324315071106, 0.9271607398986816, 0.4498816728591919, 0.8567771911621094, 0.0051795244216918945, 0.3874478340148926, 0.1921311616897583, 0.6639413833618164, 0.46572935581207275, 0.8383623361587524, 0.9469438791275024, 0.020708560943603516, 0.15898573398590088, 0.944486141204834, 0.3688013553619385, 0.3565635681152344, 0.4661533832550049, 0.28654932975769043, 0.9991466999053955, 0.395352840423584, 0.1945810317993164, 0.37531542778015137, 0.34606099128723145, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.0145039558410645]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.40698301792144775, 0.39485931396484375, 0.7665799856185913, 0.45240330696105957, 0.3180277347564697, 0.6046972274780273, 0.27099859714508057, 0.32419121265411377, 0.15665650367736816, 0.8422553539276123, 0.17924022674560547, 0.9220238924026489, 0.6719664335250854, 0.3106119632720947, 0.7598171234130859, 0.41273176670074463, 0.24574947357177734, 0.4455568790435791, 0.7285884618759155, 0.5802567005157471, 0.7816479206085205, 0.43250608444213867, 0.7760515213012695, 0.9805769920349121, 0.48999035358428955, 0.6657071113586426, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.522810935974121]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6499532461166382, 0.33332526683807373, 0.18239235877990723, 0.28092002868652344, 0.6100766658782959, 0.8831055164337158, 0.5425313711166382, 0.46851158142089844, 0.0425875186920166, 0.4390711784362793, 0.24202609062194824, 0.4561119079589844, 0.30997657775878906, 0.8286688327789307, 0.9777672290802002, 0.7183065414428711, 0.994769811630249, 0.9684973955154419, 0.516169548034668, 0.7496813535690308, 0.1341181993484497, 0.23441553115844727, 0.4797624349594116, 0.7527389526367188, 0.706782341003418, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.673696041107178]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.29753637313842773, 0.9054058790206909, 0.5384794473648071, 0.008930683135986328, 0.9904316663742065, 0.012788891792297363, 0.6981593370437622, 0.21004319190979004, 0.24448156356811523, 0.3577829599380493, 0.8826776742935181, 0.556861162185669, 0.7953556776046753, 0.755801796913147, 0.8585830926895142, 0.8141891956329346, 0.611968994140625, 0.02731025218963623, 0.8326984643936157, 0.026828527450561523, 0.04161202907562256, 0.72672438621521, 0.16502487659454346, 0.3540531396865845, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.771440029144287]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.22602331638336182, 0.938198447227478, 0.6433297395706177, 0.22689640522003174, 0.2643556594848633, 0.38158679008483887, 0.18027138710021973, 0.7762355804443359, 0.8237777948379517, 0.16852951049804688, 0.2584739923477173, 0.38265061378479004, 0.028604865074157715, 0.05207550525665283, 0.08737969398498535, 0.831161379814148, 0.337926983833313, 0.578020453453064, 0.1244196891784668, 0.3897353410720825, 0.6503366231918335, 0.5011011362075806, 0.3202742338180542, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.270139217376709]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.1591346263885498, 0.3551875352859497, 0.19456207752227783, 0.08360016345977783, 0.49338340759277344, 0.8102091550827026, 0.038229942321777344, 0.5133059024810791, 0.42447197437286377, 0.1302204132080078, 0.032562255859375, 0.4770599603652954, 0.10980844497680664, 0.39250481128692627, 0.9826400279998779, 0.6408740282058716, 0.4889005422592163, 0.8684313297271729, 0.6371150016784668, 0.5141459703445435, 0.04483532905578613, 0.008337259292602539, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.415218830108643]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6458258628845215, 0.14830482006072998, 0.1367412805557251, 0.05629110336303711, 0.6188018321990967, 0.09611082077026367, 0.022463202476501465, 0.16802668571472168, 0.13714361190795898, 0.9738653898239136, 0.22950482368469238, 0.14387571811676025, 0.3699824810028076, 0.037640929222106934, 0.8094890117645264, 0.16305005550384521, 0.9263873100280762, 0.4807380437850952, 0.35078299045562744, 0.5584464073181152, 0.42825543880462646, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.667717456817627]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6350167989730835, 0.7023348808288574, 0.9816912412643433, 0.3239823579788208, 0.3320831060409546, 0.9436564445495605, 0.8940634727478027, 0.9500092267990112, 0.039070963859558105, 0.7722084522247314, 0.014232039451599121, 0.5869686603546143, 0.6310857534408569, 0.754490852355957, 0.4464530944824219, 0.28466737270355225, 0.12341678142547607, 0.28795409202575684, 0.03844261169433594, 0.05093085765838623, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.96828556060791]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.1708146333694458, 0.7656099796295166, 0.17411589622497559, 0.13566994667053223, 0.5877799987792969, 0.8375746011734009, 0.7500095367431641, 0.43630826473236084, 0.6984328031539917, 0.458881139755249, 0.3605443239212036, 0.6612114906311035, 0.7714365720748901, 0.18256628513336182, 0.6045645475387573, 0.6824719905853271, 0.093025803565979, 0.8150986433029175, 0.6484025716781616, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.744658946990967]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.37816083431243896, 0.3764007091522217, 0.366793155670166, 0.5291237831115723, 0.4347909688949585, 0.051524996757507324, 0.7034255266189575, 0.21059012413024902, 0.4240748882293701, 0.33774685859680176, 0.602317214012146, 0.8616265058517456, 0.09764957427978516, 0.8331103324890137, 0.6966776847839355, 0.9676048755645752, 0.39181971549987793, 0.5993291139602661, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.194350242614746]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9014437198638916, 0.6692262887954712, 0.2199995517730713, 0.34047043323516846, 0.5105847120285034, 0.5800155401229858, 0.6711152791976929, 0.24664950370788574, 0.837967038154602, 0.1854724884033203, 0.4090847969055176, 0.5367509126663208, 0.9298523664474487, 0.8082610368728638, 0.4719582796096802, 0.9887839555740356, 0.8493902683258057, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.900388240814209]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.18410146236419678, 0.16370105743408203, 0.00571596622467041, 0.6846107244491577, 0.8010516166687012, 0.5752760171890259, 0.7778260707855225, 0.570926308631897, 0.2175813913345337, 0.15920031070709229, 0.7178256511688232, 0.7729295492172241, 0.09782063961029053, 0.5135018825531006, 0.6833776235580444, 0.07111799716949463, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.223117351531982]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9505445957183838, 0.6332504749298096, 0.4598352909088135, 0.821086049079895, 0.4521135091781616, 0.40192127227783203, 0.37758028507232666, 0.7068672180175781, 0.5899826288223267, 0.3339945077896118, 0.5181410312652588, 0.5538294315338135, 0.32791435718536377, 0.6153789758682251, 0.21849405765533447, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.009317874908447]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.5818334817886353, 0.008952617645263672, 0.817952036857605, 0.8802107572555542, 0.2041914463043213, 0.6401019096374512, 0.5620921850204468, 0.6981043815612793, 0.9288793802261353, 0.17047274112701416, 0.8074302673339844, 0.027227401733398438, 0.8451176881790161, 0.8029766082763672, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.590928554534912]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.21364152431488037, 0.9299622774124146, 0.176169753074646, 0.40923750400543213, 0.0010215044021606445, 0.6993589401245117, 0.2511633634567261, 0.6728019714355469, 0.2957075834274292, 0.5348055362701416, 0.8762129545211792, 0.07801520824432373, 0.3015238046646118, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.7690606117248535]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.3406723737716675, 0.8631099462509155, 0.12643325328826904, 0.38934123516082764, 0.46326422691345215, 0.039055824279785156, 0.5590641498565674, 0.24887871742248535, 0.38112592697143555, 0.7917823791503906, 0.8130742311477661, 0.016570329666137695, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.340514183044434]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.05386042594909668, 0.0680922269821167, 0.31601762771606445, 0.18086862564086914, 0.7679719924926758, 0.6589527130126953, 0.9141957759857178, 0.402393102645874, 0.8808540105819702, 0.9081192016601562, 0.6332200765609741, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.193706512451172]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7115259170532227, 0.19793486595153809, 0.3831000328063965, 0.12318956851959229, 0.048275113105773926, 0.6922358274459839, 0.8630118370056152, 0.6173487901687622, 0.4392777681350708, 0.7511276006698608, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.914463043212891]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.15324389934539795, 0.608772873878479, 0.30232787132263184, 0.42286384105682373, 0.6527767181396484, 0.8560984134674072, 0.95783531665802, 0.35850799083709717, 0.20661818981170654, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.008178234100342]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.9347794055938721, 0.5582127571105957, 0.8688193559646606, 0.9338347911834717, 0.6681429147720337, 0.09503507614135742, 0.1364145278930664, 0.2338886260986328, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.377873420715332]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.44904494285583496, 0.6156696081161499, 0.24088788032531738, 0.04646635055541992, 0.1615074872970581, 0.3069014549255371, 0.476338267326355, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.039738178253174]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.26967108249664307, 0.34533655643463135, 0.32481229305267334, 0.032245635986328125, 0.33962345123291016, 0.5251162052154541, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[7.07838249206543]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.16084957122802734, 0.4878326654434204, 0.31659018993377686, 0.9245070219039917, 0.24317359924316406, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.446048259735107]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.2458251714706421, 0.5735921859741211, 0.2629578113555908, 0.3110496997833252, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[5.8523969650268555]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.7443956136703491, 0.07421326637268066, 0.4091228246688843, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[8.270795822143555]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.6418168544769287, 0.5278061628341675, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
[6.982004642486572]
]
>},
{#Nx.Tensor<
f32[1][32]
[
[0.5448164939880371, ...]
]
>,
#Nx.Tensor<
f32[1][1]
[
...
]
>},
{#Nx.Tensor<
f32[1][32]
[
...
]
>, ...},
{...},
...
]
Each example is a tuple of {input, label}
where input
is a uniform random vector and label
is obtained by evaluating the true data generating function, true_function
defmodule SGD do
import Nx.Defn
defn init_random_params(key) do
Nx.Random.uniform(key, shape: {32, 1})
end
defn model(params, inputs) do
labels = Nx.dot(inputs, params)
labels
end
defn mean_squared_error(y_true, y_pred) do
y_true
|> Nx.subtract(y_pred)
|> Nx.pow(2)
|> Nx.mean(axes: [-1])
end
defn loss(actual_label, predicted_label) do
loss_value = mean_squared_error(actual_label, predicted_label)
loss_value
end
def objective(params, actual_inputs, actual_labels) do
predicted_labels = model(params, actual_inputs)
loss(actual_labels, predicted_labels)
end
def step(params, actual_inputs, actual_labels) do
{loss, params_grad} =
value_and_grad(params, fn params ->
objective(params, actual_inputs, actual_labels)
end)
new_params = params - 1.0e-2 * params_grad
{loss, new_params}
end
def evaluate(trained_params, test_data) do
test_data
|> Enum.map(fn {x, y} ->
prediction = model(trained_params, x)
loss(y, prediction)
end)
|> Enum.reduce(0, &Nx.add/2)
end
def train(data, iterations, key) do
{params, _key} = init_random_params(key)
loss = Nx.tensor(0.0)
{_, trained_params} =
for i <- 1..iterations, reduce: {loss, params} do
{loss, params} ->
for {{x, y}, j} <- Enum.with_index(data), reduce: {loss, params} do
{loss, params} ->
{batch_loss, new_params} = step(params, x, y)
avg_loss = Nx.add(Nx.mean(batch_loss), loss) |> Nx.divide(j + 1)
IO.write("\rEpoch: #{i}, Loss: #{Nx.to_number(avg_loss)}")
{avg_loss, new_params}
end
end
trained_params
end
end
{:module, SGD, <<70, 79, 82, 49, 0, 0, 28, ...>>, {:train, 3}}
key = Nx.Random.key(0)
trained_params = SGD.train(train_data, 1, key)