Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion invokeai/app/invocations/qwen_image_image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_qwen_image


@invocation(
Expand All @@ -44,7 +45,13 @@ class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard)

@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info.model_on_device() as (_, vae):
assert isinstance(vae_info.model, AutoencoderKLQwenImage)
estimated_working_memory = estimate_vae_working_memory_qwen_image(
operation="encode",
image_tensor=image_tensor,
vae=vae_info.model,
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoencoderKLQwenImage)

vae.disable_tiling()
Expand Down
8 changes: 7 additions & 1 deletion invokeai/app/invocations/qwen_image_latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_qwen_image


@invocation(
Expand All @@ -41,9 +42,14 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKLQwenImage)
estimated_working_memory = estimate_vae_working_memory_qwen_image(
operation="decode",
image_tensor=latents,
vae=vae_info.model,
)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device() as (_, vae),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
):
context.util.signal_progress("Running VAE")
assert isinstance(vae, AutoencoderKLQwenImage)
Expand Down
52 changes: 52 additions & 0 deletions invokeai/backend/util/vae_working_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny

from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
Expand Down Expand Up @@ -92,6 +93,57 @@ def estimate_vae_working_memory_flux(
return int(working_memory)


def estimate_vae_working_memory_qwen_image(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKLQwenImage
) -> int:
"""Estimate the working memory required by the invocation in bytes.

The Qwen Image VAE is a video-style autoencoder that operates on 5D tensors of shape
(B, C, num_frames, H, W). Tiling is not used, so peak working memory scales with the full
spatial output. The two trailing dimensions are the spatial H/W in latent space (decode) or
pixel space (encode), matching the convention used by the other estimators here.
"""
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1

h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()

# The Qwen Image VAE is much heavier than the SD/SDXL VAE and needs correspondingly larger
# constants. These were calibrated by measuring peak *reserved* memory growth (not just allocated
# -- reserved is what the cache's `free >= estimate` check compares against) across a resolution
# grid in fp16, on both an AMD W7900 (ROCm) and an NVIDIA card (CUDA). See
# scripts/calibrate_qwen_vae_working_memory.py.
#
# Implied constant = reserved_bytes / (h * w * element_size). Per-point maxima (fp16):
# 512^2 768^2 1024^2 1536^2 1792^2 2048^2 -> ship (max observed + ~8% headroom)
# ROCm decode 5132 4596 4570 3273 3735 4813 -> 5500
# ROCm encode 5864 5858 5858 3532 4364 (OOM) -> 6300
# CUDA decode 2660 2519 2690 2671 2281 (OOM) -> 2900
# CUDA encode 1456 1451 1458 1456 1455 1455 -> 1600
#
# Why this branches on backend (the only estimator here that does):
# - The Qwen VAE is attention-heavy. With Flash/efficient attention (CUDA) the attention memory
# is O(area) and the curve is flat/linear; the ROCm build falls back to math attention, which
# is O(area^2), so ROCm reserves ~2x (decode) to ~4x (encode) more and goes super-linear above
# ~1792^2. The two backends differ far more than any headroom, so a single constant would
# either under-estimate on ROCm (OOM) or massively over-budget on CUDA (needless eviction).
# - "Encoding is half of decoding" (as the sibling estimators assume) is only true on CUDA. On
# ROCm encode reserves >= decode, so the ROCm encode constant is sized accordingly -- this is
# the path Qwen Image Edit exercises.
# - On ROCm the linear model under-estimates for decodes well above 2048^2, but those OOM on a
# 48GB card regardless; on CUDA the curve stays linear so no extra term is needed.
is_rocm = torch.version.hip is not None
if operation == "decode":
scaling_constant = 5500 if is_rocm else 2900
else: # encode
scaling_constant = 6300 if is_rocm else 1600

working_memory = h * w * element_size * scaling_constant

return int(working_memory)


def estimate_vae_working_memory_sd3(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
) -> int:
Expand Down
2 changes: 1 addition & 1 deletion invokeai/version/invokeai_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "6.13.0.post1"
__version__ = "6.13.5.rc1"
Loading