Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Torchx notebook

torchx_cuda.livemd

Torchx notebook

Mix.install(
  [
    {:nx, "~> 0.8"},
    # {:exla, "~> 0.6.1"},
    {:torchx, "~> 0.8"},
    {:scidata, "~> 0.1.11"},
    {:benchee, "~> 1.3"}
  ],
  config: [
    nx: [
      default_backend:  {Torchx.Backend, device: :cuda}
    ],
  ],
  system_env: %{"LIBTORCH_TARGET" => "cu121", "LIBTORCH_VERSION" => "2.4.1"}
)

Section

{train_images, train_labels} = Scidata.MNIST.download()
{test_images, test_labels} = Scidata.MNIST.download_test()
{train_images_binary, train_tensor_type, train_shape} = train_images
{test_images_binary, test_tensor_type, test_shape} = test_images
train_tensors =
  train_images_binary
  |> Nx.from_binary(train_tensor_type)
  |> Nx.reshape({60000, 28 * 28})
  |> Nx.divide(255)
x_train = train_tensors[0..49_999]
x_valid = train_tensors[50_000..59_999]
{x_train.shape, x_valid.shape}
mean = 0.0
variance = 1.0
IO.puts(Torchx.device_available?(:cuda))
IO.puts(Torchx.device_count(:cuda))
# {784, 10}
key = Nx.Random.key(12)
{weights, _new_key} = Nx.Random.normal(key, mean, variance, type: {:f, 32})
# weights = Torchx.normal(mean, variance, {784, 10}, {:f, 32}, :cuda)
large_nx_mult_fn = fn -> Nx.dot(x_valid, weights) end

x_valid_cuda = Nx.backend_transfer(x_valid, {Torchx.Backend, device: :cuda})
weights_cuda = Nx.backend_transfer(weights, {Torchx.Backend, device: :cuda})
torchx_gpu_mult_fn = fn -> Nx.dot(x_valid_cuda, weights_cuda) end
repeat = fn timed_fn, times -> Enum.each(1..times, fn _x -> timed_fn.() end) end
repeat_times = 5
# Warmup
{elapsed_time_micro, _} = :timer.tc(repeat, [torchx_gpu_mult_fn, repeat_times])
{elapsed_time_micro, _} = :timer.tc(repeat, [torchx_gpu_mult_fn, repeat_times])
avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times

{backend, [device: device]} = Nx.default_backend()

"#{backend} #{device} avg time in milliseconds #{avg_elapsed_time_ms} total_time #{elapsed_time_micro / 1000}"