Untitled notebook
Mix.install(
[
:nx,
{:emlx, github: "elixir-nx/emlx"},
{:kenneth, path: "/Users/hauleth/Workspace/hauleth/kenneth"}
],
config: [
nx: [default_backend: EMLX.Backend]
]
)
Section
key = Nx.Random.key(2137)
{voters, _key} = Nx.Random.normal(key, 0, 0.2, shape: {2, 2}, names: [:voter, :xy])
|> Nx.vectorize(:voter)
candidates = Nx.tensor(
[[0.5, 0.5], [-0.5, 0.5], [-0.5, -0.5], [0.5, -0.5]],
names: [:candidate, :xy]
) |> Nx.vectorize(:candidate)
#Nx.Tensor<
vectorized[candidate: 4]
f32[xy: 2]
EMLX.Backend
[
[0.5, 0.5],
[-0.5, 0.5],
[-0.5, -0.5],
[0.5, -0.5]
]
>
defmodule Voters do
import Nx.Defn
@doc "Euclidean distance between `a` and `b`"
defn dist(a, b, opts \\ []), do: Nx.sqrt(Nx.sum((a - b)**2, opts))
end
{:module, Voters, <<70, 79, 82, 49, 0, 0, 11, ...>>, true}
demo = Nx.tensor([[0.0, 0.0], [1.0, 1.0]], names: [:voter, :xy])
#Nx.Tensor<
f32[voter: 2][xy: 2]
EMLX.Backend
[
[0.0, 0.0],
[1.0, 1.0]
]
>
dists = Voters.dist(voters, candidates, axes: [0])
#Nx.Tensor<
vectorized[voter: 2][candidate: 4]
f32
EMLX.Backend
[
[0.943080723285675, 0.8635784387588501, 0.4958480894565582, 0.6240983605384827],
[0.9802231788635254, 0.6120692491531372, 0.4717492163181305, 0.8993085622787476]
]
>
voters_count = dists.vectorized_axes[:voter]
{vot_dist, key} = Nx.Random.normal(key, 0, 1.0, shape: {voters_count})
vot_dist =
vot_dist
|> Nx.abs()
|> Nx.vectorize(:voter)
#Nx.Tensor<
vectorized[voter: 2]
f32
EMLX.Backend
[0.5919870138168335, 0.692608118057251]
>
d = Nx.revectorize(dists, [voter: :auto], target_shape: {4})
#Nx.Tensor<
vectorized[voter: 2]
f32[4]
EMLX.Backend
[
[0.943080723285675, 0.8635784387588501, 0.4958480894565582, 0.6240983605384827],
[0.9802231788635254, 0.6120692491531372, 0.4717492163181305, 0.8993085622787476]
]
>
Nx.Random.shuffle(key, Nx.iota({10, 5}, axis: 1), independent: true, axis: 1)
{#Nx.Tensor<
s32[10][5]
EMLX.Backend
[
[4, 1, 0, 2, 3],
[1, 4, 2, 0, 3],
[3, 0, 4, 1, 2],
[4, 0, 3, 1, 2],
[4, 3, 1, 0, 2],
[2, 4, 0, 1, 3],
[3, 1, 4, 0, 2],
[2, 4, 1, 0, 3],
[4, 1, 3, 0, 2],
[4, 3, 0, 2, ...]
]
>,
#Nx.Tensor<
u32[2]
EMLX.Backend
[1702733630, 3143582473]
>}