Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us
Notesclub

Stable Diffusion with EXLA CUDA

stable_diffusion_exla_cuda.livemd

Stable Diffusion with EXLA CUDA

Mix.install(
  [
    {:bumblebee, "~> 0.4"},
    {:nx, "~> 0.6"},
    {:exla, "~> 0.6"},
    {:kino, "~> 0.12"}
  ],
  system_env: [
    {"XLA_TARGET", "cuda114"}
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend
    ]
  ]
)

設定

Nx.default_backend()
repository_id = "CompVis/stable-diffusion-v1-4"
cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

{:ok, tokenizer} =
  Bumblebee.load_tokenizer({
    :hf,
    "openai/clip-vit-large-patch14",
    cache_dir: cache_dir
  })
{:ok, clip} =
  Bumblebee.load_model({
    :hf,
    repository_id,
    subdir: "text_encoder", cache_dir: cache_dir
  })
{:ok, unet} =
  Bumblebee.load_model(
    {
      :hf,
      repository_id,
      subdir: "unet", cache_dir: cache_dir
    },
    params_filename: "diffusion_pytorch_model.bin"
  )
{:ok, vae} =
  Bumblebee.load_model(
    {
      :hf,
      repository_id,
      subdir: "vae", cache_dir: cache_dir
    },
    architecture: :decoder,
    params_filename: "diffusion_pytorch_model.bin"
  )
{:ok, scheduler} =
  Bumblebee.load_scheduler({
    :hf,
    repository_id,
    subdir: "scheduler", cache_dir: cache_dir
  })
{:ok, featurizer} =
  Bumblebee.load_featurizer({
    :hf,
    repository_id,
    subdir: "feature_extractor", cache_dir: cache_dir
  })
{:ok, safety_checker} =
  Bumblebee.load_model({
    :hf,
    repository_id,
    subdir: "safety_checker", cache_dir: cache_dir
  })
serving =
  Bumblebee.Diffusion.StableDiffusion.text_to_image(
    clip,
    unet,
    vae,
    tokenizer,
    scheduler,
    num_steps: 20,
    num_images_per_prompt: 2,
    safety_checker: safety_checker,
    safety_checker_featurizer: featurizer,
    compile: [batch_size: 1, sequence_length: 60],
    defn_options: [compiler: EXLA]
  )

画像生成

prompt_input = Kino.Input.text("PROMPT")
output = Nx.Serving.run(serving, Kino.Input.read(prompt_input))
output.results
|> Enum.map(fn result ->
  Kino.Image.new(result.image)
end)
|> Kino.Layout.grid(columns: 2)

時間計測

1..10
|> Enum.map(fn _ ->
  {time, _} = :timer.tc(Nx.Serving, :run, [serving, Kino.Input.read(prompt_input)])
  time
end)
|> then(&(Enum.sum(&1) / 10))