Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Audio classification with TensorFlow Lite

notebooks/audio_classification.livemd

Audio classification with TensorFlow Lite

Mix.install([
  {:nx_signal, "~> 0.1"},
  {:tflite_elixir, "~> 0.3.0"},
  {:nx, "~> 0.5"},
  {:kino, "~> 0.9.0"},
  {:req, "~> 0.3.6"}
])

Introduction

The task of identifying what an audio represents is called audio classification. An audio classification model is trained to recognize various audio events.

For example, you may train a model to recognize events representing three different events: clapping, finger snapping, and typing. TensorFlow Lite provides optimized pre-trained models that you can deploy in your mobile applications.

https://www.tensorflow.org/lite/examples/audio_classification/overview

Download model file

downloads_dir = System.tmp_dir!()
# for nerves demo user
# change to a directory with write-permission
# downloads_dir = "/data/livebook"

download = fn url ->
  save_as = Path.join(downloads_dir, URI.encode_www_form(url))
  unless File.exists?(save_as), do: Req.get!(url, output: save_as)
  save_as
end

data_files =
  [
    cpu_model:
      "https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite"
  ]
  |> Enum.map(fn {key, url} -> {key, download.(url)} end)
  |> Map.new()

Load model and embedded lables

model_buffer = File.read!(data_files.cpu_model)

labels =
  String.split(
    TFLiteElixir.FlatBufferModel.get_associated_file(model_buffer, "yamnet_label_list.txt"),
    "\n"
  )

{:ok, interpreter} = TFLiteElixir.Interpreter.new_from_buffer(model_buffer)

Record audio

input = Kino.Input.audio("Audio")
value = Kino.Input.read(input)

Downsampling to 16kHZ

recording = Nx.from_binary(value.data, :f32)
recording_length = elem(recording.shape, 0)
downsampled = Nx.slice(recording, [0], [recording_length], strides: [3])
downsampled_length = elem(downsampled.shape, 0)

Audio classification with TensorFlow Lite

top_k = 1
downsample_rate = 15600
sample_duration = 0.975
num_samples = trunc(downsampled_length / 15600)

for sample_index <- 0..(num_samples - 1) do
  sample =
    Nx.slice(
      downsampled,
      [sample_index * downsample_rate],
      [downsample_rate]
    )

  [out_tensor] = TFLiteElixir.Interpreter.predict(interpreter, sample)
  out_tensor = Nx.reshape(out_tensor, {:auto})

  sorted = Nx.argsort(out_tensor, direction: :desc)
  top_k_pred = Nx.to_flat_list(Nx.take(sorted, Nx.iota({top_k})))

  Enum.map(top_k_pred, fn pred_index ->
    start_time = Float.round(sample_index * sample_duration, 3)
    end_time = Float.round((sample_index + 1) * sample_duration, 3)
    IO.puts("[#{start_time}-#{end_time}]: #{Enum.at(labels, pred_index)}")
  end)
end

:ok