SAM
Mix.install([
{:pythonx, "~> 0.4.2"},
{:kino_pythonx, "~> 0.1.0"},
{:kino, "~> 0.19.0"}
])
[project]
name = "project"
version = "0.0.0"
requires-python = "==3.13.*"
dependencies = [
"sam_audio @ git+https://github.com/facebookresearch/sam-audio.git",
"torch",
"torchaudio",
"torchcodec>=0.2,<0.3",
"transformers>=4.54,<5",
"huggingface_hub>=0.34,<1.0",
"hf_transfer",
"pydub",
"audioop-lts",
"soundfile",
]
Monkey Patches
from pathlib import Path
import site
import numpy as np
import torch
import soundfile as sf
import torchaudio
# Monkey-patch torchaudio.load and torchaudio.save to use soundfile
# torchaudio 2.12+ hardcodes torchcodec and ignores the backend parameter
def _soundfile_load(uri, frame_offset=0, num_frames=-1, normalize=True,
channels_first=True, format=None, buffer_size=4096, backend=None):
data, sample_rate = sf.read(str(uri), start=frame_offset,
stop=None if num_frames == -1 else frame_offset + num_frames,
dtype="float32", always_2d=True)
waveform = torch.from_numpy(data.T) # (channels, samples)
if not channels_first:
waveform = waveform.T
return waveform, sample_rate
def _soundfile_save(uri, src, sample_rate, channels_first=True, format=None,
encoding=None, bits_per_sample=None, buffer_size=4096,
compression=None, backend=None):
if channels_first:
src = src.T # (samples, channels)
data = src.cpu().numpy() if isinstance(src, torch.Tensor) else np.asarray(src)
sf.write(str(uri), data, sample_rate)
torchaudio.load = _soundfile_load
torchaudio.save = _soundfile_save
site_paths = [Path(base) for base in site.getsitepackages()]
# Patch core/audio_visual_encoder/transforms.py
for path in [base / "core" / "audio_visual_encoder" / "transforms.py" for base in site_paths]:
if path.exists():
break
else:
raise SystemExit("Could not find core/audio_visual_encoder/transforms.py")
text = path.read_text()
text = text.replace(
"from torchcodec.decoders import AudioDecoder, VideoDecoder\n",
"import torchaudio\n",
)
text = text.replace(
"from torchcodec.decoders import VideoDecoder\n",
"",
)
text = text.replace(
""" def _load_audio(self, path: str):
ad = AudioDecoder(path, sample_rate=self.sampling_rate, num_channels=1)
return ad.get_all_samples().data
""",
""" def _load_audio(self, path: str):
wav, sample_rate = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sample_rate != self.sampling_rate:
wav = torchaudio.functional.resample(wav, sample_rate, self.sampling_rate)
return wav
""",
)
path.write_text(text)
print(f"patched {path}")
# Patch sam_audio/processor.py
for path in [base / "sam_audio" / "processor.py" for base in site_paths]:
if path.exists():
break
else:
raise SystemExit("Could not find sam_audio/processor.py")
text = path.read_text()
text = text.replace(
"from torchcodec.decoders import AudioDecoder, VideoDecoder\n",
"",
)
text = text.replace(
"from torchcodec.decoders import VideoDecoder\n",
"",
)
path.write_text(text)
print(f"patched {path}")
Setup
{output, 0} = System.cmd("nvidia-smi", [])
IO.puts(output)
{_, 0} = System.cmd("bash", ["-c", "which ffmpeg || (apt update && apt install -y ffmpeg)"],
stderr_to_stdout: true
)
HF Setup
hf_token = System.get_env("LB_HF_TOKEN")
import os
os.environ["HF_TOKEN"] = hf_token.decode("utf-8")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
Load Model
import gc
import types
from sam_audio import SAMAudio, SAMAudioProcessor
torch.backends.cudnn.enabled = False
device = "cuda"
dtype = torch.bfloat16
model = SAMAudio.from_pretrained("facebook/sam-audio-large", proxies=None, resume_download=False)
# Remove the vision encoder to save VRAM (audio-only usage)
vision_dim = model.vision_encoder.dim
del model.vision_encoder
model._vision_encoder_dim = vision_dim
def _get_video_features_audio_only(self, video, audio_features):
if video is not None:
raise ValueError("This service is audio-only; video inputs are disabled.")
batch_size, time_steps, _ = audio_features.shape
return audio_features.new_zeros(batch_size, self._vision_encoder_dim, time_steps)
model._get_video_features = types.MethodType(_get_video_features_audio_only, model)
gc.collect()
torch.cuda.empty_cache()
model = model.to(device, dtype).eval()
processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")
print(f"Model loaded on {device} with dtype {dtype}")
print(f"Sample rate: {processor.audio_sampling_rate}")
Register Worker
defmodule SAM.Worker do
use GenServer
@moduledoc "Callable API for SAM audio separation via distributed Erlang"
def start_link(_opts \\ []) do
GenServer.start_link(__MODULE__, nil, name: {:global, :sam_worker})
end
def separate(audio_path, description, opts \\ []) do
GenServer.call({:global, :sam_worker}, {:separate, audio_path, description, opts}, :infinity)
end
@impl true
def init(_) do
{:ok, nil}
end
@impl true
def handle_call({:separate, audio_path, description, opts}, _from, state) do
max_seconds = Keyword.get(opts, :max_seconds, 35)
output_dir = Keyword.get(opts, :output_dir, "/tmp/sam_output")
File.mkdir_p!(output_dir)
{result, 0} = Pythonx.eval("""
import os, subprocess, torch, torchaudio
audio_path = "#{audio_path}"
description = "#{description}"
max_seconds = #{max_seconds}
output_dir = "#{output_dir}"
# Check duration
import soundfile as sf_check
info = sf_check.info(audio_path)
if info.duration > max_seconds:
raise ValueError(f"Audio is {info.duration:.1f}s, max is {max_seconds}s")
inputs = processor(audios=[audio_path], descriptions=[description]).to(device)
with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
result = model.separate(inputs)
target = result.target[0] if isinstance(result.target, (list, tuple)) else result.target
residual = result.residual[0] if isinstance(result.residual, (list, tuple)) else result.residual
target_cpu = target.detach().to(torch.float32).cpu()
residual_cpu = residual.detach().to(torch.float32).cpu()
if target_cpu.ndim == 1:
target_cpu = target_cpu.unsqueeze(0)
if residual_cpu.ndim == 1:
residual_cpu = residual_cpu.unsqueeze(0)
sample_rate = int(processor.audio_sampling_rate)
base = os.path.splitext(os.path.basename(audio_path))[0]
target_path = os.path.join(output_dir, f"{base}_target.wav")
residual_path = os.path.join(output_dir, f"{base}_residual.wav")
torchaudio.save(target_path, target_cpu, sample_rate)
torchaudio.save(residual_path, residual_cpu, sample_rate)
result_str = f"{target_path}|{residual_path}|{sample_rate}"
""", [])
[target_path, residual_path, sample_rate_str] =
result |> Pythonx.decode() |> String.split("|")
reply = %{
target: target_path,
residual: residual_path,
sample_rate: String.to_integer(sample_rate_str)
}
{:reply, reply, state}
end
end
{:ok, _pid} = SAM.Worker.start_link()
IO.puts("SAM worker registered as :sam_worker on #{node()}")
Test Separation
import urllib.request
import subprocess
# Download a test audio file
test_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/videoplayback_superman.wav"
output_dir = "/tmp/sam_output"
os.makedirs(output_dir, exist_ok=True)
raw_audio = os.path.join(output_dir, "test_full.wav")
test_audio = os.path.join(output_dir, "test_30s.wav")
urllib.request.urlretrieve(test_url, raw_audio)
# Trim to 30 seconds (SAM max is 35s, longer audio causes OOM)
trim_result = subprocess.run([
"ffmpeg", "-y", "-i", raw_audio,
"-t", "30", test_audio
], capture_output=True, text=True)
if trim_result.returncode != 0:
raise RuntimeError(f"ffmpeg trim failed: {trim_result.stderr}")
print(f"Trimmed test audio to 30s: {test_audio}")
# Run separation
description = "speech"
inputs = processor(audios=[test_audio], descriptions=[description]).to(device)
with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
result = model.separate(inputs)
target = result.target[0] if isinstance(result.target, (list, tuple)) else result.target
residual = result.residual[0] if isinstance(result.residual, (list, tuple)) else result.residual
print(f"Target shape: {target.shape}")
print(f"Residual shape: {residual.shape}")
# Save outputs
target_path = os.path.join(output_dir, "target_speech.wav")
residual_path = os.path.join(output_dir, "residual.wav")
target_cpu = target.detach().to(torch.float32).cpu()
residual_cpu = residual.detach().to(torch.float32).cpu()
if target_cpu.ndim == 1:
target_cpu = target_cpu.unsqueeze(0)
if residual_cpu.ndim == 1:
residual_cpu = residual_cpu.unsqueeze(0)
sample_rate = int(processor.audio_sampling_rate)
torchaudio.save(target_path, target_cpu, sample_rate)
torchaudio.save(residual_path, residual_cpu, sample_rate)
print(f"Saved target to {target_path}")
print(f"Saved residual to {residual_path}")
print("DONE")
Listen to Results
IO.puts("Target (isolated speech):")
Kino.Audio.new(File.read!("/tmp/sam_output/target_speech.wav"), :wav)
IO.puts("Residual (everything else):")
Kino.Audio.new(File.read!("/tmp/sam_output/residual.wav"), :wav)