Powered by AppSignal & Oban Pro

UniRig 3D Model Rigging

elixir/unirig_generation.livemd

UniRig 3D Model Rigging

# Livebook setup - copy this entire cell to run
Mix.install([
  {:pythonx, "~> 0.4.7"},
  {:jason, "~> 1.4.4"},
  {:req, "~> 0.5.0"},
  {:opentelemetry_api, "~> 1.3"},
  {:opentelemetry, "~> 1.3"},
  {:opentelemetry_exporter, "~> 1.0"},
])

# Configure OpenTelemetry for console logging
Application.put_env(:opentelemetry, :span_processor, :batch)
Application.put_env(:opentelemetry, :traces_exporter, :none)
Application.put_env(:opentelemetry, :metrics_exporter, :none)
Application.put_env(:opentelemetry, :logs_exporter, :none)

Logger.configure(level: :info)

Setup Python Environment

# Initialize Python environment with UniRig dependencies
Pythonx.uv_init("""
[project]
name = "unirig-generation"
version = "0.0.0"
requires-python = "==3.11.*"
dependencies = [
  "numpy<2.0",
  "bpy==4.5.*",
  "pillow",
  "opencv-python",
  "torch==2.7.0",
  "torchvision",
  "pytorch-lightning",
  "lightning",
  "huggingface-hub",
  "einops",
  "tqdm",
  "trimesh",
  "cumm-cu118",
  "spconv-cu118==2.3.8",
  "torch_scatter @ https://data.pyg.org/whl/torch-2.7.0%2Bcu118/torch_scatter-2.1.2%2Bpt27cu118-cp311-cp311-win_amd64.whl ; sys_platform == 'win32'",
  "torch_scatter @ https://data.pyg.org/whl/torch-2.7.0%2Bcu118/torch_scatter-2.1.2%2Bpt27cu118-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux'",
  "torch-cluster @ https://data.pyg.org/whl/torch-2.7.0%2Bcu118/torch_cluster-1.6.3%2Bpt27cu118-cp311-cp311-win_amd64.whl ; sys_platform == 'win32'",
  "torch-cluster @ https://data.pyg.org/whl/torch-2.7.0%2Bcu118/torch_cluster-1.6.3%2Bpt27cu118-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux'",
  "scipy",
  "pyyaml",
  "omegaconf",
  "hydra-core",
  "fvcore",
  "point-cloud-utils",
  "transformers==4.51.3",
  "python-box",
  "addict",
  "timm",
  "fast-simplification",
  "open3d",
  "pyrender",
  "wandb",
  "torch",
  "libigl",
]

[tool.uv.sources]
torch = { index = "pytorch-cu118" }
torchvision = { index = "pytorch-cu118" }

[[tool.uv.index]]
name = "pytorch-cu118"
url = "https://download.pytorch.org/whl/cu118"
explicit = true
""")

IO.puts("✓ Python environment initialized with UniRig dependencies")

Configuration

# Configure the rigging parameters
config = %{
  mesh_path: "path/to/your/model.obj",  # Replace with actual 3D model path
  output_format: "usdc",
  seed: 42,
  skeleton_only: false,
  skin_only: false,
  skeleton_task: nil,  # Use default
  skin_task: nil       # Use default
}

IO.puts("Configuration:")
IO.inspect(config, pretty: true)

Model Download (Optional)

# Download UniRig models (optional - models will be downloaded automatically)
# This cell is optional as the models will be downloaded automatically during inference

# Uncomment to pre-download:
# repo_id = "VAST-AI/UniRig"
# IO.puts("Downloading UniRig models...")
# HuggingFaceDownloader.download_repo(repo_id, "pretrained_weights/UniRig", "UniRig", true)

UniRig 3D Model Rigging

# Save config to JSON for Python
config_json = Jason.encode!(config)
config_file = "/tmp/unirig_config_#{System.system_time(:millisecond)}.json"
File.write!(config_file, config_json)

