Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

11785 - HW1PT2

hw1pt2.livemd

11785 - HW1PT2

Mix.install(
  [
    {:arrays, "~> 2.1"},
    {:exla, ">= 0.0.0"},
    {:npy, "~> 0.1.1"},
    {:axon, "~> 0.7.0"},
    {:table_rex, "~> 3.1.1"},
    {:kino, "~> 0.7.0"},
    {:finch, "~> 0.19.0"},
    {:jason, "~> 1.4"},
    {:polaris, "~> 0.1.0"},
    {:eflame, "~> 1.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Configuration

config = %{
  context: 30,
  epochs: 5,
  batch_size: 512,
  data_root: "/home/livebook"
  #data_root: "/Users/darren/dev/11785-project"
}

Kaggle Raw Data

The RawData module retrieves and unpacks the MFCC and transcripts raw data. There were a collection of challenges in getting the Kaggle data available in an Elixir LiveBook:

  1. There is seemingly no simple method for downloading the Kaggle data for a competition from their site. They have a python kaggle-api, but lack an unsecured HTTP endpoint that one could use to issue a GET request to download a competition .zip file. Instead, I uploaded a copy of the kaggle data to a personal S3 bucket that I had already created (for hosting a course project for a French course that I took a few years ago at CMU!)
  2. When running a LiveBook on a remote Elixir node on Fly.io, I could not successfully unzip the Kaggle data zip file. The builtin Erlang unzip :erlang.unzip would raise an error complaining that the zip file was invalid. To circumvent this, I instead created a .tar.gz representation of the data that was able to be unpacked on the remote server.
  3. Finally, I ran into problems using Elixir’s built in support for reading .npy files via Nx.load_numpy!/1. While I could use that function to read the MFCC data files, it would raise an error when attempting to read the transcript files. It turns out that these files in the Kaggle dataset encoded the transcript data as numpy arrays of strings, but the Nx Elixir library only supports loading Numpy numerical arrays. To work around this, I wrote a small python script to rewrite these transcript files to be integer numpy arrays, with the integer being the index of the phoneme (from the list of phonemes).
defmodule RawData do
  def ensure_available(config) do
    if File.exists?("#{config.data_root}/data") do
      :exists
    else
      download(config) |> untar(config)
    end
  end

  defp download(config) do
    url = "https://s3.amazonaws.com/votre-montreal.com/data.tar.gz"
    destination = "#{config.data_root}/data.tar.gz"

    # {_, exit_status} = System.cmd("curl", ["-o", destination, url])

    {_, exit_status} = System.cmd("wget", [url, "-O", destination])

    if exit_status == 0 do
      :ok
    else
      :error
    end
  end

  defp untar(:error, _), do: :error

  defp untar(:ok, config) do
    {_, exit_status} = System.cmd("tar", ["-xzvf", "#{config.data_root}/data.tar.gz"])

    if exit_status == 0 do
      :ok
    else
      :error
    end
  end
end

RawData.ensure_available(config)

This sets the Nx setting to allow all tensors to move to the GPU (if one is available)

Nx.Defn.global_default_options(compiler: EXLA)
defmodule Globals do
  def phonemes(),
    do: [
      "[SIL]",
      "AA",
      "AE",
      "AH",
      "AO",
      "AW",
      "AY",
      "B",
      "CH",
      "D",
      "DH",
      "EH",
      "ER",
      "EY",
      "F",
      "G",
      "HH",
      "IH",
      "IY",
      "JH",
      "K",
      "L",
      "M",
      "N",
      "NG",
      "OW",
      "OY",
      "P",
      "R",
      "S",
      "SH",
      "T",
      "TH",
      "UH",
      "UW",
      "V",
      "W",
      "Y",
      "Z",
      "ZH",
      "[SOS]",
      "[EOS]"
    ]
end

Transformations

Here I have replicated the two transformation that I used in HW1PT2: the cepstral mean transform and a combined time and frequency masking transform.

defmodule Masks do
  def freq(mask, low \\ 0.0, high \\ 28.0, f \\ 4) do
    # Randomly choose f between 3 and F
    f_value = :rand.uniform() * (f - 3.0) + 3.0
    f_value = Float.floor(f_value) |> trunc()

    # Randomly choose the starting frequency
    f0_value = :rand.uniform() * (high - f_value - low) + low
    f0_value = Float.floor(f0_value) |> trunc()

    # Create the zero mask along the frequency dimension
    mask_shape = Nx.shape(mask)
    mask_slice = Nx.slice(mask, [0, f0_value], [elem(mask_shape, 0), f_value])

    # Subtract the slice from the mask, effectively zeroing out that portion
    mask = Nx.put_slice(mask, [0, f0_value], Nx.multiply(mask_slice, Nx.tensor(0.0)))

    mask
  end

  def time(mask, f \\ 8) do
    # Get the number of time steps (rows)
    num_time_steps = Nx.shape(mask) |> elem(0)

    # Randomly choose t between 3 and F
    t_value = :rand.uniform() * (f - 3.0) + 3.0
    t_value = Float.floor(t_value) |> trunc()

    # Randomly choose the starting time step
    t0_value = :rand.uniform() * (num_time_steps - t_value)
    t0_value = Float.floor(t0_value) |> trunc()

    # Create the zero mask along the time dimension
    mask_shape = Nx.shape(mask)
    mask_slice = Nx.slice(mask, [t0_value, 0], [t_value, elem(mask_shape, 1)])

    # Subtract the slice from the mask, effectively zeroing out that portion
    Nx.put_slice(mask, [t0_value, 0], Nx.multiply(mask_slice, Nx.tensor(0.0)))
  end
end

defmodule AudioTransforms do
  # Function to apply the cepstral mean and variance normalization
  def cepstral_mean_transform(data) do
    # Calculate the mean along axis 0 (mean for each column)
    mean = Nx.mean(data, axes: [0])

    # Subtract the mean from the data (center the data)
    centered_data = Nx.subtract(data, mean)

    # Calculate the standard deviation along axis 0
    std_dev = Nx.standard_deviation(centered_data, axes: [0])

    # Normalize the data by dividing by the standard deviation
    Nx.divide(centered_data, std_dev)
  end

  def freq_time_masking_transform(batch) do
    # Get the shape of a single entry in the batch (2D tensor)
    {batch_size, height, width} = Nx.shape(batch)

    # Create an initial mask of ones with shape (height, width) for a single entry
    mask = Nx.broadcast(Nx.tensor(1.0), {height, width})

    # Apply 0, 1, or 2 random frequency transformations
    num_freq_xforms = :rand.uniform(3) - 1
    mask = Enum.reduce(1..num_freq_xforms, mask, fn _, mask_acc -> Masks.freq(mask_acc) end)

    # Apply 0, 1, or 2 random time transformations
    num_time_xforms = :rand.uniform(3) - 1
    mask = Enum.reduce(1..num_time_xforms, mask, fn _, mask_acc -> Masks.time(mask_acc) end)

    # Broadcast the 2D mask to match the batch size (apply to every 2D tensor in the batch)
    broadcasted_mask = Nx.broadcast(mask, {batch_size, height, width})

    # Apply the broadcasted mask to the batch (element-wise multiplication)
    Nx.multiply(batch, broadcasted_mask)
  end
end

DataSets

An implementation of a dataset for MFCC audio data.

This generally mimics what I did in HW1PT2, with some additional enhancements for performing batch transforms during a collate function (called by the DataLoader).

One challenge here was the runtime of the ingest_partition. On my local dev machine (Macbook Pro, Apple Silocon, SSD) I can read in both the train and dev partitions in around 100 seconds. On a remote Fly.io machine (4 performance CPUs, 32 GB ram), this originally took closer to 1500 seconds. I added in some concurrent processing via Task.async_stream (line 48) and was able to reduce the runtime closer to 300 seconds.

The most significant challenge, however, was the complete lack of support for constant time, random access data structures in Elixir. My HW1PT2 impl in Python used a binary search (via bisect.bisect_right) to determine the in memory location of the requested index in the __getitem__ method of the Dataset. Replicating this approach in Elixir was impossible due to the runtime characteristics of Elixir lists. In Elixir, lists are immutable singly-linked lists - therefore random element access is O(n). I experimented with other, third-party libraries that emulate a list but provide O(log n) guarantees for random element access. Even with that approach, creating batches of phonemes was too expensive to even start to train the model. The best that I could do was to eliminate the context padding of each MFCC file, and concatenate them all together. Even with this approach, I was never able to complete a single training epoch, given that the runtime per epoch is seemingly in order of “hours” long.

defmodule AudioDataset do
  @moduledoc """
  A dataset module for ingesting and padding MFCC data.
  """

  @max_context 60

  defstruct [
    :context,
    :phonemes,
    :augment,
    :mfccs,
    :transcripts,
    :index_mapping,
    :length,
    :total
  ]

  # Helper function to read data (assuming .npy files)
  defp read_file(path) do
    File.read!(path)
    |> Nx.load_numpy!()
  end

  # Create a zero padding tensor of shape {MAX_CONTEXT, 28}
  defp padding(_context) do
    Nx.tensor(Nx.broadcast(0.0, {@max_context, 28}), type: :f32)
  end

  # Ingest a partition of MFCC and transcript data
  defp ingest_partition(acc, root, _context, partition) do
    mfcc_dir = Path.join([root, partition, "mfcc"])
    transcript_dir = Path.join([root, partition, "transcript"])

    mfcc_names = Path.wildcard(Path.join(mfcc_dir, "*.npy"))

    transcript_names = Path.wildcard(Path.join(transcript_dir, "*.npy"))

    if length(mfcc_names) != length(transcript_names) do
      IO.inspect(length(mfcc_names))
      IO.inspect(length(transcript_names))

      raise "Mismatch between MFCC and transcript counts"
    end

    Enum.take(mfcc_names, 2500)
    |> Task.async_stream(
      fn mfcc_path ->
        mfcc = read_file(mfcc_path)
        mfcc = AudioTransforms.cepstral_mean_transform(mfcc)

        transcript_path = Path.join(transcript_dir, Path.basename(mfcc_path))
        transcript = read_file(transcript_path)
        transcript = Nx.slice(transcript, [1], [Nx.size(transcript) - 2])

        {mfcc, transcript}
      end,
      max_concurrency: System.schedulers_online() * 2,
      ordered: true
    )
    |> Enum.reduce(acc, fn {:ok, {mfcc, transcript}}, acc ->
      {_freq, length} = Nx.shape(mfcc)
      mfccs = [mfcc | acc.mfccs]
      transcripts = [transcript | acc.transcripts]

      %{
        acc
        | transcripts: transcripts,
          mfccs: mfccs,
          total: acc.total + length
      }
    end)
  end

  # Initialize the dataset with the specified partitions
  def new(root, context \\ 30, partition \\ "train-clean-100", augment \\ true) do
    acc = %AudioDataset{
      context: context,
      phonemes: Globals.phonemes(),
      augment: augment,
      mfccs: [],
      transcripts: [],
      index_mapping: Arrays.new([]),
      length: 0,
      total: 0
    }

    acc = ingest_partition(acc, root, context, partition)

    mfccs = Enum.reverse(acc.mfccs)
    transcripts = Enum.reverse(acc.transcripts)

    mfccs = Nx.concatenate(mfccs)
    transcripts = Nx.concatenate(transcripts)

    mfccs = Nx.concatenate([padding(context), mfccs, padding(context)])

    %{acc | mfccs: mfccs, transcripts: transcripts}
  end

  def len(dataset), do: dataset.total

  def get_item(dataset, ind) do
    actual_index = ind + @max_context

    before_context = dataset.context
    after_context = dataset.context

    lower_offset = actual_index - before_context
    upper_offset = actual_index + after_context + 1

    frames =
      Nx.slice(dataset.mfccs, [lower_offset, 0], [
        upper_offset - lower_offset,
        28
      ])

    phoneme = Nx.slice(dataset.transcripts, [actual_index - @max_context], [1])

    {frames, phoneme}
  end

  def collate(dataset, batch) do
    {batch_size, height, width} = Nx.shape(batch)

    batch =
      case dataset.augment do
        true -> AudioTransforms.freq_time_masking_transform(batch)
        _ -> batch
      end

    Nx.reshape(batch, {batch_size, height * width})
  end
end

DataLoader

A DataLoader implementation which supports arbitrary batch size, asynchronous processing via a number of background “workers” and optional shuffling and data subseting.

Here we see the power of the Elixir/Erlang concurrency model in action. It takes 1 line of code to turn this serial, synchronous data loader into a more performant concurrent implementation. The call to Task.async_stream with the :max_concurrency option set our desired number of workers gives us a remarkably simple yet performant equivalent to Python’s DataLoader

defmodule DataLoader do
  @moduledoc """
  A DataLoader module for asynchronously fetching and processing dataset batches.

  Supports asychronous processing with specified number of worker 
  """

  defstruct [:dataset_module, :dataset, :batch_size, :num_workers, :shuffle, :subset]

  # Create a new DataLoader with a dataset, batch size, and number of async workers
  def new(dataset_module, dataset, opts \\ []) do
    # Set default values for options
    opts = Keyword.merge([batch_size: 32, num_workers: 4, shuffle: true, subset: 0], opts)

    %DataLoader{
      dataset_module: dataset_module,
      dataset: dataset,
      batch_size: Keyword.get(opts, :batch_size),
      num_workers: Keyword.get(opts, :num_workers),
      shuffle: Keyword.get(opts, :shuffle),
      subset: Keyword.get(opts, :subset)
    }
  end

  # Public function to return a stream of batches
  def data(%DataLoader{
        dataset_module: dataset_module,
        dataset: dataset,
        batch_size: batch_size,
        num_workers: num_workers,
        subset: subset,
        shuffle: shuffle
      }) do
    # Create a list of batch indices, subset and sort if needed
    indices = 0..(dataset.total - 1)

    indices =
      case subset do
        0 -> indices
        n -> Enum.take_every(indices, n)
      end

    indices =
      if shuffle do
        Enum.shuffle(indices)
      else
        shuffle
      end

    # Support async processing by a number of workers, or synchronous
    # when num_workers is 0.  Syncrhonous is helpful for debugging.
    case num_workers do
      0 ->
        Enum.chunk_every(indices, batch_size, batch_size, :discard)
        |> Enum.map(fn batch -> prepare_batch(dataset_module, dataset, batch) end)

      n ->
        Stream.chunk_every(indices, batch_size, batch_size, :discard)
        |> Task.async_stream(fn batch -> prepare_batch(dataset_module, dataset, batch) end,
          max_concurrency: n,
          ordered: !shuffle,
          timeout: :infinity
        )
        |> Stream.map(fn
          # Extract the result from the {:ok, result} tuple
          {:ok, result} -> result
          # Optional: Handle errors if needed
          {:error, _reason} -> raise "Task failed"
        end)
    end
  end

  defp prepare_batch(dataset_module, dataset, batch_indices) do
    # Map the indices that we need for this batch to 
    # frame and phoneme tuples, unzipping them to rearrange
    # as a tuple of two separate lists
    {x, y} =
      batch_indices
      |> Enum.map(fn idx -> apply(dataset_module, :get_item, [dataset, idx]) end)
      |> Enum.unzip()

    # Collate the batch and move it to the GPU
    batch_x =
      apply(dataset_module, :collate, [dataset, Nx.stack(x)])
      |> Nx.backend_transfer(EXLA.Backend)

    # Stack the y tensors and move to GPU
    batch_y =
      Nx.stack(y)
      |> Nx.backend_transfer(EXLA.Backend)

    {batch_x, batch_y}
  end
end
train_data_set =
  AudioDataset.new("#{config.data_root}/data", config.context, "train-clean-100", false)

IO.inspect(
  "Train dataset loaded, #{train_data_set.length} mfccs, #{train_data_set.total} samples"
)

dev_data_set = AudioDataset.new("#{config.data_root}/data", config.context, "dev-clean", false)
IO.inspect("Dev dataset loaded, #{dev_data_set.length} mfccs, #{dev_data_set.total} samples")

train_data_loader =
  DataLoader.new(AudioDataset, train_data_set, batch_size: config.batch_size, num_workers: 16)

dev_data_loader =
  DataLoader.new(AudioDataset, dev_data_set, batch_size: config.batch_size, num_workers: 0)

Model Definition

defmodule Model do
  def tiny(input_size, output_size) do
    Axon.input("data", shape: {nil, input_size})
    |> linear(512)
    |> Axon.batch_norm()
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.35)
    |> linear(512)
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.3)
    |> linear(512)
    |> Axon.batch_norm()
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.25)
    |> Axon.dense(output_size, activation: :softmax)
  end

  def full(input_size, output_size) do
    Axon.input("data", shape: {nil, input_size})
    |> linear(2048)
    |> Axon.batch_norm()
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.35)
    |> linear(2048)
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.3)
    |> linear(2048)
    |> Axon.batch_norm()
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.25)
    |> linear(2048)
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.2)
    |> linear(1024)
    |> Axon.batch_norm()
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.15)
    |> linear(512)
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.1)
    |> linear(256)
    |> Axon.batch_norm()
    |> Axon.gelu()
    |> Axon.dropout(rate: 0.05)
    |> Axon.dense(output_size, activation: :softmax)
  end

  defp linear(model, size) do
    Axon.dense(model, size,
      kernel_initializer: Axon.Initializers.glorot_uniform(),
      bias_initializer: Axon.Initializers.zeros()
    )
  end

  def summarize(model, input_size) do
    Axon.Display.as_graph(model, Nx.template({1, input_size}, :f32))
  end
