Torchx notebook
Mix.install(
[
{:nx, "~> 0.8"},
# {:exla, "~> 0.6.1"},
{:torchx, "~> 0.8"},
{:scidata, "~> 0.1.11"},
{:benchee, "~> 1.3"}
],
config: [
nx: [
default_backend: {Torchx.Backend, device: :cuda}
],
],
system_env: %{"LIBTORCH_TARGET" => "cu121", "LIBTORCH_VERSION" => "2.4.1"}
)
Section
{train_images, train_labels} = Scidata.MNIST.download()
{test_images, test_labels} = Scidata.MNIST.download_test()
{train_images_binary, train_tensor_type, train_shape} = train_images
{test_images_binary, test_tensor_type, test_shape} = test_images
train_tensors =
train_images_binary
|> Nx.from_binary(train_tensor_type)
|> Nx.reshape({60000, 28 * 28})
|> Nx.divide(255)
x_train = train_tensors[0..49_999]
x_valid = train_tensors[50_000..59_999]
{x_train.shape, x_valid.shape}
mean = 0.0
variance = 1.0
IO.puts(Torchx.device_available?(:cuda))
IO.puts(Torchx.device_count(:cuda))
# {784, 10}
key = Nx.Random.key(12)
{weights, _new_key} = Nx.Random.normal(key, mean, variance, type: {:f, 32})
# weights = Torchx.normal(mean, variance, {784, 10}, {:f, 32}, :cuda)
large_nx_mult_fn = fn -> Nx.dot(x_valid, weights) end
x_valid_cuda = Nx.backend_transfer(x_valid, {Torchx.Backend, device: :cuda})
weights_cuda = Nx.backend_transfer(weights, {Torchx.Backend, device: :cuda})
torchx_gpu_mult_fn = fn -> Nx.dot(x_valid_cuda, weights_cuda) end
repeat = fn timed_fn, times -> Enum.each(1..times, fn _x -> timed_fn.() end) end
repeat_times = 5
# Warmup
{elapsed_time_micro, _} = :timer.tc(repeat, [torchx_gpu_mult_fn, repeat_times])
{elapsed_time_micro, _} = :timer.tc(repeat, [torchx_gpu_mult_fn, repeat_times])
avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times
{backend, [device: device]} = Nx.default_backend()
"#{backend} #{device} avg time in milliseconds #{avg_elapsed_time_ms} total_time #{elapsed_time_micro / 1000}"