Nx Pipeline Example
Mix.install([
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla"},
{:handoff, "~> 0.1"}
])
Section
defmodule NxRunner do
import Nx.Defn
defn my_function(x, y) do
c = Nx.cos(x)
s = Nx.sin(y)
c + s
end
def compile_and_call_expr(args, expr) do
Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)])
end
def load_from_file(filename, index) do
filename
|> File.read!()
|> Nx.deserialize()
|> elem(index)
end
def serialize(tuple, source, target, idx) do
t = elem(tuple, idx)
if source == target do
t
else
Nx.serialize(t)
end
end
def deserialize(tensor, source, target) do
if source == target do
tensor
else
Nx.deserialize(tensor)
end
end
end
expr = Nx.Defn.debug_expr(&NxRunner.my_function/2).(Nx.template({10}, :f32), 1)
nx_graph = Nx.Defn.Graph.split(expr, fn %{data: %{op: op}} -> op == :add end)
args =
%Handoff.Function{
id: :arguments,
args: [],
code: &Function.identity/1,
extra_args: [{Nx.iota({10}, type: :f32), 1}],
type: :inline
}
nx_stage_to_function = fn stage ->
args =
Enum.map(stage.arguments, fn
%{source: {producer_id, idx}} ->
%Handoff.Function.Argument{
id: producer_id || :arguments,
serialization_fn: {NxRunner, :serialize, [idx]},
deserialization_fn: {NxRunner, :deserialize, []}
}
end)
%Handoff.Function{
id: stage.id,
args: args,
code: &NxRunner.compile_and_call_expr/2,
extra_args: [stage.expr],
argument_inclusion: :as_list
}
end
dag =
nx_graph
|> Enum.map(nx_stage_to_function)
|> Enum.reduce(Handoff.DAG.new(), &Handoff.DAG.add_function(&2, &1))
|> Handoff.DAG.add_function(args)
dag
{:ok, %{results: results}} = Handoff.execute(dag)
results[List.last(nx_graph).id]