# Run UniRig rigging
try do
  Pythonx.eval(~S"""
import json
import sys
import os
import warnings
from pathlib import Path
import torch
import yaml
from box import Box
import lightning as L
from math import ceil

# Suppress verbose warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', message='.*flash_attn.*')
warnings.filterwarnings('ignore', message='.*flash-attn.*')
warnings.filterwarnings('ignore', message='.*flash attention.*')
warnings.filterwarnings('ignore', message='.*BatchNorm.*')

# Create filtered stdout/stderr wrapper
class FilteredOutput:
    def __init__(self, original_stream):
        self.original_stream = original_stream
        self.filtered_patterns = [
            'flash_attn is disabled',
            'flash-attn is disabled',
            'flash attention is disabled',
            'use BatchNorm in ptv3obj',
            'WARNING: use BatchNorm',
        ]
        self.buffer = ''

    def write(self, text):
        if not text:
            return
        self.buffer += text
        if '\n' in self.buffer:
            lines = self.buffer.split('\n')
            self.buffer = lines[-1]
            complete_lines = lines[:-1]
            filtered_lines = []
            for line in complete_lines:
                should_filter = False
                line_lower = line.lower()
                for pattern in self.filtered_patterns:
                    if pattern.lower() in line_lower:
                        should_filter = True
                        break
                if not should_filter:
                    filtered_lines.append(line)
            if filtered_lines:
                self.original_stream.write('\n'.join(filtered_lines) + '\n')

    def flush(self):
        if self.buffer:
            should_filter = False
            buffer_lower = self.buffer.lower()
            for pattern in self.filtered_patterns:
                if pattern.lower() in buffer_lower:
                    should_filter = True
                    break
            if not should_filter:
                self.original_stream.write(self.buffer)
            self.buffer = ''
        self.original_stream.flush()

# Replace stdout and stderr with filtered versions
sys.stdout = FilteredOutput(sys.stdout)
sys.stderr = FilteredOutput(sys.stderr)

# Fix for PyTorch 2.7+ weights_only loading
torch.serialization.add_safe_globals([Box])

# Helper function to load configs
def load_config(task: str, path: str) -> Box:
    if path.endswith('.yaml'):
        path = path.removesuffix('.yaml')
    path += '.yaml'
    print(f"load {task} config: {path}")
    return Box(yaml.safe_load(open(path, 'r')))

# Helper function to run UniRig inference
def run_unirig_inference(
    task_path: str,
    seed: int,
    input_path: str,
    output_path: str,
    npz_dir: str,
    data_name: str = None,
    unirig_base_path: str = None
):
    # Patch UniRig's Exporter to use USD instead of FBX
    from src.data.exporter import Exporter
    original_export_fbx = Exporter._export_fbx

    def patched_export_fbx(self, path, vertices=None, joints=None, skin=None, parents=None, names=None, faces=None, extrude_size=0.03, group_per_vertex=-1, add_root=False, do_not_normalize=False, use_extrude_bone=True, use_connect_unique_child=True, extrude_from_parent=True, tails=None):
        import bpy
        import os
        if path.lower().endswith('.fbx'):
            path = path[:-4] + '.usdc'
        elif not path.lower().endswith(('.usd', '.usda', '.usdc')):
            path = path + '.usdc'

        self._safe_make_dir(path)
        self._clean_bpy()
        self._make_armature(
            vertices=vertices,
            joints=joints,
            skin=skin,
            parents=parents,
            names=names,
            faces=faces,
            extrude_size=extrude_size,
            group_per_vertex=group_per_vertex,
            add_root=add_root,
            do_not_normalize=do_not_normalize,
            use_extrude_bone=use_extrude_bone,
            use_connect_unique_child=use_connect_unique_child,
            extrude_from_parent=extrude_from_parent,
            tails=tails,
        )

        bpy.ops.wm.usd_export(
            filepath=path,
            export_materials=True,
            export_textures=True,
            relative_paths=False,
            export_uvmaps=True,
            export_armatures=True,
            selected_objects_only=False,
            visible_objects_only=False,
            use_instancing=False,
            evaluation_mode='RENDER'
        )

    Exporter._export_fbx = patched_export_fbx

    try:
        torch.set_float32_matmul_precision('high')
        L.seed_everything(seed, workers=True)

        task = load_config('task', task_path)
        mode = task.mode
        assert mode in ['train', 'predict', 'validate']

        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            from src.data.extract import get_files
            from src.data.datapath import Datapath
            from src.data.dataset import UniRigDatasetModule, DatasetConfig
            from src.data.transform import TransformConfig
            from src.tokenizer.spec import TokenizerConfig
            from src.tokenizer.parse import get_tokenizer
            from src.model.parse import get_model
            from src.system.parse import get_system, get_writer
            from src.inference.download import download
            from lightning.pytorch.callbacks import ModelCheckpoint

        data_config = load_config('data', os.path.join('configs/data', task.components.data))
        transform_config = load_config('transform', os.path.join('configs/transform', task.components.transform))

        input_path_abs = os.path.abspath(input_path)
        input_basename = os.path.basename(input_path_abs)
        input_stem = os.path.splitext(input_basename)[0]

        files = get_files(
            data_name=task.components.data_name,
            inputs=input_path_abs,
            input_dataset_dir=None,
            output_dataset_dir=str(npz_dir),
            require_suffix=['obj', 'fbx', 'FBX', 'dae', 'glb', 'gltf', 'vrm', 'usd', 'usda', 'usdc'],
            force_override=True,
            warning=False,
        )

        files = [(input_file, os.path.join(str(npz_dir), input_stem)) for input_file, output_dir in files]

        from src.data.extract import extract_builtin
        import time
        timestamp = time.strftime("%Y_%m_%d_%H_%M_%S")

        data_name_actual = task.components.get('data_name', 'raw_data.npz')
        if data_name is not None:
            data_name_actual = data_name

        files_to_extract = []
        for input_file, output_dir in files:
            os.makedirs(output_dir, exist_ok=True)
            raw_data_npz = os.path.join(output_dir, data_name_actual)
            if not os.path.exists(raw_data_npz):
                files_to_extract.append((input_file, output_dir))

        if files_to_extract:
            print(f"\\n=== Extracting {len(files_to_extract)} mesh file(s) ===")
            target_count = data_config.get('faces_target_count', 50000)
            try:
                extract_builtin(
                    output_folder=str(npz_dir),
                    target_count=target_count,
                    num_runs=1,
                    id=0,
                    time=timestamp,
                    files=files_to_extract,
                )
                print("✓ Mesh extraction complete")
            except Exception as e:
                print(f"✗ Error during extraction: {e}")
                raise

        print("\\n=== Verifying extracted files ===")
        all_verified = True
        for input_file, output_dir in files:
            raw_data_npz = os.path.join(output_dir, data_name_actual)
            if os.path.exists(raw_data_npz):
                print(f"  ✓ Found: {raw_data_npz}")
            else:
                print(f"  ✗ Missing: {raw_data_npz}")
                all_verified = False

        if not all_verified:
            raise FileNotFoundError(f"Extraction/verification failed")

        files = [f[1] for f in files]
        datapath = Datapath(files=files, cls=None)

        tokenizer_config = task.components.get('tokenizer', None)
        if tokenizer_config is not None:
            tokenizer_config = load_config('tokenizer', os.path.join('configs/tokenizer', task.components.tokenizer))
            from src.tokenizer.spec import TokenizerConfig
            tokenizer_config = TokenizerConfig.parse(config=tokenizer_config)

        predict_dataset_config = data_config.get('predict_dataset_config', None)
        if predict_dataset_config is not None:
            predict_dataset_config = DatasetConfig.parse(config=predict_dataset_config).split_by_cls()

        predict_transform_config = transform_config.get('predict_transform_config', None)
        if predict_transform_config is not None:
            predict_transform_config = TransformConfig.parse(config=predict_transform_config)

        model_config = task.components.get('model', None)
        if model_config is not None:
            model_config = load_config('model', os.path.join('configs/model', model_config))
            if tokenizer_config is not None:
                tokenizer = get_tokenizer(config=tokenizer_config)
            else:
                tokenizer = None
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                model = get_model(tokenizer=tokenizer, **model_config)
        else:
            model = None

        data = UniRigDatasetModule(
            process_fn=None if model is None else model._process_fn,
            predict_dataset_config=predict_dataset_config,
            predict_transform_config=predict_transform_config,
            tokenizer_config=tokenizer_config,
            debug=False,
            data_name=data_name_actual,
            datapath=datapath,
            cls=None,
        )

        writer_config = task.get('writer', None)
        callbacks = []
        if writer_config is not None:
            assert predict_transform_config is not None, 'missing predict_transform_config in transform'
            writer_config['npz_dir'] = npz_dir
            if writer_config.get('export_npz') == 'predict_skeleton':
                writer_config['output_dir'] = npz_dir
                writer_config['user_mode'] = False
            else:
                writer_config['output_dir'] = None
                writer_config['user_mode'] = True
            writer_config['output_name'] = output_path
            callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))

        system_config = task.components.get('system', None)
        if system_config is not None:
            system_config = load_config('system', os.path.join('configs/system', system_config))
            optimizer_config = task.get('optimizer', None)
            loss_config = task.get('loss', None)
            scheduler_config = task.get('scheduler', None)

            train_dataset_config = data_config.get('train_dataset_config', None)
            if train_dataset_config is not None:
                train_dataset_config = DatasetConfig.parse(config=train_dataset_config)

            system = get_system(
                **system_config,
                model=model,
                optimizer_config=optimizer_config,
                loss_config=loss_config,
                scheduler_config=scheduler_config,
                steps_per_epoch=1 if train_dataset_config is None else
                ceil(len(data.train_dataloader()) // 1 // 1),
            )
        else:
            system = None

        trainer_config = task.get('trainer', {})

        resume_from_checkpoint = task.get('resume_from_checkpoint', None)
        resume_from_checkpoint = download(resume_from_checkpoint)

        try:
            torch.serialization.add_safe_globals([Box])
        except Exception:
            pass

        trainer = L.Trainer(
            callbacks=callbacks,
            logger=None,
            **trainer_config,
        )

        assert resume_from_checkpoint is not None, 'expect resume_from_checkpoint in task'
        trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)

    finally:
        Exporter._export_fbx = original_export_fbx

# Load configuration
""" <> """
config_file_path = r"#{String.replace(config_file, "\\", "\\\\")}"
print(f"Loading config from: {config_file_path}")
with open(config_file_path, 'r', encoding='utf-8') as f:
    config = json.load(f)

mesh_path = config['mesh_path']
output_format = config.get('output_format', 'usdc')
seed = config.get('seed', 42)
skeleton_only = config.get('skeleton_only', False)
skin_only = config.get('skin_only', False)
skeleton_task = config.get('skeleton_task')
skin_task = config.get('skin_task')

mesh_path = str(Path(mesh_path).resolve())
original_cwd = os.getcwd()

output_dir = Path(original_cwd) / "output"
output_dir.mkdir(exist_ok=True, parents=True)

import time
tag = time.strftime("%Y%m%d_%H_%M_%S")
export_dir = output_dir / tag
export_dir.mkdir(exist_ok=True, parents=True)

intermediate_dir = export_dir / "intermediate"
intermediate_dir.mkdir(exist_ok=True, parents=True)

print("\\n=== Setup UniRig Environment ===")

unirig_path = Path.cwd() / ".." / "thirdparty" / "UniRig"

if not unirig_path.exists():
    unirig_path = None
    print("⚠ UniRig repository not found in ../thirdparty/UniRig")
    raise FileNotFoundError("UniRig repository not found")
else:
    unirig_path = unirig_path.resolve()
    print(f"✓ Using UniRig from: {unirig_path}")

if unirig_path and unirig_path.exists():
    if str(unirig_path) not in sys.path:
        sys.path.insert(0, str(unirig_path))

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

npz_dir = intermediate_dir
npz_dir.mkdir(exist_ok=True, parents=True)

if unirig_path and unirig_path.exists():
    os.chdir(str(unirig_path))

if skin_only:
    print("\\n=== Generate Skinning Weights ===")
    skeleton_path = export_dir / "skeleton.usdc"
    if not skeleton_path.exists():
        raise FileNotFoundError(f"Skeleton file not found: {skeleton_path}")

    skin_output = export_dir / "skin.usdc"
    skin_task_path = skin_task or "configs/task/quick_inference_unirig_skin.yaml"

    run_unirig_inference(
        task_path=skin_task_path,
        seed=seed,
        input_path=str(skeleton_path),
        output_path=str(skin_output),
        npz_dir=str(npz_dir),
        data_name="predict_skeleton.npz",
        unirig_base_path=str(unirig_path) if unirig_path else None
    )

    final_output = export_dir / f"rigged.{output_format}"
    from src.inference.merge import transfer
    transfer(
        source=str(skin_output),
        target=str(Path(mesh_path).resolve()),
        output=str(final_output),
        add_root=False
    )

elif skeleton_only:
    print("\\n=== Generate Skeleton ===")
    skeleton_output = export_dir / f"skeleton.{output_format}"
    skeleton_task_path = skeleton_task or "configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml"

    run_unirig_inference(
        task_path=skeleton_task_path,
        seed=seed,
        input_path=str(Path(mesh_path).resolve()),
        output_path=str(skeleton_output),
        npz_dir=str(npz_dir),
        data_name=None,
        unirig_base_path=str(unirig_path) if unirig_path else None
    )

else:
    print("\\n=== Generate Skeleton ===")
    skeleton_output = export_dir / "skeleton.usdc"
    skeleton_task_path = skeleton_task or "configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml"

    run_unirig_inference(
        task_path=skeleton_task_path,
        seed=seed,
        input_path=str(Path(mesh_path).resolve()),
        output_path=str(skeleton_output),
        npz_dir=str(npz_dir),
        data_name=None,
        unirig_base_path=str(unirig_path) if unirig_path else None
    )

    print("\\n=== Generate Skinning Weights ===")
    mesh_stem = Path(mesh_path).stem
    mesh_dir = Path(mesh_path).parent

    possible_paths = [
        Path(npz_dir) / mesh_stem / "predict_skeleton.npz",
        mesh_dir / mesh_stem / "predict_skeleton.npz",
        mesh_dir / "predict_skeleton.npz",
    ]

    skeleton_npz_path = None
    for path in possible_paths:
        if path.exists():
            skeleton_npz_path = path
            break

    if skeleton_npz_path is None:
        import glob
        found_files = list(Path(npz_dir).rglob("predict_skeleton.npz"))
        found_files.extend(list(mesh_dir.rglob("predict_skeleton.npz")))
        if found_files:
            skeleton_npz_path = found_files[0]

    if skeleton_npz_path:
        expected_path = Path(npz_dir) / mesh_stem / "predict_skeleton.npz"
        expected_path.parent.mkdir(parents=True, exist_ok=True)
        import shutil
        shutil.copy2(skeleton_npz_path, expected_path)

    skin_output = export_dir / "skin.usdc"
    skin_task_path = skin_task or "configs/task/quick_inference_unirig_skin.yaml"

    run_unirig_inference(
        task_path=skin_task_path,
        seed=seed,
        input_path=str(Path(mesh_path).resolve()),
        output_path=str(skin_output),
        npz_dir=str(npz_dir),
        data_name="predict_skeleton.npz",
        unirig_base_path=str(unirig_path) if unirig_path else None
    )

    print("\\n=== Merge Skeleton and Skin ===")
    final_output = export_dir / f"rigged.{output_format}"

    from src.inference.merge import clean_bpy, load, process_mesh, get_arranged_bones, process_armature, merge as merge_func, get_skin
    import bpy
    import numpy as np

    clean_bpy()
    armature = load(filepath=str(skin_output), return_armature=True)
    if armature is None:
        raise ValueError("Failed to load skeleton from USD")

    vertices_skin, faces_skin, skin = process_mesh()
    arranged_bones = get_arranged_bones(armature)
    if skin is None:
        skin = get_skin(arranged_bones)

    joints, tails, parents, names, matrix_local = process_armature(armature, arranged_bones)

    clean_bpy()
    load(str(Path(mesh_path).resolve()))

    for c in bpy.data.armatures:
        bpy.data.armatures.remove(c)

    merge_func(
        path=str(Path(mesh_path).resolve()),
        output_path=str(final_output),
        vertices=vertices_skin,
        joints=joints,
        skin=skin,
        parents=parents,
        names=names,
        tails=tails,
        add_root=False,
    )

os.chdir(original_cwd)

print("\\n=== Complete ===")
print("3D model rigging completed successfully!")
""", %{})

  IO.puts("✓ UniRig rigging completed successfully!")

rescue
  e ->
    IO.puts("❌ Error during rigging: #{inspect(e)}")
after
  # Cleanup
  File.rm(config_file)
end

Usage Instructions

  1. Setup: Run the first cell to install dependencies
  2. Configure: Update the config map with your 3D model path
  3. Run: Execute the rigging cell to generate skeleton and skinning
  4. Results: Check the output directory for rigged models

Notes

  • Requires CUDA-compatible GPU for best performance
  • UniRig models will be downloaded automatically (~10GB)
  • Supports various 3D formats (OBJ, FBX, GLB, GLTF, USD)
  • Output is always USDC format for optimal compatibility