Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Nx Pipeline Example

livebooks/nx_pipeline.livemd

Nx Pipeline Example

Mix.install([
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:exla, github: "elixir-nx/nx", sparse: "exla"},
  {:handoff, "~> 0.1"}
])

Section

defmodule NxRunner do
  import Nx.Defn

  defn my_function(x, y) do
    c = Nx.cos(x)
    s = Nx.sin(y)

    c + s
  end

  def compile_and_call_expr(args, expr) do
    Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)])
  end

  def load_from_file(filename, index) do
    filename
    |> File.read!()
    |> Nx.deserialize()
    |> elem(index)
  end

  def serialize(tuple, source, target, idx) do
    t = elem(tuple, idx)

    if source == target do
      t
    else
      Nx.serialize(t)
    end
  end

  def deserialize(tensor, source, target) do
    if source == target do
      tensor
    else
      Nx.deserialize(tensor)
    end
  end
end

expr = Nx.Defn.debug_expr(&NxRunner.my_function/2).(Nx.template({10}, :f32), 1)
nx_graph = Nx.Defn.Graph.split(expr, fn %{data: %{op: op}} -> op == :add end)
args =
  %Handoff.Function{
    id: :arguments,
    args: [],
    code: &Function.identity/1,
    extra_args: [{Nx.iota({10}, type: :f32), 1}],
    type: :inline
  }

nx_stage_to_function = fn stage ->
  args =
    Enum.map(stage.arguments, fn
      %{source: {producer_id, idx}} ->
        %Handoff.Function.Argument{
          id: producer_id || :arguments,
          serialization_fn: {NxRunner, :serialize, [idx]},
          deserialization_fn: {NxRunner, :deserialize, []}
        }
    end)

    %Handoff.Function{
      id: stage.id,
      args: args,
      code: &NxRunner.compile_and_call_expr/2,
      extra_args: [stage.expr],
      argument_inclusion: :as_list
    }
end

dag =
  nx_graph
  |> Enum.map(nx_stage_to_function)
  |> Enum.reduce(Handoff.DAG.new(), &Handoff.DAG.add_function(&2, &1))
  |> Handoff.DAG.add_function(args)

dag
{:ok, %{results: results}} = Handoff.execute(dag)

results[List.last(nx_graph).id]