Powered by AppSignal & Oban Pro

SAM

SAM.livemd

SAM

Mix.install([
  {:pythonx, "~> 0.4.2"},
  {:kino_pythonx, "~> 0.1.0"},
  {:kino, "~> 0.19.0"}
])
[project]
name = "project"
version = "0.0.0"
requires-python = "==3.13.*"
dependencies = [
    "sam_audio @ git+https://github.com/facebookresearch/sam-audio.git",
    "torch",
    "torchaudio",
    "torchcodec>=0.2,<0.3",
    "transformers>=4.54,<5",
    "huggingface_hub>=0.34,<1.0",
    "hf_transfer",
    "pydub",
    "audioop-lts",
    "soundfile",
]

Monkey Patches

from pathlib import Path
import site
import numpy as np
import torch
import soundfile as sf
import torchaudio

# Monkey-patch torchaudio.load and torchaudio.save to use soundfile
# torchaudio 2.12+ hardcodes torchcodec and ignores the backend parameter

def _soundfile_load(uri, frame_offset=0, num_frames=-1, normalize=True,
                    channels_first=True, format=None, buffer_size=4096, backend=None):
    data, sample_rate = sf.read(str(uri), start=frame_offset,
                                stop=None if num_frames == -1 else frame_offset + num_frames,
                                dtype="float32", always_2d=True)
    waveform = torch.from_numpy(data.T)  # (channels, samples)
    if not channels_first:
        waveform = waveform.T
    return waveform, sample_rate

def _soundfile_save(uri, src, sample_rate, channels_first=True, format=None,
                    encoding=None, bits_per_sample=None, buffer_size=4096,
                    compression=None, backend=None):
    if channels_first:
        src = src.T  # (samples, channels)
    data = src.cpu().numpy() if isinstance(src, torch.Tensor) else np.asarray(src)
    sf.write(str(uri), data, sample_rate)

torchaudio.load = _soundfile_load
torchaudio.save = _soundfile_save

site_paths = [Path(base) for base in site.getsitepackages()]

# Patch core/audio_visual_encoder/transforms.py
for path in [base / "core" / "audio_visual_encoder" / "transforms.py" for base in site_paths]:
    if path.exists():
        break
else:
    raise SystemExit("Could not find core/audio_visual_encoder/transforms.py")

text = path.read_text()
text = text.replace(
    "from torchcodec.decoders import AudioDecoder, VideoDecoder\n",
    "import torchaudio\n",
)
text = text.replace(
    "from torchcodec.decoders import VideoDecoder\n",
    "",
)
text = text.replace(
    """    def _load_audio(self, path: str):
        ad = AudioDecoder(path, sample_rate=self.sampling_rate, num_channels=1)
        return ad.get_all_samples().data
""",
    """    def _load_audio(self, path: str):
        wav, sample_rate = torchaudio.load(path)
        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)
        if sample_rate != self.sampling_rate:
            wav = torchaudio.functional.resample(wav, sample_rate, self.sampling_rate)
        return wav
""",
)
path.write_text(text)
print(f"patched {path}")

# Patch sam_audio/processor.py
for path in [base / "sam_audio" / "processor.py" for base in site_paths]:
    if path.exists():
        break
else:
    raise SystemExit("Could not find sam_audio/processor.py")

text = path.read_text()
text = text.replace(
    "from torchcodec.decoders import AudioDecoder, VideoDecoder\n",
    "",
)
text = text.replace(
    "from torchcodec.decoders import VideoDecoder\n",
    "",
)
path.write_text(text)
print(f"patched {path}")

Setup

{output, 0} = System.cmd("nvidia-smi", [])
IO.puts(output)
{_, 0} = System.cmd("bash", ["-c", "which ffmpeg || (apt update && apt install -y ffmpeg)"],
  stderr_to_stdout: true
)

HF Setup

hf_token = System.get_env("LB_HF_TOKEN")
import os
os.environ["HF_TOKEN"] = hf_token.decode("utf-8")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

Load Model

import gc
import types
from sam_audio import SAMAudio, SAMAudioProcessor

torch.backends.cudnn.enabled = False

device = "cuda"
dtype = torch.bfloat16

model = SAMAudio.from_pretrained("facebook/sam-audio-large", proxies=None, resume_download=False)

# Remove the vision encoder to save VRAM (audio-only usage)
vision_dim = model.vision_encoder.dim
del model.vision_encoder
model._vision_encoder_dim = vision_dim

def _get_video_features_audio_only(self, video, audio_features):
    if video is not None:
        raise ValueError("This service is audio-only; video inputs are disabled.")
    batch_size, time_steps, _ = audio_features.shape
    return audio_features.new_zeros(batch_size, self._vision_encoder_dim, time_steps)

model._get_video_features = types.MethodType(_get_video_features_audio_only, model)
gc.collect()
torch.cuda.empty_cache()

model = model.to(device, dtype).eval()
processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")

print(f"Model loaded on {device} with dtype {dtype}")
print(f"Sample rate: {processor.audio_sampling_rate}")

