11785 - HW1PT2
Mix.install(
[
{:arrays, "~> 2.1"},
{:exla, ">= 0.0.0"},
{:npy, "~> 0.1.1"},
{:axon, "~> 0.7.0"},
{:table_rex, "~> 3.1.1"},
{:kino, "~> 0.7.0"},
{:finch, "~> 0.19.0"},
{:jason, "~> 1.4"},
{:polaris, "~> 0.1.0"},
{:eflame, "~> 1.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Configuration
config = %{
context: 30,
epochs: 5,
batch_size: 512,
data_root: "/home/livebook"
#data_root: "/Users/darren/dev/11785-project"
}
Kaggle Raw Data
The RawData
module retrieves and unpacks the MFCC and transcripts raw data. There were a
collection of challenges in getting the Kaggle data available in an Elixir LiveBook:
-
There is seemingly no simple method for downloading the Kaggle data for a competition from their site. They have a python
kaggle-api
, but lack an unsecured HTTP endpoint that one could use to issue aGET
request to download a competition.zip
file. Instead, I uploaded a copy of the kaggle data to a personal S3 bucket that I had already created (for hosting a course project for a French course that I took a few years ago at CMU!) -
When running a LiveBook on a remote Elixir node on Fly.io, I could not successfully unzip
the Kaggle data zip file. The builtin Erlang unzip
:erlang.unzip
would raise an error complaining that the zip file was invalid. To circumvent this, I instead created a.tar.gz
representation of the data that was able to be unpacked on the remote server. -
Finally, I ran into problems using Elixir’s built in support for reading
.npy
files viaNx.load_numpy!/1
. While I could use that function to read the MFCC data files, it would raise an error when attempting to read the transcript files. It turns out that these files in the Kaggle dataset encoded the transcript data as numpy arrays of strings, but theNx
Elixir library only supports loading Numpy numerical arrays. To work around this, I wrote a small python script to rewrite these transcript files to be integer numpy arrays, with the integer being the index of the phoneme (from the list of phonemes).
defmodule RawData do
def ensure_available(config) do
if File.exists?("#{config.data_root}/data") do
:exists
else
download(config) |> untar(config)
end
end
defp download(config) do
url = "https://s3.amazonaws.com/votre-montreal.com/data.tar.gz"
destination = "#{config.data_root}/data.tar.gz"
# {_, exit_status} = System.cmd("curl", ["-o", destination, url])
{_, exit_status} = System.cmd("wget", [url, "-O", destination])
if exit_status == 0 do
:ok
else
:error
end
end
defp untar(:error, _), do: :error
defp untar(:ok, config) do
{_, exit_status} = System.cmd("tar", ["-xzvf", "#{config.data_root}/data.tar.gz"])
if exit_status == 0 do
:ok
else
:error
end
end
end
RawData.ensure_available(config)
This sets the Nx
setting to allow all tensors to move to the GPU (if one is available)
Nx.Defn.global_default_options(compiler: EXLA)
defmodule Globals do
def phonemes(),
do: [
"[SIL]",
"AA",
"AE",
"AH",
"AO",
"AW",
"AY",
"B",
"CH",
"D",
"DH",
"EH",
"ER",
"EY",
"F",
"G",
"HH",
"IH",
"IY",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OY",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UW",
"V",
"W",
"Y",
"Z",
"ZH",
"[SOS]",
"[EOS]"
]
end
Transformations
Here I have replicated the two transformation that I used in HW1PT2: the cepstral mean transform and a combined time and frequency masking transform.
defmodule Masks do
def freq(mask, low \\ 0.0, high \\ 28.0, f \\ 4) do
# Randomly choose f between 3 and F
f_value = :rand.uniform() * (f - 3.0) + 3.0
f_value = Float.floor(f_value) |> trunc()
# Randomly choose the starting frequency
f0_value = :rand.uniform() * (high - f_value - low) + low
f0_value = Float.floor(f0_value) |> trunc()
# Create the zero mask along the frequency dimension
mask_shape = Nx.shape(mask)
mask_slice = Nx.slice(mask, [0, f0_value], [elem(mask_shape, 0), f_value])
# Subtract the slice from the mask, effectively zeroing out that portion
mask = Nx.put_slice(mask, [0, f0_value], Nx.multiply(mask_slice, Nx.tensor(0.0)))
mask
end
def time(mask, f \\ 8) do
# Get the number of time steps (rows)
num_time_steps = Nx.shape(mask) |> elem(0)
# Randomly choose t between 3 and F
t_value = :rand.uniform() * (f - 3.0) + 3.0
t_value = Float.floor(t_value) |> trunc()
# Randomly choose the starting time step
t0_value = :rand.uniform() * (num_time_steps - t_value)
t0_value = Float.floor(t0_value) |> trunc()
# Create the zero mask along the time dimension
mask_shape = Nx.shape(mask)
mask_slice = Nx.slice(mask, [t0_value, 0], [t_value, elem(mask_shape, 1)])
# Subtract the slice from the mask, effectively zeroing out that portion
Nx.put_slice(mask, [t0_value, 0], Nx.multiply(mask_slice, Nx.tensor(0.0)))
end
end
defmodule AudioTransforms do
# Function to apply the cepstral mean and variance normalization
def cepstral_mean_transform(data) do
# Calculate the mean along axis 0 (mean for each column)
mean = Nx.mean(data, axes: [0])
# Subtract the mean from the data (center the data)
centered_data = Nx.subtract(data, mean)
# Calculate the standard deviation along axis 0
std_dev = Nx.standard_deviation(centered_data, axes: [0])
# Normalize the data by dividing by the standard deviation
Nx.divide(centered_data, std_dev)
end
def freq_time_masking_transform(batch) do
# Get the shape of a single entry in the batch (2D tensor)
{batch_size, height, width} = Nx.shape(batch)
# Create an initial mask of ones with shape (height, width) for a single entry
mask = Nx.broadcast(Nx.tensor(1.0), {height, width})
# Apply 0, 1, or 2 random frequency transformations
num_freq_xforms = :rand.uniform(3) - 1
mask = Enum.reduce(1..num_freq_xforms, mask, fn _, mask_acc -> Masks.freq(mask_acc) end)
# Apply 0, 1, or 2 random time transformations
num_time_xforms = :rand.uniform(3) - 1
mask = Enum.reduce(1..num_time_xforms, mask, fn _, mask_acc -> Masks.time(mask_acc) end)
# Broadcast the 2D mask to match the batch size (apply to every 2D tensor in the batch)
broadcasted_mask = Nx.broadcast(mask, {batch_size, height, width})
# Apply the broadcasted mask to the batch (element-wise multiplication)
Nx.multiply(batch, broadcasted_mask)
end
end
DataSets
An implementation of a dataset for MFCC audio data.
This generally mimics what I did in HW1PT2, with some additional enhancements for performing batch
transforms during a collate
function (called by the DataLoader).
One challenge here was the runtime of the ingest_partition
. On my local dev machine (Macbook Pro, Apple Silocon, SSD) I can read in both the train and dev partitions in around 100 seconds. On a remote Fly.io machine (4 performance CPUs, 32 GB ram), this originally took closer to 1500 seconds. I added in some concurrent processing via Task.async_stream
(line 48) and was able to reduce the runtime closer to 300 seconds.
The most significant challenge, however, was the complete lack of support for constant time, random access data structures in Elixir. My HW1PT2 impl in Python used a binary search (via bisect.bisect_right
) to determine the in memory location of the requested index in the __getitem__
method of the Dataset. Replicating this approach in Elixir was impossible due to the runtime characteristics of Elixir lists. In Elixir, lists are immutable singly-linked lists - therefore random element access is O(n)
. I experimented with other, third-party libraries that emulate a list but provide O(log n)
guarantees for random element access. Even with that approach, creating batches of phonemes was too expensive to even start to train the model. The
best that I could do was to eliminate the context padding of each MFCC file, and concatenate them all together. Even with this approach, I was never able to complete a single training epoch, given that the runtime per epoch is seemingly in order of “hours” long.
defmodule AudioDataset do
@moduledoc """
A dataset module for ingesting and padding MFCC data.
"""
@max_context 60
defstruct [
:context,
:phonemes,
:augment,
:mfccs,
:transcripts,
:index_mapping,
:length,
:total
]
# Helper function to read data (assuming .npy files)
defp read_file(path) do
File.read!(path)
|> Nx.load_numpy!()
end
# Create a zero padding tensor of shape {MAX_CONTEXT, 28}
defp padding(_context) do
Nx.tensor(Nx.broadcast(0.0, {@max_context, 28}), type: :f32)
end
# Ingest a partition of MFCC and transcript data
defp ingest_partition(acc, root, _context, partition) do
mfcc_dir = Path.join([root, partition, "mfcc"])
transcript_dir = Path.join([root, partition, "transcript"])
mfcc_names = Path.wildcard(Path.join(mfcc_dir, "*.npy"))
transcript_names = Path.wildcard(Path.join(transcript_dir, "*.npy"))
if length(mfcc_names) != length(transcript_names) do
IO.inspect(length(mfcc_names))
IO.inspect(length(transcript_names))
raise "Mismatch between MFCC and transcript counts"
end
Enum.take(mfcc_names, 2500)
|> Task.async_stream(
fn mfcc_path ->
mfcc = read_file(mfcc_path)
mfcc = AudioTransforms.cepstral_mean_transform(mfcc)
transcript_path = Path.join(transcript_dir, Path.basename(mfcc_path))
transcript = read_file(transcript_path)
transcript = Nx.slice(transcript, [1], [Nx.size(transcript) - 2])
{mfcc, transcript}
end,
max_concurrency: System.schedulers_online() * 2,
ordered: true
)
|> Enum.reduce(acc, fn {:ok, {mfcc, transcript}}, acc ->
{_freq, length} = Nx.shape(mfcc)
mfccs = [mfcc | acc.mfccs]
transcripts = [transcript | acc.transcripts]
%{
acc
| transcripts: transcripts,
mfccs: mfccs,
total: acc.total + length
}
end)
end
# Initialize the dataset with the specified partitions
def new(root, context \\ 30, partition \\ "train-clean-100", augment \\ true) do
acc = %AudioDataset{
context: context,
phonemes: Globals.phonemes(),
augment: augment,
mfccs: [],
transcripts: [],
index_mapping: Arrays.new([]),
length: 0,
total: 0
}
acc = ingest_partition(acc, root, context, partition)
mfccs = Enum.reverse(acc.mfccs)
transcripts = Enum.reverse(acc.transcripts)
mfccs = Nx.concatenate(mfccs)
transcripts = Nx.concatenate(transcripts)
mfccs = Nx.concatenate([padding(context), mfccs, padding(context)])
%{acc | mfccs: mfccs, transcripts: transcripts}
end
def len(dataset), do: dataset.total
def get_item(dataset, ind) do
actual_index = ind + @max_context
before_context = dataset.context
after_context = dataset.context
lower_offset = actual_index - before_context
upper_offset = actual_index + after_context + 1
frames =
Nx.slice(dataset.mfccs, [lower_offset, 0], [
upper_offset - lower_offset,
28
])
phoneme = Nx.slice(dataset.transcripts, [actual_index - @max_context], [1])
{frames, phoneme}
end
def collate(dataset, batch) do
{batch_size, height, width} = Nx.shape(batch)
batch =
case dataset.augment do
true -> AudioTransforms.freq_time_masking_transform(batch)
_ -> batch
end
Nx.reshape(batch, {batch_size, height * width})
end
end
DataLoader
A DataLoader implementation which supports arbitrary batch size, asynchronous processing via a number of background “workers” and optional shuffling and data subseting.
Here we see the power of the Elixir/Erlang concurrency model in action. It takes 1 line of code to turn this serial, synchronous data loader into a more performant concurrent implementation. The call to Task.async_stream
with the :max_concurrency
option set our desired number of workers gives us a remarkably simple yet performant equivalent to Python’s DataLoader
defmodule DataLoader do
@moduledoc """
A DataLoader module for asynchronously fetching and processing dataset batches.
Supports asychronous processing with specified number of worker
"""
defstruct [:dataset_module, :dataset, :batch_size, :num_workers, :shuffle, :subset]
# Create a new DataLoader with a dataset, batch size, and number of async workers
def new(dataset_module, dataset, opts \\ []) do
# Set default values for options
opts = Keyword.merge([batch_size: 32, num_workers: 4, shuffle: true, subset: 0], opts)
%DataLoader{
dataset_module: dataset_module,
dataset: dataset,
batch_size: Keyword.get(opts, :batch_size),
num_workers: Keyword.get(opts, :num_workers),
shuffle: Keyword.get(opts, :shuffle),
subset: Keyword.get(opts, :subset)
}
end
# Public function to return a stream of batches
def data(%DataLoader{
dataset_module: dataset_module,
dataset: dataset,
batch_size: batch_size,
num_workers: num_workers,
subset: subset,
shuffle: shuffle
}) do
# Create a list of batch indices, subset and sort if needed
indices = 0..(dataset.total - 1)
indices =
case subset do
0 -> indices
n -> Enum.take_every(indices, n)
end
indices =
if shuffle do
Enum.shuffle(indices)
else
shuffle
end
# Support async processing by a number of workers, or synchronous
# when num_workers is 0. Syncrhonous is helpful for debugging.
case num_workers do
0 ->
Enum.chunk_every(indices, batch_size, batch_size, :discard)
|> Enum.map(fn batch -> prepare_batch(dataset_module, dataset, batch) end)
n ->
Stream.chunk_every(indices, batch_size, batch_size, :discard)
|> Task.async_stream(fn batch -> prepare_batch(dataset_module, dataset, batch) end,
max_concurrency: n,
ordered: !shuffle,
timeout: :infinity
)
|> Stream.map(fn
# Extract the result from the {:ok, result} tuple
{:ok, result} -> result
# Optional: Handle errors if needed
{:error, _reason} -> raise "Task failed"
end)
end
end
defp prepare_batch(dataset_module, dataset, batch_indices) do
# Map the indices that we need for this batch to
# frame and phoneme tuples, unzipping them to rearrange
# as a tuple of two separate lists
{x, y} =
batch_indices
|> Enum.map(fn idx -> apply(dataset_module, :get_item, [dataset, idx]) end)
|> Enum.unzip()
# Collate the batch and move it to the GPU
batch_x =
apply(dataset_module, :collate, [dataset, Nx.stack(x)])
|> Nx.backend_transfer(EXLA.Backend)
# Stack the y tensors and move to GPU
batch_y =
Nx.stack(y)
|> Nx.backend_transfer(EXLA.Backend)
{batch_x, batch_y}
end
end
train_data_set =
AudioDataset.new("#{config.data_root}/data", config.context, "train-clean-100", false)
IO.inspect(
"Train dataset loaded, #{train_data_set.length} mfccs, #{train_data_set.total} samples"
)
dev_data_set = AudioDataset.new("#{config.data_root}/data", config.context, "dev-clean", false)
IO.inspect("Dev dataset loaded, #{dev_data_set.length} mfccs, #{dev_data_set.total} samples")
train_data_loader =
DataLoader.new(AudioDataset, train_data_set, batch_size: config.batch_size, num_workers: 16)
dev_data_loader =
DataLoader.new(AudioDataset, dev_data_set, batch_size: config.batch_size, num_workers: 0)
Model Definition
defmodule Model do
def tiny(input_size, output_size) do
Axon.input("data", shape: {nil, input_size})
|> linear(512)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.35)
|> linear(512)
|> Axon.gelu()
|> Axon.dropout(rate: 0.3)
|> linear(512)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.25)
|> Axon.dense(output_size, activation: :softmax)
end
def full(input_size, output_size) do
Axon.input("data", shape: {nil, input_size})
|> linear(2048)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.35)
|> linear(2048)
|> Axon.gelu()
|> Axon.dropout(rate: 0.3)
|> linear(2048)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.25)
|> linear(2048)
|> Axon.gelu()
|> Axon.dropout(rate: 0.2)
|> linear(1024)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.15)
|> linear(512)
|> Axon.gelu()
|> Axon.dropout(rate: 0.1)
|> linear(256)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.05)
|> Axon.dense(output_size, activation: :softmax)
end
defp linear(model, size) do
Axon.dense(model, size,
kernel_initializer: Axon.Initializers.glorot_uniform(),
bias_initializer: Axon.Initializers.zeros()
)
end
def summarize(model, input_size) do
Axon.Display.as_graph(model, Nx.template({1, input_size}, :f32))
end
end
input_size = (2 * config.context + 1) * 28
output_size = Enum.count(Globals.phonemes())
model = Model.full(input_size, output_size)
{init_fn, pred_fn} = Axon.build(model)
Model.summarize(model, input_size)
Training
eval_loop =
Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "accuracy")
Nx.global_default_backend({EXLA.Backend, client: :cuda})
run_eval = fn state ->
model_state = state.step_state.model_state
eval_accuracy = Axon.Loop.run(eval_loop, DataLoader.data(dev_data_loader), model_state)
IO.inspect(eval_accuracy)
{:continue, state}
end
# Define the training loop
trainer_loop =
Axon.Loop.trainer(model, :categorical_cross_entropy, Polaris.Optimizers.adamw())
|> Axon.Loop.metric(:accuracy, "accuracy")
|> Axon.Loop.handle_event(:epoch_completed, run_eval)
|> Axon.Loop.run(DataLoader.data(train_data_loader), %{},
epochs: config.epochs,
compiler: EXLA,
client: :cuda
)
The FlameGraph
module was used to profile the runtime of the DataSet
and
DataLoader
.
defmodule FlameGraph do
@stack_files_dir "_flame_graph_stacks"
@svg_dir "./"
@url_base "./flame_graphs"
def create(func, tag) do
File.mkdir(@stack_files_dir)
filename = "#{@stack_files_dir}/#{tag}.out"
:eflame.apply(:normal_with_children, filename, func, [])
end
def list() do
Path.wildcard("#{@stack_files_dir}/*.out")
end
def to_svg(stack_filename) do
File.mkdir(@svg_dir)
svg_filename = Path.basename(stack_filename, ".out") <> ".svg"
"deps/eflame/stack_to_flame.sh < #{stack_filename} > #{@svg_dir}/#{svg_filename}"
|> String.to_charlist()
|> :os.cmd()
"#{@url_base}/#{svg_filename}"
end
end
With the above flame graph generation module, I profiled the runtime of the original DataLoader collation. The runtime was dominated by the calls to bisect_right
as it binary searched through a large Nx
array in order to find the correct MFCC
file for the given absolute index.
Results / Commentary
Pros:
-
Livebook’s direct support for deploying and running on a Fly.io GPU is fantastic. I rarely had to wait more than 10 seconds to obtain a machine with a high end GPU.
-
The
Axon
library syntax is pleasant to work with and to read. The pipe operator|>
seems a natural way to express the flow of data through the components and layers of a network.
Cons:
-
Lack of a data structure with constant time access to random elements has posed an insurmountable challenge. (The built in
list
in Elixir is a single linked list withO(n)
element access) Even third-party libraries with “fast array” support offer onlyO(log n)
random access - which also wasn’t efficient enough to allow random batches to be collated. Even at a tiny batch size of16
, the projected time to train 1 single epoch of the full training data set is on the order of days. -
There is no support for Wandb, Kaggle, and Google Drive mounting. This makes doing serious DL work difficult.
-
Livebook is still immature, and clearly has bugs. Numerous times I faced problems with cells re-running automatically when they didn’t need to be (this was problematic when the cell was reloading the entire dataset). I also encountered strange code editing phenemona, with my cursor seemingly jumping to other positions, backspace refusing to delete characters, etc.