11785 - HW1PT2
{:exla, ">= 0.0.0"},
{:npy, "~> 0.1.1"},
{:axon, "~> 0.7.0"},
{:table_rex, "~> 3.1.1"},
{:kino, "~> 0.7.0"},
{:finch, "~> 0.19.0"},
{:jason, "~> 1.4"},
{:polaris, "~> 0.1.0"},
{:eflame, "~> 1.0"}
config: [nx: [default_backend: EXLA.Backend]]
config = %{
context: 30,
epochs: 5,
batch_size: 1024,
#data_root: "/home/livebook"
data_root: "/Users/darren/dev/11785-project"
Kaggle Raw Data
The RawData
module retrieves and unpacks the MFCC and transcripts raw data. There were a
collection of challenges in getting the Kaggle data available in an Elixir LiveBook:
There is no simple method for downloading the Kaggle data for a competition from their site. They have a python
, but lack a simple HTTP endpoint that one could use to issue aGET
request to download a competition.zip
file. Instead, I uploaded a copy of the kaggle data to a person S3 bucket that I had already created (for hosting a course project for a French course that I took a few years ago at CMU) -
When running a LiveBook on a remote Elixir node on Fly.io, I could not successfully unzip
the Kaggle data zip file. The builtin Erlang unzip
would raise an error complaining that the zip file was invalid. To circumvent this, I created a.tar.gz
representation that was able to be unpacked on the remote server -
Finally, I ran into problems using Elixir’s built in support for reading
files viaNx.load_numpy!/1
. While I could use that function to read the MFCC data files, it would raise an error when attempting to read the transcript files. These files encoded the transcript data as numpy arrays of strings, but theNx
Elixir library only supports loading Numpy numerical arrays. To work around this, I wrote a small python script to rewrite these transcript files to be integer numpy arrays, with the integer being the index of the phoneme (from the list of phonemes).
defmodule RawData do
def ensure_available(config) do
if File.exists?("#{config.data_root}/data") do
download(config) |> untar(config)
defp download(config) do
url = "https://s3.amazonaws.com/votre-montreal.com/data.tar.gz"
destination = "#{config.data_root}/data.tar.gz"
#{_, exit_status} = System.cmd("curl", ["-o", destination, url])
{_, exit_status} = System.cmd("wget", [url, "-O", destination])
if exit_status == 0 do
defp untar(:error, _), do: :error
defp untar(:ok, config) do
{_, exit_status} = System.cmd("tar", ["-xzvf", "#{config.data_root}/data.tar.gz"])
if exit_status == 0 do
defmodule Globals do
def phonemes(), do: [
Elixir does not have a random access data structure like a python array, instead it has an immutable list. This was going to create a performance problem in my dataset implementation where I needed to perform a binary search to determine which individual mfcc file an index would be found in. To avoid this, I wrote FastArray
, an Nx
backed fixed size array implementation, which provides support for fast, random access, upon which I replicated python’s bisect_right
defmodule FastArray do
@moduledoc """
A module that provides a fast, fixed-size array backed by a 1D Nx
defstruct [:storage, :type]
# Create a new LongArray of the given size and default value.
@spec new(pos_integer(), integer()) :: %__MODULE__{}
def new(size, default_value, type \\ {:s, 64}) when is_integer(default_value) and is_integer(size) and size > 0 do
default_value_tensor = Nx.tensor(default_value, type: type) # Create a tensor with the appropriate type
storage: Nx.broadcast(default_value_tensor, {size}) ,
type: type
# Get the value at the specified index
def get(tensor, index) when is_integer(index) and index >= 0 do
|> Nx.slice([index], [1]) # Slice out one element starting at `index`
|> Nx.to_flat_list() # Convert the slice to a flat list
|> hd()
# Put a value into the specified index
def put(%__MODULE__{storage: storage, type: type} = fast_array, index, value) when is_integer(index) and index >= 0 do
storage = Nx.put_slice(storage, [index], Nx.tensor([value], type: type))
%{fast_array | storage: storage}
# Bisect right function - find the insertion point to maintain sorted order
def bisect_right(%FastArray{storage: tensor}, value) do
size = Nx.size(tensor) # Get the size of the tensor
bisect_right_helper(tensor, value, 0, size - 1)
# Recursive helper function for bisect_right
defp bisect_right_helper(tensor, value, low, high) when low <= high do
mid = div(low + high, 2)
mid_value = get(%FastArray{storage: tensor}, mid)
cond do
mid_value <= value -> bisect_right_helper(tensor, value, mid + 1, high)
true -> bisect_right_helper(tensor, value, low, mid - 1)
defp bisect_right_helper(_, _, low, _), do: low
Here I have replicated the two transformation that I used in HW1PT2: the cepstral mean transform and a combined time and frequency masking transform.
defmodule Masks do
def freq(mask, low \\ 0.0, high \\ 28.0, f \\ 4) do
# Randomly choose f between 3 and F
f_value = :rand.uniform() * (f - 3.0) + 3.0
f_value = Float.floor(f_value) |> trunc()
# Randomly choose the starting frequency
f0_value = :rand.uniform() * (high - f_value - low) + low
f0_value = Float.floor(f0_value) |> trunc()
# Create the zero mask along the frequency dimension
mask_shape = Nx.shape(mask)
mask_slice = Nx.slice(mask, [0, f0_value], [elem(mask_shape, 0), f_value])
# Subtract the slice from the mask, effectively zeroing out that portion
mask = Nx.put_slice(mask, [0, f0_value], Nx.multiply(mask_slice, Nx.tensor(0.0)))
def time(mask, f \\ 8) do
# Get the number of time steps (rows)
num_time_steps = Nx.shape(mask) |> elem(0)
# Randomly choose t between 3 and F
t_value = :rand.uniform() * (f - 3.0) + 3.0
t_value = Float.floor(t_value) |> trunc()
# Randomly choose the starting time step
t0_value = :rand.uniform() * (num_time_steps - t_value)
t0_value = Float.floor(t0_value) |> trunc()
# Create the zero mask along the time dimension
mask_shape = Nx.shape(mask)
mask_slice = Nx.slice(mask, [t0_value, 0], [t_value, elem(mask_shape, 1)])
# Subtract the slice from the mask, effectively zeroing out that portion
mask = Nx.put_slice(mask, [t0_value, 0], Nx.multiply(mask_slice, Nx.tensor(0.0)))
defmodule AudioTransforms do
# Function to apply the cepstral mean and variance normalization
def cepstral_mean_transform(data) do
# Calculate the mean along axis 0 (mean for each column)
mean = Nx.mean(data, axes: [0])
# Subtract the mean from the data (center the data)
centered_data = Nx.subtract(data, mean)
# Calculate the standard deviation along axis 0
std_dev = Nx.standard_deviation(centered_data, axes: [0])
# Normalize the data by dividing by the standard deviation
Nx.divide(centered_data, std_dev)
def freq_time_masking_transform(batch) do
# Get the shape of a single entry in the batch (2D tensor)
{batch_size, height, width} = Nx.shape(batch)
# Create an initial mask of ones with shape (height, width) for a single entry
mask = Nx.broadcast(Nx.tensor(1.0), {height, width})
# Apply 0, 1, or 2 random frequency transformations
num_freq_xforms = :rand.uniform(3) - 1
mask = Enum.reduce(1..num_freq_xforms, mask, fn _, mask_acc -> Masks.freq(mask_acc) end)
# Apply 0, 1, or 2 random time transformations
num_time_xforms = :rand.uniform(3) - 1
mask = Enum.reduce(1..num_time_xforms, mask, fn _, mask_acc -> Masks.time(mask_acc) end)
# Broadcast the 2D mask to match the batch size (apply to every 2D tensor in the batch)
broadcasted_mask = Nx.broadcast(mask, {batch_size, height, width})
# Apply the broadcasted mask to the batch (element-wise multiplication)
Nx.multiply(batch, broadcasted_mask)
An implementation of a dataset for MFCC audio data.
This generally mimics what I did in HW1PT2, with some additional enhancements for performing batch
transforms during a collate
function (called by the DataLoader).
A key challenge here is the runtime of the ingest_partition
. On my local dev machine (Macbook Pro, Apple Silocon, SSD) I can read in both the train and dev partitions in around 100 seconds. On a remote Fly.io machine (4 performance CPUs, 32 GB ram), this takes closer to 1500 seconds.
defmodule FlameGraph do
@stack_files_dir "_flame_graph_stacks"
@svg_dir "./"
@url_base "./flame_graphs"
def create(func, tag) do
filename = "#{@stack_files_dir}/#{tag}.out"
:eflame.apply(:normal_with_children, filename, func, [])
def list() do
def to_svg(stack_filename) do
svg_filename = Path.basename(stack_filename, ".out") <> ".svg"
"deps/eflame/stack_to_flame.sh < #{stack_filename} > #{@svg_dir}/#{svg_filename}"
|> String.to_charlist()
|> :os.cmd()
defmodule AudioDataset do
@moduledoc """
A dataset module for ingesting and padding MFCC data.
@max_context 60
defstruct [
# Helper function to read data (assuming .npy files)
defp read_file(path) do
|> Nx.load_numpy!()
# Create a zero padding tensor of shape {MAX_CONTEXT, 28}
defp padding(_context) do
Nx.tensor(Nx.broadcast(0.0, {@max_context, 28}), type: :f32)
# Pad the input array on both sides with padding
defp pad(arr, context) do
padding = padding(context)
Nx.concatenate([padding, arr, padding])
# Add MFCC data to dataset, padding it appropriately
defp add(i, data, context, total, acc) do
length = Nx.shape(data) |> elem(0)
padded = pad(data, context)
%{acc | mfccs: Map.put(acc.mfccs, i, padded), total: total + length}
# Ingest a partition of MFCC and transcript data
defp ingest_partition(acc, root, context, partition) do
mfcc_dir = Path.join([root, partition, "mfcc"])
transcript_dir = Path.join([root, partition, "transcript"])
mfcc_names = Path.wildcard(Path.join(mfcc_dir, "*.npy"))
transcript_names = Path.wildcard(Path.join(transcript_dir, "*.npy"))
if length(mfcc_names) != length(transcript_names) do
raise "Mismatch between MFCC and transcript counts"
index_mapping = FastArray.new(length(mfcc_names), 0)
|> Task.async_stream(fn {mfcc_path, i} ->
mfcc = read_file(mfcc_path)
mfcc = AudioTransforms.cepstral_mean_transform(mfcc)
transcript_path = Path.join(transcript_dir, Path.basename(mfcc_path))
transcript = read_file(transcript_path)
transcript = Nx.slice(transcript, [1], [Nx.size(transcript) - 2])
{i, mfcc, transcript}
end, max_concurrency: System.schedulers_online() * 2, ordered: true)
|> Enum.reduce(acc, fn {:ok, {i, mfcc, transcript}}, acc ->
acc = add(i, mfcc, context, acc.total, acc)
| transcripts: Map.put(acc.transcripts, i, transcript),
index_mapping: FastArray.put(index_mapping, i, acc.total)
# Initialize the dataset with the specified partitions
def new(root, context \\ 30, partition \\ "train-clean-100", augment \\ true) do
acc = %AudioDataset{
context: context,
phonemes: Globals.phonemes(),
augment: augment,
mfccs: %{},
transcripts: %{},
index_mapping: nil,
length: 0,
total: 0
acc = ingest_partition(acc, root, context, partition)
%{acc | length: Map.keys(acc.mfccs) |> Enum.count()}
def len(dataset), do: dataset.length
def get_item(dataset, ind) do
line = locate_line(dataset.index_mapping, ind)
actual_index = case line do
0 -> ind + @max_context
_n -> ind - FastArray.get(dataset.index_mapping, line - 1) + @max_context
before_context = dataset.context
after_context = dataset.context
lower_offset = actual_index - before_context
upper_offset = actual_index + after_context + 1
frames = Map.get(dataset.mfccs, line)
frames = Nx.slice(frames, [lower_offset, 0], [
upper_offset - lower_offset,
phoneme_line = Map.get(dataset.transcripts, line)
phoneme = Nx.slice(phoneme_line, [actual_index - @max_context], [1])
{frames, phoneme}
def collate(dataset, batch) do
{batch_size, height, width} = Nx.shape(batch)
batch = case dataset.augment do
true -> AudioTransforms.freq_time_masking_transform(batch)
_ -> batch
Nx.reshape(batch, {batch_size, height * width})
# Helper function to find the correct line
defp locate_line(index_mapping, ind) do
FastArray.bisect_right(index_mapping, ind)
A DataLoader implementation which supports arbitrary batch size, asynchronous processing via a number of background “workers” and optional shuffling and data subseting.
Here we see the power of the Elixir/Erlang concurrency model in action. It takes 1 line of code to turn this serial, synchronous data loader into a much more performant concurrent implementation. The call to Task.async_stream
with the :max_concurrency
option set our desired number of workers gives us a remarkably simple yet performant equivalent to Python’s DataLoader
defmodule DataLoader do
@moduledoc """
A DataLoader module for asynchronously fetching and processing dataset batches.
Supports asychronous processing with specified number of worker
defstruct [:dataset_module, :dataset, :batch_size, :num_workers, :shuffle, :subset]
# Create a new DataLoader with a dataset, batch size, and number of async workers
def new(dataset_module, dataset, opts \\ []) do
# Set default values for options
opts = Keyword.merge([batch_size: 32, num_workers: 4, shuffle: true, subset: 0], opts)
dataset_module: dataset_module,
dataset: dataset,
batch_size: Keyword.get(opts, :batch_size),
num_workers: Keyword.get(opts, :num_workers),
shuffle: Keyword.get(opts, :shuffle),
subset: Keyword.get(opts, :subset),
# Public function to return a stream of batches
def data(%DataLoader{dataset_module: dataset_module, dataset: dataset,
batch_size: batch_size, num_workers: num_workers, subset: subset, shuffle: shuffle}) do
# Create a list of batch indices, subset and sort if needed
indices = 0..(dataset.total - 1)
indices = case subset do
0 -> indices
n -> Enum.take_every(indices, n)
indices = if shuffle do Enum.shuffle(indices) else shuffle end
# Support async processing by a number of workers, or synchronous
# when num_workers is 0. Syncrhonous is helpful for debugging.
case num_workers do
0 ->
Stream.chunk_every(indices, batch_size, batch_size, :discard)
|> Stream.map(fn batch -> prepare_batch(dataset_module, dataset, batch) end)
n ->
Stream.chunk_every(indices, batch_size, batch_size, :discard)
|> Task.async_stream(fn batch -> prepare_batch(dataset_module, dataset, batch) end, max_concurrency: n, ordered: !shuffle, timeout: :infinity)
|> Stream.map(fn
{:ok, result} -> result # Extract the result from the {:ok, result} tuple
{:error, _reason} -> raise "Task failed" # Optional: Handle errors if needed
defp prepare_batch(dataset_module, dataset, batch_indices) do
# Map the indices that we need for this batch to
# frame and phoneme tuples, unzipping them to rearrange
# as a tuple of two separate lists
{x, y} = batch_indices
|> Enum.map(fn idx -> apply(dataset_module, :get_item, [dataset, idx]) end)
|> Enum.unzip()
# Collate the batch and move it to the GPU
batch_x = apply(dataset_module, :collate, [dataset, Nx.stack(x)])
|> Nx.backend_transfer(EXLA.Backend)
# Stack the y tensors and move to GPU
batch_y = Nx.stack(y)
|> Nx.backend_transfer(EXLA.Backend)
{batch_x, batch_y} end, String.to_atom("#{:rand.uniform()}"))
train_data_set = AudioDataset.new("#{config.data_root}/data", config.context, "train-clean-100", true)
IO.inspect("Train dataset loaded, #{train_data_set.length} mfccs, #{train_data_set.total} samples")
dev_data_set = AudioDataset.new("#{config.data_root}/data", config.context, "dev-clean", false)
IO.inspect("Train dataset loaded, #{dev_data_set.length} mfccs, #{dev_data_set.total} samples")
train_data_loader = DataLoader.new(AudioDataset, train_data_set, batch_size: config.batch_size, num_workers: 8)
dev_data_loader = DataLoader.new(AudioDataset, dev_data_set, batch_size: config.batch_size, num_workers: 8)
Enum.take(DataLoader.data(train_data_loader), 1)
Model Definition
defmodule Model do
def tiny(input_size, output_size) do
Axon.input("data", shape: {nil, input_size})
|> linear(512)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.35)
|> linear(512)
|> Axon.gelu()
|> Axon.dropout(rate: 0.3)
|> linear(512)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.25)
|> Axon.dense(output_size, activation: :softmax)
def full(input_size, output_size) do
Axon.input("data", shape: {nil, input_size})
|> linear(2048)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.35)
|> linear(2048)
|> Axon.gelu()
|> Axon.dropout(rate: 0.3)
|> linear(2048)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.25)
|> linear(2048)
|> Axon.gelu()
|> Axon.dropout(rate: 0.2)
|> linear(1024)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.15)
|> linear(512)
|> Axon.gelu()
|> Axon.dropout(rate: 0.1)
|> linear(256)
|> Axon.batch_norm()
|> Axon.gelu()
|> Axon.dropout(rate: 0.05)
|> Axon.dense(output_size, activation: :softmax)
defp linear(model, size) do
Axon.dense(model, size,
kernel_initializer: Axon.Initializers.glorot_uniform(),
bias_initializer: Axon.Initializers.zeros()
def summarize(model, input_size) do
Axon.Display.as_graph(model, Nx.template({1, input_size}, :f32))
input_size = (2 * config.context + 1) * 28
output_size = Enum.count(Globals.phonemes())
model = Model.full(input_size, output_size)
#Model.summarize(model, input_size)
{init_fn, pred_fn} = Axon.build(model)
eval_loop = Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "accuracy")
{train_data_loader, dev_data_loader} = Data.get(config)
run_eval = fn state ->
model_state = state.step_state.model_state
eval_accuracy = Axon.Loop.run(eval_loop, DataLoader.data(dev_data_loader), model_state)
{:continue, state}
# Define the training loop
trainer_loop =
Axon.Loop.trainer(model, :categorical_cross_entropy, Polaris.Optimizers.adamw())
|> Axon.Loop.metric(:accuracy, "accuracy")
#|> Axon.Loop.handle_event(:epoch_completed, run_eval)
|> Axon.Loop.run(DataLoader.data(train_data_loader), %{}, epochs: config.epochs, compiler: EXLA)
