11785 - HW1PT2
Mix.install(
[
{:arrays, "~> 2.1"},
{:exla, ">= 0.0.0"},
{:npy, "~> 0.1.1"},
{:axon, "~> 0.7.0"},
{:table_rex, "~> 3.1.1"},
{:kino, "~> 0.7.0"},
{:finch, "~> 0.19.0"},
{:jason, "~> 1.4"},
{:polaris, "~> 0.1.0"},
{:eflame, "~> 1.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Configuration
config = %{
context: 30,
epochs: 5,
batch_size: 1024,
data_root: "/home/livebook"
#data_root: "/Users/darren/dev/11785-project"
}
Kaggle Raw Data
The RawData
module retrieves and unpacks the MFCC and transcripts raw data. There were a
collection of challenges in getting the Kaggle data available in an Elixir LiveBook:
-
There is no simple method for downloading the Kaggle data for a competition from their site. They have a python
kaggle-api
, but lack a simple HTTP endpoint that one could use to issue aGET
request to download a competition.zip
file. Instead, I uploaded a copy of the kaggle data to a person S3 bucket that I had already created (for hosting a course project for a French course that I took a few years ago at CMU) -
When running a LiveBook on a remote Elixir node on Fly.io, I could not successfully unzip
the Kaggle data zip file. The builtin Erlang unzip
:erlang.unzip
would raise an error complaining that the zip file was invalid. To circumvent this, I created a.tar.gz
representation that was able to be unpacked on the remote server -
Finally, I ran into problems using Elixir’s built in support for reading
.npy
files viaNx.load_numpy!/1
. While I could use that function to read the MFCC data files, it would raise an error when attempting to read the transcript files. These files encoded the transcript data as numpy arrays of strings, but theNx
Elixir library only supports loading Numpy numerical arrays. To work around this, I wrote a small python script to rewrite these transcript files to be integer numpy arrays, with the integer being the index of the phoneme (from the list of phonemes).
defmodule RawData do
def ensure_available(config) do
if File.exists?("#{config.data_root}/data") do
:exists
else
download(config) |> untar(config)
end
end
defp download(config) do
url = "https://s3.amazonaws.com/votre-montreal.com/data.tar.gz"
destination = "#{config.data_root}/data.tar.gz"
#{_, exit_status} = System.cmd("curl", ["-o", destination, url])
{_, exit_status} = System.cmd("wget", [url, "-O", destination])
if exit_status == 0 do
:ok
else
:error
end
end
defp untar(:error, _), do: :error
defp untar(:ok, config) do
{_, exit_status} = System.cmd("tar", ["-xzvf", "#{config.data_root}/data.tar.gz"])
if exit_status == 0 do
:ok
else
:error
end
end
end
RawData.ensure_available(config)
Nx.Defn.global_default_options(compiler: EXLA)
defmodule Globals do
def phonemes(), do: [
"[SIL]",
"AA",
"AE",
"AH",
"AO",
"AW",
"AY",
"B",
"CH",
"D",
"DH",
"EH",
"ER",
"EY",
"F",
"G",
"HH",
"IH",
"IY",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OY",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UW",
"V",
"W",
"Y",
"Z",
"ZH",
"[SOS]",
"[EOS]"
]
end
Utilities
Elixir does not have a random access data structure like a python array, instead it has an immutable list. This was going to create a performance problem in my dataset implementation where I needed to perform a binary search to determine which individual mfcc file an index would be found in. To avoid this, I wrote FastArray
, an Nx
backed fixed size array implementation, which provides support for fast, random access, upon which I replicated python’s bisect_right
function.
Transformations
Here I have replicated the two transformation that I used in HW1PT2: the cepstral mean transform and a combined time and frequency masking transform.
defmodule Masks do
def freq(mask, low \\ 0.0, high \\ 28.0, f \\ 4) do
# Randomly choose f between 3 and F
f_value = :rand.uniform() * (f - 3.0) + 3.0
f_value = Float.floor(f_value) |> trunc()
# Randomly choose the starting frequency
f0_value = :rand.uniform() * (high - f_value - low) + low
f0_value = Float.floor(f0_value) |> trunc()
# Create the zero mask along the frequency dimension
mask_shape = Nx.shape(mask)
mask_slice = Nx.slice(mask, [0, f0_value], [elem(mask_shape, 0), f_value])
# Subtract the slice from the mask, effectively zeroing out that portion
mask = Nx.put_slice(mask, [0, f0_value], Nx.multiply(mask_slice, Nx.tensor(0.0)))
mask
end
def time(mask, f \\ 8) do
# Get the number of time steps (rows)
num_time_steps = Nx.shape(mask) |> elem(0)
# Randomly choose t between 3 and F
t_value = :rand.uniform() * (f - 3.0) + 3.0
t_value = Float.floor(t_value) |> trunc()
# Randomly choose the starting time step
t0_value = :rand.uniform() * (num_time_steps - t_value)
t0_value = Float.floor(t0_value) |> trunc()
# Create the zero mask along the time dimension
mask_shape = Nx.shape(mask)
mask_slice = Nx.slice(mask, [t0_value, 0], [t_value, elem(mask_shape, 1)])
# Subtract the slice from the mask, effectively zeroing out that portion
mask = Nx.put_slice(mask, [t0_value, 0], Nx.multiply(mask_slice, Nx.tensor(0.0)))
mask
end
end
defmodule AudioTransforms do
# Function to apply the cepstral mean and variance normalization
def cepstral_mean_transform(data) do
# Calculate the mean along axis 0 (mean for each column)
mean = Nx.mean(data, axes: [0])
# Subtract the mean from the data (center the data)
centered_data = Nx.subtract(data, mean)
# Calculate the standard deviation along axis 0
std_dev = Nx.standard_deviation(centered_data, axes: [0])
# Normalize the data by dividing by the standard deviation
Nx.divide(centered_data, std_dev)
end
def freq_time_masking_transform(batch) do
# Get the shape of a single entry in the batch (2D tensor)
{batch_size, height, width} = Nx.shape(batch)
# Create an initial mask of ones with shape (height, width) for a single entry
mask = Nx.broadcast(Nx.tensor(1.0), {height, width})
# Apply 0, 1, or 2 random frequency transformations
num_freq_xforms = :rand.uniform(3) - 1
mask = Enum.reduce(1..num_freq_xforms, mask, fn _, mask_acc -> Masks.freq(mask_acc) end)
# Apply 0, 1, or 2 random time transformations
num_time_xforms = :rand.uniform(3) - 1
mask = Enum.reduce(1..num_time_xforms, mask, fn _, mask_acc -> Masks.time(mask_acc) end)
# Broadcast the 2D mask to match the batch size (apply to every 2D tensor in the batch)
broadcasted_mask = Nx.broadcast(mask, {batch_size, height, width})
# Apply the broadcasted mask to the batch (element-wise multiplication)
Nx.multiply(batch, broadcasted_mask)
end
end
DataSets
An implementation of a dataset for MFCC audio data.
This generally mimics what I did in HW1PT2, with some additional enhancements for performing batch
transforms during a collate
function (called by the DataLoader).
A key challenge here is the runtime of the ingest_partition
. On my local dev machine (Macbook Pro, Apple Silocon, SSD) I can read in both the train and dev partitions in around 100 seconds. On a remote Fly.io machine (4 performance CPUs, 32 GB ram), this takes closer to 1500 seconds.
defmodule FlameGraph do
@stack_files_dir "_flame_graph_stacks"
@svg_dir "./"
@url_base "./flame_graphs"
def create(func, tag) do
File.mkdir(@stack_files_dir)
filename = "#{@stack_files_dir}/#{tag}.out"
:eflame.apply(:normal_with_children, filename, func, [])
end
def list() do
Path.wildcard("#{@stack_files_dir}/*.out")
end
def to_svg(stack_filename) do
File.mkdir(@svg_dir)
svg_filename = Path.basename(stack_filename, ".out") <> ".svg"
"deps/eflame/stack_to_flame.sh < #{stack_filename} > #{@svg_dir}/#{svg_filename}"
|> String.to_charlist()
|> :os.cmd()
"#{@url_base}/#{svg_filename}"
end
end
defmodule AudioDataset do
@moduledoc """
A dataset module for ingesting and padding MFCC data.
"""
@max_context 60
defstruct [
:context,
:phonemes,
:augment,
:mfccs,
:transcripts,
:index_mapping,
:length,
:total
]
# Helper function to read data (assuming .npy files)
defp read_file(path) do
File.read!(path)
|> Nx.load_numpy!()
end
# Create a zero padding tensor of shape {MAX_CONTEXT, 28}
defp padding(_context) do
Nx.tensor(Nx.broadcast(0.0, {@max_context, 28}), type: :f32)
end
# Ingest a partition of MFCC and transcript data
defp ingest_partition(acc, root, _context, partition) do
mfcc_dir = Path.join([root, partition, "mfcc"])
transcript_dir = Path.join([root, partition, "transcript"])
mfcc_names = Path.wildcard(Path.join(mfcc_dir, "*.npy"))
transcript_names = Path.wildcard(Path.join(transcript_dir, "*.npy"))
if length(mfcc_names) != length(transcript_names) do
IO.inspect(length(mfcc_names))
IO.inspect(length(transcript_names))
raise "Mismatch between MFCC and transcript counts"
end
Enum.take(mfcc_names, 2700)
|> Task.async_stream(fn mfcc_path ->
mfcc = read_file(mfcc_path)
mfcc = AudioTransforms.cepstral_mean_transform(mfcc)
transcript_path = Path.join(transcript_dir, Path.basename(mfcc_path))
transcript = read_file(transcript_path)
transcript = Nx.slice(transcript, [1], [Nx.size(transcript) - 2])
{mfcc, transcript}
end, max_concurrency: System.schedulers_online() * 2, ordered: true)
|> Enum.reduce(acc, fn {:ok, {mfcc, transcript}}, acc ->
{_freq, length} = Nx.shape(mfcc)
mfccs = [mfcc | acc.mfccs]
transcripts = [transcript | acc.transcripts]
%{
acc
| transcripts: transcripts, mfccs: mfccs, total: acc.total + length
}
end)
end
# Initialize the dataset with the specified partitions
def new(root, context \\ 30, partition \\ "train-clean-100", augment \\ true) do
acc = %AudioDataset{
context: context,
phonemes: Globals.phonemes(),
augment: augment,
mfccs: [],
transcripts: [],
index_mapping: Arrays.new([]),
length: 0,
total: 0
}
acc = ingest_partition(acc, root, context, partition)
mfccs = Enum.reverse(acc.mfccs)
transcripts = Enum.reverse(acc.transcripts)
mfccs = Nx.concatenate(mfccs)
transcripts = Nx.concatenate(transcripts)
mfccs = Nx.concatenate([padding(context), mfccs, padding(context)])
%{acc | mfccs: mfccs, transcripts: transcripts}
end
def len(dataset), do: dataset.total
def get_item(dataset, ind) do
actual_index = ind + @max_context
before_context = dataset.context
after_context = dataset.context
lower_offset = actual_index - before_context
upper_offset = actual_index + after_context + 1
frames = Nx.slice(dataset.mfccs, [lower_offset, 0], [
upper_offset - lower_offset,
28
])
phoneme = Nx.slice(dataset.transcripts, [actual_index - @max_context], [1])
{frames, phoneme}
end
def collate(dataset, batch) do
{batch_size, height, width} = Nx.shape(batch)
batch = case dataset.augment do
true -> AudioTransforms.freq_time_masking_transform(batch)
_ -> batch
end
Nx.reshape(batch, {batch_size, height * width})
end
end
DataLoader
A DataLoader implementation which supports arbitrary batch size, asynchronous processing via a number of background “workers” and optional shuffling and data subseting.
Here we see the power of the Elixir/Erlang concurrency model in action. It takes 1 line of code to turn this serial, synchronous data loader into a much more performant concurrent implementation. The call to Task.async_stream
with the :max_concurrency
option set our desired number of workers gives us a remarkably simple yet performant equivalent to Python’s DataLoader
defmodule DataLoader do
@moduledoc """
A DataLoader module for asynchronously fetching and processing dataset batches.
Supports asychronous processing with specified number of worker
"""
defstruct [:dataset_module, :dataset, :batch_size, :num_workers, :shuffle, :subset]
# Create a new DataLoader with a dataset, batch size, and number of async workers
def new(dataset_module, dataset, opts \\ []) do
# Set default values for options
opts = Keyword.merge([batch_size: 32, num_workers: 4, shuffle: true, subset: 0], opts)
%DataLoader{
dataset_module: dataset_module,
dataset: dataset,
batch_size: Keyword.get(opts, :batch_size),
num_workers: Keyword.get(opts, :num_workers),
shuffle: Keyword.get(opts, :shuffle),
subset: Keyword.get(opts, :subset),
}
end
# Public function to return a stream of batches
def data(%DataLoader{dataset_module: dataset_module, dataset: dataset,
batch_size: batch_size, num_workers: num_workers, subset: subset, shuffle: shuffle}) do
# Create a list of batch indices, subset and sort if needed
indices = 0..(dataset.total - 1)
indices = case subset do
0 -> indices
n -> Enum.take_every(indices, n)
end
indices = if shuffle do Enum.shuffle(indices) else shuffle end
# Support async processing by a number of workers, or synchronous
# when num_workers is 0. Syncrhonous is helpful for debugging.
case num_workers do
0 ->
Enum.chunk_every(indices, batch_size, batch_size, :discard)
|> Enum.map(fn batch -> prepare_batch(dataset_module, dataset, batch) end)
n ->
Stream.chunk_every(indices, batch_size, batch_size, :discard)
|> Task.async_stream(fn batch -> prepare_batch(dataset_module, dataset, batch) end, max_concurrency: n, ordered: !shuffle, timeout: :infinity)
|> Stream.map(fn
{:ok, result} -> result # Extract the result from the {:ok, result} tuple
{:error, _reason} -> raise "Task failed" # Optional: Handle errors if needed
end)
end
end
defp prepare_batch(dataset_module, dataset, batch_indices) do
# Map the indices that we need for this batch to
# frame and phoneme tuples, unzipping them to rearrange
# as a tuple of two separate lists
{x, y} = batch_indices
|> Enum.map(fn idx -> apply(dataset_module, :get_item, [dataset, idx]) end)
|> Enum.unzip()
# Collate the batch and move it to the GPU
batch_x = apply(dataset_module, :collate, [dataset, Nx.stack(x)])
|> Nx.backend_transfer(EXLA.Backend)
# Stack the y tensors and move to GPU
batch_y = Nx.stack(y)
|> Nx.backend_transfer(EXLA.Backend)
{batch_x, batch_y}
end
end
train_data_set = AudioDataset.new("#{config.data_root}/data", config.context, "train-clean-100", false)
IO.inspect("Train dataset loaded, #{train_data_set.length} mfccs, #{train_data_set.total} samples")
dev_data_set = AudioDataset.new("#{config.data_root}/data", config.context, "dev-clean", false)
IO.inspect("Dev dataset loaded, #{dev_data_set.length} mfccs, #{dev_data_set.total} samples")
train_data_loader = DataLoader.new(AudioDataset, train_data_set, batch_size: config.batch_size, num_workers: 0)
dev_data_loader = DataLoader.new(AudioDataset, dev_data_set, batch_size: config.batch_size, num_workers: 0)
Model Definition
defmodule Model do
def tiny(input_size, output_size) do
Axon.input("data", shape: {nil, input_size})
|> linear(512)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.35)
|> linear(512)
|> Axon.gelu()
|> Axon.dropout(rate: 0.3)
|> linear(512)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.25)
|> Axon.dense(output_size, activation: :softmax)
end
def full(input_size, output_size) do
Axon.input("data", shape: {nil, input_size})
|> linear(2048)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.35)
|> linear(2048)
|> Axon.gelu()
|> Axon.dropout(rate: 0.3)
|> linear(2048)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.25)
|> linear(2048)
|> Axon.gelu()
|> Axon.dropout(rate: 0.2)
|> linear(1024)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.15)
|> linear(512)
|> Axon.gelu()
|> Axon.dropout(rate: 0.1)
|> linear(256)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.05)
|> Axon.dense(output_size, activation: :softmax)
end
defp linear(model, size) do
Axon.dense(model, size,
kernel_initializer: Axon.Initializers.glorot_uniform(),
bias_initializer: Axon.Initializers.zeros()
)
end
def summarize(model, input_size) do
Axon.Display.as_graph(model, Nx.template({1, input_size}, :f32))
end
end
input_size = (2 * config.context + 1) * 28
output_size = Enum.count(Globals.phonemes())
model = Model.full(input_size, output_size)
#Model.summarize(model, input_size)
{init_fn, pred_fn} = Axon.build(model)
Training
eval_loop = Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "accuracy")
Nx.global_default_backend({EXLA.Backend, client: :cuda})
# {train_data_loader, dev_data_loader} = Data.get(config)
run_eval = fn state ->
model_state = state.step_state.model_state
eval_accuracy = Axon.Loop.run(eval_loop, DataLoader.data(dev_data_loader), model_state)
IO.inspect(eval_accuracy)
{:continue, state}
end
# Define the training loop
trainer_loop =
Axon.Loop.trainer(model, :categorical_cross_entropy, Polaris.Optimizers.adamw())
|> Axon.Loop.metric(:accuracy, "accuracy")
#|> Axon.Loop.handle_event(:epoch_completed, run_eval)
|> Axon.Loop.run(DataLoader.data(train_data_loader), %{}, epochs: config.epochs, compiler: EXLA, client: :cuda)
#first_batch = Enum.at(images, 0)