Register Worker

defmodule SAM.Worker do
  use GenServer

  @moduledoc "Callable API for SAM audio separation via distributed Erlang"

  def start_link(_opts \\ []) do
    GenServer.start_link(__MODULE__, nil, name: {:global, :sam_worker})
  end

  def separate(audio_path, description, opts \\ []) do
    GenServer.call({:global, :sam_worker}, {:separate, audio_path, description, opts}, :infinity)
  end

  @impl true
  def init(_) do
    {:ok, nil}
  end

  @impl true
  def handle_call({:separate, audio_path, description, opts}, _from, state) do
    max_seconds = Keyword.get(opts, :max_seconds, 35)
    output_dir = Keyword.get(opts, :output_dir, "/tmp/sam_output")
    File.mkdir_p!(output_dir)

    {result, 0} = Pythonx.eval("""
import os, subprocess, torch, torchaudio

audio_path = "#{audio_path}"
description = "#{description}"
max_seconds = #{max_seconds}
output_dir = "#{output_dir}"

# Check duration
import soundfile as sf_check
info = sf_check.info(audio_path)
if info.duration > max_seconds:
    raise ValueError(f"Audio is {info.duration:.1f}s, max is {max_seconds}s")

inputs = processor(audios=[audio_path], descriptions=[description]).to(device)

with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
    result = model.separate(inputs)

target = result.target[0] if isinstance(result.target, (list, tuple)) else result.target
residual = result.residual[0] if isinstance(result.residual, (list, tuple)) else result.residual

target_cpu = target.detach().to(torch.float32).cpu()
residual_cpu = residual.detach().to(torch.float32).cpu()
if target_cpu.ndim == 1:
    target_cpu = target_cpu.unsqueeze(0)
if residual_cpu.ndim == 1:
    residual_cpu = residual_cpu.unsqueeze(0)

sample_rate = int(processor.audio_sampling_rate)

base = os.path.splitext(os.path.basename(audio_path))[0]
target_path = os.path.join(output_dir, f"{base}_target.wav")
residual_path = os.path.join(output_dir, f"{base}_residual.wav")

torchaudio.save(target_path, target_cpu, sample_rate)
torchaudio.save(residual_path, residual_cpu, sample_rate)

result_str = f"{target_path}|{residual_path}|{sample_rate}"
""", [])

    [target_path, residual_path, sample_rate_str] =
      result |> Pythonx.decode() |> String.split("|")

    reply = %{
      target: target_path,
      residual: residual_path,
      sample_rate: String.to_integer(sample_rate_str)
    }

    {:reply, reply, state}
  end
end

{:ok, _pid} = SAM.Worker.start_link()
IO.puts("SAM worker registered as :sam_worker on #{node()}")

Test Separation

import urllib.request
import subprocess

# Download a test audio file
test_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/videoplayback_superman.wav"
output_dir = "/tmp/sam_output"
os.makedirs(output_dir, exist_ok=True)

raw_audio = os.path.join(output_dir, "test_full.wav")
test_audio = os.path.join(output_dir, "test_30s.wav")
urllib.request.urlretrieve(test_url, raw_audio)

# Trim to 30 seconds (SAM max is 35s, longer audio causes OOM)
trim_result = subprocess.run([
    "ffmpeg", "-y", "-i", raw_audio,
    "-t", "30", test_audio
], capture_output=True, text=True)
if trim_result.returncode != 0:
    raise RuntimeError(f"ffmpeg trim failed: {trim_result.stderr}")
print(f"Trimmed test audio to 30s: {test_audio}")

# Run separation
description = "speech"

inputs = processor(audios=[test_audio], descriptions=[description]).to(device)

with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
    result = model.separate(inputs)

target = result.target[0] if isinstance(result.target, (list, tuple)) else result.target
residual = result.residual[0] if isinstance(result.residual, (list, tuple)) else result.residual

print(f"Target shape: {target.shape}")
print(f"Residual shape: {residual.shape}")

# Save outputs
target_path = os.path.join(output_dir, "target_speech.wav")
residual_path = os.path.join(output_dir, "residual.wav")

target_cpu = target.detach().to(torch.float32).cpu()
residual_cpu = residual.detach().to(torch.float32).cpu()
if target_cpu.ndim == 1:
    target_cpu = target_cpu.unsqueeze(0)
if residual_cpu.ndim == 1:
    residual_cpu = residual_cpu.unsqueeze(0)

sample_rate = int(processor.audio_sampling_rate)
torchaudio.save(target_path, target_cpu, sample_rate)
torchaudio.save(residual_path, residual_cpu, sample_rate)

print(f"Saved target to {target_path}")
print(f"Saved residual to {residual_path}")
print("DONE")

Listen to Results

IO.puts("Target (isolated speech):")
Kino.Audio.new(File.read!("/tmp/sam_output/target_speech.wav"), :wav)
IO.puts("Residual (everything else):")
Kino.Audio.new(File.read!("/tmp/sam_output/residual.wav"), :wav)