Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Matrix multiplication on GPU - SHORT - XLA

books/ai/gpu/exla_short.livemd

Matrix multiplication on GPU - SHORT - XLA

Mix.install(
  [
    {:nx, "~> 0.6"},
    {:scidata, "~> 0.1"},
    {:axon, "~> 0.6"},
    {:exla, "~> 0.6"}
  ],
  system_env: %{
    "XLA_TARGET" => "cuda120",
    "ELIXIR_ERL_OPTIONS" => "+sssdio 128"
  }
)

Before running notebook

This notebook has a dependency on EXLA. XLA support systems with direct access to an NVidia GPU, AMD ROCm or a Google TPU. According to the documentation, https://github.com/elixir-nx/nx/tree/main/exla#readme EXLA will try to find a precompiled version that matches your system. If it doesn’t find a match. you will need to install CUDA and CuDNN for your system.

The notebook is currently configured for Nvidia GPU via

system_env: %{"XLA_TARGET" => "cuda111"}

Review the configuration documentation for more options. https://hexdocs.pm/exla/EXLA.html#module-configuration

We had to install CUDA and CuDNN but that was several months ago. Your experience may vary from ours.

We’ll pull down the MNIST data

{train_images, train_labels} = Scidata.MNIST.download()
{train_images_binary, train_tensor_type, train_shape} = train_images
train_tensor_type

Convert into Tensors and normalize to between 0 and 1

train_tensors =
  train_images_binary
  |> Nx.from_binary(train_tensor_type)
  |> Nx.reshape({60000, 28 * 28})
  |> Nx.divide(255)

We’ll separate the data into 50,000 train images and 10,000 validation images.

x_train_cpu = train_tensors[0..49_999]
x_valid_cpu = train_tensors[50_000..59_999]
{x_train_cpu.shape, x_valid_cpu.shape}

Training is more stable when random numbers are initialized with a mean of 0.0 and a variance of 1.0

mean = 0.0
variance = 1.0
weights_cpu = Nx.random_normal({784, 10}, mean, variance, type: {:f, 32})

In order to simplify timing the performance of the Nx.dot/2 function, we’ll use an 0 parameter anonymous function. Invoking the anonymous function will always use the two parameters, x_valid_cpu and weights_cpu.

large_nx_mult_fn = fn -> Nx.dot(x_valid_cpu, weights_cpu) end

The following anonymous function takes function and the number of times to make the call to the function.

repeat = fn timed_fn, times -> Enum.each(1..times, fn _x -> timed_fn.() end) end

Timing the average duration of the dot multiply function to run. The cell will output the average and total elapsed time

# repeat_times = 5
# {elapsed_time_micro, _} = :timer.tc(repeat, [large_nx_mult_fn, repeat_times])
# avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times

# {backend, _device} = Nx.default_backend()

# {backend} CPU avg time in #{avg_elapsed_time_ms} milliseconds, total_time #{elapsed_time_micro / 1000} milliseconds"
:ok

XLA using GPU

We’ll switch to the XLA backend and use the cuda device. If you have a different device, replace all the :cuda specifications with your device.

Nx.default_backend({EXLA.Backend, device_id: :cuda})
Nx.default_backend()

In the following cell, we transfer the target data onto the GPU.

x_valid_cuda = Nx.backend_transfer(x_valid_cpu, {EXLA.Backend, client: :cuda})
weights_cuda = Nx.backend_transfer(weights_cpu, {EXLA.Backend, client: :cuda})

An anonymous function that calls Nx.dot/2 with data on the GPU

exla_gpu_mult_fn = fn -> Nx.dot(x_valid_cuda, weights_cuda) end

We’ll warm up the GPU by looping through 5 function calls and then timing the next 5 function calls.

repeat_times = 5
EXLA.Client.get_supported_platforms()
# Warm up one epoch
{elapsed_time_micro, _} = :timer.tc(repeat, [exla_gpu_mult_fn, repeat_times])
# The real timing starts here
# {elapsed_time_micro, _} = :timer.tc(repeat, [exla_gpu_mult_fn, repeat_times])
# avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times

# {backend, [device_id: device]} = Nx.default_backend()

# "#{backend} #{device} avg time in #{avg_elapsed_time_ms} milliseconds total_time #{elapsed_time_micro / 1000} milliseconds"
x_valid_cpu = Nx.backend_transfer(x_valid_cuda, Nx.BinaryBackend)
weights_cpu = Nx.backend_transfer(weights_cuda, Nx.BinaryBackend)