end

input_size = (2 * config.context + 1) * 28
output_size = Enum.count(Globals.phonemes())

model = Model.full(input_size, output_size)

{init_fn, pred_fn} = Axon.build(model)

Model.summarize(model, input_size)

Training

eval_loop =
  Axon.Loop.evaluator(model)
  |> Axon.Loop.metric(:accuracy, "accuracy")

Nx.global_default_backend({EXLA.Backend, client: :cuda})

run_eval = fn state ->
  model_state = state.step_state.model_state
  eval_accuracy = Axon.Loop.run(eval_loop, DataLoader.data(dev_data_loader), model_state)
  IO.inspect(eval_accuracy)
  {:continue, state}
end

# Define the training loop
trainer_loop =
  Axon.Loop.trainer(model, :categorical_cross_entropy, Polaris.Optimizers.adamw())
  |> Axon.Loop.metric(:accuracy, "accuracy")
  |> Axon.Loop.handle_event(:epoch_completed, run_eval)
  |> Axon.Loop.run(DataLoader.data(train_data_loader), %{},
    epochs: config.epochs,
    compiler: EXLA,
    client: :cuda
  )

The FlameGraph module was used to profile the runtime of the DataSet and DataLoader.

defmodule FlameGraph do
  @stack_files_dir "_flame_graph_stacks"
  @svg_dir "./"
  @url_base "./flame_graphs"

  def create(func, tag) do
    File.mkdir(@stack_files_dir)
    filename = "#{@stack_files_dir}/#{tag}.out"
    :eflame.apply(:normal_with_children, filename, func, [])
  end

  def list() do
    Path.wildcard("#{@stack_files_dir}/*.out")
  end

  def to_svg(stack_filename) do
    File.mkdir(@svg_dir)

    svg_filename = Path.basename(stack_filename, ".out") <> ".svg"

    "deps/eflame/stack_to_flame.sh < #{stack_filename} > #{@svg_dir}/#{svg_filename}"
    |> String.to_charlist()
    |> :os.cmd()

    "#{@url_base}/#{svg_filename}"
  end
