Powered by AppSignal & Oban Pro

Flamingo

Flamingo.livemd

Flamingo

Mix.install([
  {:pythonx, "~> 0.4.2"},
  {:kino_pythonx, "~> 0.1.0"},
  {:kino, "~> 0.19.0"}
])
[project]
name = "project"
version = "0.0.0"
requires-python = "==3.13.*"
dependencies = [
    "torch",
    "torchaudio",
    "transformers",
    "accelerate",
    "huggingface_hub",
    "hf_transfer",
    "librosa",
    "safetensors",
    "soundfile",
    "pydub",
    "audioop-lts",
]

Setup

{output, 0} = System.cmd("nvidia-smi", [])
IO.puts(output)
{_, 0} = System.cmd("bash", ["-c", "which ffmpeg || (apt update && apt install -y ffmpeg)"],
  stderr_to_stdout: true
)

HF Setup

hf_token = System.get_env("LB_HF_TOKEN")
import os
os.environ["HF_TOKEN"] = hf_token.decode("utf-8")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

Patches

import site
from pathlib import Path

# Patch musicflamingo to cast audio embeds dtype before masked_scatter
# Find the file dynamically instead of hardcoding the path
site_paths = [Path(base) for base in site.getsitepackages()]
for path in [base / "transformers" / "models" / "musicflamingo" / "modeling_musicflamingo.py" for base in site_paths]:
    if path.exists():
        break
else:
    raise SystemExit("Could not find modeling_musicflamingo.py")

text = path.read_text()

old = "inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))"
new = "inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype))"

if old in text:
    text = text.replace(old, new)
    path.write_text(text)
    print("patched musicflamingo masked_scatter")
else:
    print("musicflamingo already patched or pattern changed")

Flamingo

import torch

torch.backends.cudnn.enabled = False

# Patch missing FP8 dtypes
for name in ("float8_e8m0fnu", "float8_e4m3fnuz", "float8_e5m2fnuz"):
    if not hasattr(torch, name):
        setattr(torch, name, torch.uint8)

from transformers import AutoProcessor, AutoModelForSeq2SeqLM

model_id = "nvidia/audio-flamingo-next-think-hf"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
).eval()

# The audio tower and projector require float32 internally (LayerNorm).
# Keep them in float32; the patched masked_scatter casts output to bf16 for the LLM.
model.model.audio_tower.to(torch.float32)
model.model.multi_modal_projector.to(torch.float32)

Register Worker

defmodule Flamingo.Worker do
  use GenServer

  @moduledoc "Callable API for Audio Flamingo via distributed Erlang"

  def start_link(_opts \\ []) do
    GenServer.start_link(__MODULE__, nil, name: {:global, :flamingo_worker})
  end

  def ask(audio_path, prompt, opts \\ []) do
    GenServer.call({:global, :flamingo_worker}, {:ask, audio_path, prompt, opts}, :infinity)
  end

  @impl true
  def init(_) do
    {:ok, nil}
  end

  @impl true
  def handle_call({:ask, audio_path, prompt, opts}, _from, state) do
    max_new_tokens = Keyword.get(opts, :max_new_tokens, 1024)
    repetition_penalty = Keyword.get(opts, :repetition_penalty, 1.2)

    {result, 0} = Pythonx.eval("""
import torch

audio_path = "#{audio_path}"
prompt_text = \"\"\"#{prompt}\"\"\"
max_new_tokens = #{max_new_tokens}
repetition_penalty = #{repetition_penalty}

conversation = [
    [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt_text},
                {"type": "audio", "path": audio_path},
            ],
        }
    ]
]

batch = processor.apply_chat_template(
    conversation,
    tokenize=True,
    add_generation_prompt=True,
    return_dict=True,
).to(model.device)

if "input_features" in batch:
    batch["input_features"] = batch["input_features"].to(torch.float32)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    generated = model.generate(
        **batch,
        max_new_tokens=max_new_tokens,
        repetition_penalty=repetition_penalty,
    )

prompt_len = batch["input_ids"].shape[1]
completion = generated[:, prompt_len:]
result_text = processor.batch_decode(
    completion,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)[0]
""", [])

    reply = result |> Pythonx.decode() |> String.trim()
    {:reply, reply, state}
  end
end

{:ok, _pid} = Flamingo.Worker.start_link()
IO.puts("Flamingo worker registered as :flamingo_worker on #{node()}")

Test

conversation = [
    [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": (
                        "Transcribe the speech, identify important background sounds, "
                        "and mention approximate timestamps for key events."
                    ),
                },
                {
                    "type": "audio",
                    "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/videoplayback_superman.wav",
                },
            ],
        }
    ]
]

batch = processor.apply_chat_template(
    conversation,
    tokenize=True,
    add_generation_prompt=True,
    return_dict=True,
).to(model.device)

if "input_features" in batch:
    batch["input_features"] = batch["input_features"].to(torch.float32)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    generated = model.generate(
        **batch,
        max_new_tokens=1024,
        repetition_penalty=1.2,
    )

prompt_len = batch["input_ids"].shape[1]
completion = generated[:, prompt_len:]
text = processor.batch_decode(
    completion,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)[0]

print(text)