Flamingo
Mix.install([
{:pythonx, "~> 0.4.2"},
{:kino_pythonx, "~> 0.1.0"},
{:kino, "~> 0.19.0"}
])
[project]
name = "project"
version = "0.0.0"
requires-python = "==3.13.*"
dependencies = [
"torch",
"torchaudio",
"transformers",
"accelerate",
"huggingface_hub",
"hf_transfer",
"librosa",
"safetensors",
"soundfile",
"pydub",
"audioop-lts",
]
Setup
{output, 0} = System.cmd("nvidia-smi", [])
IO.puts(output)
{_, 0} = System.cmd("bash", ["-c", "which ffmpeg || (apt update && apt install -y ffmpeg)"],
stderr_to_stdout: true
)
HF Setup
hf_token = System.get_env("LB_HF_TOKEN")
import os
os.environ["HF_TOKEN"] = hf_token.decode("utf-8")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
Patches
import site
from pathlib import Path
# Patch musicflamingo to cast audio embeds dtype before masked_scatter
# Find the file dynamically instead of hardcoding the path
site_paths = [Path(base) for base in site.getsitepackages()]
for path in [base / "transformers" / "models" / "musicflamingo" / "modeling_musicflamingo.py" for base in site_paths]:
if path.exists():
break
else:
raise SystemExit("Could not find modeling_musicflamingo.py")
text = path.read_text()
old = "inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))"
new = "inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype))"
if old in text:
text = text.replace(old, new)
path.write_text(text)
print("patched musicflamingo masked_scatter")
else:
print("musicflamingo already patched or pattern changed")
Flamingo
import torch
torch.backends.cudnn.enabled = False
# Patch missing FP8 dtypes
for name in ("float8_e8m0fnu", "float8_e4m3fnuz", "float8_e5m2fnuz"):
if not hasattr(torch, name):
setattr(torch, name, torch.uint8)
from transformers import AutoProcessor, AutoModelForSeq2SeqLM
model_id = "nvidia/audio-flamingo-next-think-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager",
).eval()
# The audio tower and projector require float32 internally (LayerNorm).
# Keep them in float32; the patched masked_scatter casts output to bf16 for the LLM.
model.model.audio_tower.to(torch.float32)
model.model.multi_modal_projector.to(torch.float32)
Register Worker
defmodule Flamingo.Worker do
use GenServer
@moduledoc "Callable API for Audio Flamingo via distributed Erlang"
def start_link(_opts \\ []) do
GenServer.start_link(__MODULE__, nil, name: {:global, :flamingo_worker})
end
def ask(audio_path, prompt, opts \\ []) do
GenServer.call({:global, :flamingo_worker}, {:ask, audio_path, prompt, opts}, :infinity)
end
@impl true
def init(_) do
{:ok, nil}
end
@impl true
def handle_call({:ask, audio_path, prompt, opts}, _from, state) do
max_new_tokens = Keyword.get(opts, :max_new_tokens, 1024)
repetition_penalty = Keyword.get(opts, :repetition_penalty, 1.2)
{result, 0} = Pythonx.eval("""
import torch
audio_path = "#{audio_path}"
prompt_text = \"\"\"#{prompt}\"\"\"
max_new_tokens = #{max_new_tokens}
repetition_penalty = #{repetition_penalty}
conversation = [
[
{
"role": "user",
"content": [
{"type": "text", "text": prompt_text},
{"type": "audio", "path": audio_path},
],
}
]
]
batch = processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(model.device)
if "input_features" in batch:
batch["input_features"] = batch["input_features"].to(torch.float32)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
generated = model.generate(
**batch,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
)
prompt_len = batch["input_ids"].shape[1]
completion = generated[:, prompt_len:]
result_text = processor.batch_decode(
completion,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
""", [])
reply = result |> Pythonx.decode() |> String.trim()
{:reply, reply, state}
end
end
{:ok, _pid} = Flamingo.Worker.start_link()
IO.puts("Flamingo worker registered as :flamingo_worker on #{node()}")
Test
conversation = [
[
{
"role": "user",
"content": [
{
"type": "text",
"text": (
"Transcribe the speech, identify important background sounds, "
"and mention approximate timestamps for key events."
),
},
{
"type": "audio",
"path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/videoplayback_superman.wav",
},
],
}
]
]
batch = processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(model.device)
if "input_features" in batch:
batch["input_features"] = batch["input_features"].to(torch.float32)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
generated = model.generate(
**batch,
max_new_tokens=1024,
repetition_penalty=1.2,
)
prompt_len = batch["input_ids"].shape[1]
completion = generated[:, prompt_len:]
text = processor.batch_decode(
completion,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
print(text)