end

With the above flame graph generation module, I profiled the runtime of the original DataLoader collation. The runtime was dominated by the calls to bisect_right as it binary searched through a large Nx array in order to find the correct MFCC file for the given absolute index.

Results / Commentary

Pros:

  1. Livebook’s direct support for deploying and running on a Fly.io GPU is fantastic. I rarely had to wait more than 10 seconds to obtain a machine with a high end GPU.

  2. The Axon library syntax is pleasant to work with and to read. The pipe operator |> seems a natural way to express the flow of data through the components and layers of a network.

Cons:

  1. Lack of a data structure with constant time access to random elements has posed an insurmountable challenge. (The built in list in Elixir is a single linked list with O(n) element access) Even third-party libraries with “fast array” support offer only O(log n) random access - which also wasn’t efficient enough to allow random batches to be collated. Even at a tiny batch size of 16, the projected time to train 1 single epoch of the full training data set is on the order of days.

  2. There is no support for Wandb, Kaggle, and Google Drive mounting. This makes doing serious DL work difficult.

  3. Livebook is still immature, and clearly has bugs. Numerous times I faced problems with cells re-running automatically when they didn’t need to be (this was problematic when the cell was reloading the entire dataset). I also encountered strange code editing phenemona, with my cursor seemingly jumping to other positions, backspace refusing to delete characters, etc.