Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
158a586
feat(pid): vendor PiD decoder backend (phase A of integration)
Pfannkuchensack May 29, 2026
94a180d
feat(pid): wire PiD + Gemma-2 into model-manager and add decode nodes
Pfannkuchensack May 29, 2026
dd1e76e
feat(pid): end-to-end PiD pixel-diffusion decoder integration
Pfannkuchensack May 29, 2026
6b72ef0
Merge branch 'main' into feat/pid-decoder
Pfannkuchensack Jun 9, 2026
bc79446
Merge branch 'invoke-ai:main' into feat/pid-decoder
Pfannkuchensack Jun 21, 2026
47c7513
Chore Ruff
Pfannkuchensack Jun 21, 2026
51e6a93
Chore Typegen
Pfannkuchensack Jun 21, 2026
9029c12
Chore Knip
Pfannkuchensack Jun 21, 2026
2aeaa97
fix(pid): remove unused vendored models/utils.py (broken easy_io import)
Pfannkuchensack Jun 21, 2026
27f1455
feat(pid): identify decoder backbone from weight shapes, not filename
Pfannkuchensack Jun 21, 2026
82dd9ea
feat(ui): PiD decode (Fit mode) for FLUX text-to-image
Pfannkuchensack Jun 29, 2026
ab11359
feat(ui): PiD Native 4x mode for FLUX text-to-image
Pfannkuchensack Jun 29, 2026
9ffdf99
feat(ui): PiD Fit decode for FLUX image-to-image
Pfannkuchensack Jun 29, 2026
3bc19d3
feat(ui): PiD Native 4x decode for FLUX image-to-image
Pfannkuchensack Jun 29, 2026
3742be3
feat(ui): add informational popover to PiD Decode setting
Pfannkuchensack Jun 30, 2026
4dac81d
feat(models): add PiD decoder + Gemma-2 encoder to starter models
Pfannkuchensack Jun 30, 2026
063827d
feat(pid): add FLUX.2 Klein PiD 4x-SR decode support
Pfannkuchensack Jul 1, 2026
10ed5f4
feat(pid): add SD3 PiD 4x-SR decode support
Pfannkuchensack Jul 1, 2026
2954c14
feat(pid): add SDXL PiD 4x-SR decode support
Pfannkuchensack Jul 1, 2026
fe9987b
feat(pid): add Z-Image PiD 4x-SR decode support
Pfannkuchensack Jul 1, 2026
9f3bb20
feat(pid): add Qwen-Image PiD 4x-SR decode support
Pfannkuchensack Jul 1, 2026
51a3dca
Merge remote-tracking branch 'upstream/main' into feat/pid-decoder
Pfannkuchensack Jul 1, 2026
4f417a3
Chore Ruff
Pfannkuchensack Jul 1, 2026
3d8c84c
Add Docs
Pfannkuchensack Jul 1, 2026
bf619a1
fix(pid): green up frontend tests and knip for the PiD branch
Pfannkuchensack Jul 1, 2026
dd41494
Chore openapi + typegen
Pfannkuchensack Jul 1, 2026
c2e7d50
Merge branch 'main' into feat/pid-decoder
Pfannkuchensack Jul 1, 2026
b70aa86
Merge branch 'main' into feat/pid-decoder
Pfannkuchensack Jul 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions LICENSE-PiD.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
PiD (Pixel Diffusion Decoder) — License notice

Upstream project: https://github.com/nv-tlabs/PiD
Vendored under: invokeai/backend/pid/

================================================================================
CODE (Apache License 2.0)
================================================================================

The PiD source code, including the `pid/_src/` subtree and the `pid/_ext/imaginaire/`
framework subset, is licensed under the Apache License, Version 2.0.

Copyright 2026 NVIDIA CORPORATION & AFFILIATES.

Portions of the framework (pid/_ext/imaginaire/) were originally adapted from
the cosmos-predict2.5 project (https://github.com/nvidia-cosmos/cosmos-predict2.5/).

Files vendored into invokeai/backend/pid/ retain their original SPDX-License-Identifier
headers. The Apache 2.0 license text is available at:

http://www.apache.org/licenses/LICENSE-2.0

================================================================================
MODEL WEIGHTS (NVIDIA Source Code License v1 — non-commercial)
================================================================================

The pre-trained PiD decoder checkpoints distributed by NVIDIA at

https://huggingface.co/nvidia/PiD

are released under the NSCLv1 license. Per NSCLv1, the weights may only be used
for non-commercial (research or evaluation) purposes:

https://huggingface.co/nvidia/PixelDiT-1300M-1024px/blob/main/LICENSE

This restriction applies to the weights only, not to the InvokeAI source code
or the vendored PiD source code (which remain Apache 2.0). Users are responsible
for ensuring their use of the PiD weights complies with NSCLv1.

================================================================================
LOCAL MODIFICATIONS
================================================================================

The following changes were applied to the upstream PiD subset when vendoring:

* All `pid.*` imports were rewritten to `invokeai.backend.pid.*`.
* `pid/_src/configs/`, `pid/_src/tokenizers/`, `pid/_src/checkpointer/`,
`pid/_src/inference/_demo_*.py`, `from_*.py`, `create_dataset.py`,
`rae_generation.py`, and `scale_rae_generation.py` were dropped (not needed
for the decoder-only inference subset).
* `pid/_ext/imaginaire/checkpointer/`, `trainer.py`, `visualize/`, `flags.py`,
`config.py`, `types/`, `utils/easy_io/`, `utils/callback.py`,
`utils/config_helper.py`, `utils/validator{,_params}.py` and the
`lazy_config/omegaconf_patch.py` were dropped.
* The upstream `utils/log.py` (loguru-based) and `utils/misc.py` were replaced
with stdlib-based stubs covering only the API surface used by the decoder.
* `lazy_config/file_io.py` (iopath PathManager) and `lazy_config/registry.py`
(fvcore Registry) were replaced with stdlib-only implementations.
* `lazy_config/lazy.py` was reduced to a minimal `LazyCall`/`LazyConfig` stub;
the upstream yaml/cloudpickle/dill/detectron2 config save/load paths are
intentionally not supported.
* `lazy_config/instantiate.py` was reduced to a stdlib-only implementation;
the upstream omegaconf `DictConfig`/`ListConfig` branches were dropped, so
no `omegaconf` dependency is required.
* `_src/utils/model_loader.py` (which depended on Imaginaire's distributed
checkpointer + easy_io) and `_src/inference/inference_utils.py` (S3 / video
helpers) were removed; their decode-path equivalents are reimplemented in
`invokeai/backend/pid/decode.py`.
76 changes: 76 additions & 0 deletions docs/src/content/docs/features/pid-decode.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
---
title: PiD Super-Resolution Decode
lastUpdated: 2026-07-01
sidebar:
order: 5
---

import { Steps, Aside, Tabs, TabItem } from '@astrojs/starlight/components'

**PiD** (Pixel Diffusion Decoder) is an alternative way to turn a model's latents into an image. Instead of the usual VAE decode, it runs a short pixel-space diffusion that produces a **4× super-resolved** result in a single, few-step pass — so a 512×512 generation comes out as a detailed 2048×2048 image.

Because it decodes in pixel space and is conditioned on your prompt, PiD often recovers finer texture and edge detail than a plain VAE decode followed by an upscaler.

<Aside type="note">
PiD replaces the VAE decode step of a single generation. It is not a separate "upscale" pass you run on an existing image — you enable it before generating.
</Aside>

## Supported models

PiD works with these base models:

| Base model | PiD decoder to install |
|---|---|
| FLUX.1 | PiD Decoder FLUX |
| FLUX.2 Klein (4B / 9B) | PiD Decoder FLUX.2 |
| Stable Diffusion 3 | PiD Decoder SD3 |
| SDXL | PiD Decoder SDXL |
| Z-Image / Z-Image Turbo | **PiD Decoder FLUX** (Z-Image shares FLUX.1's VAE) |
| Qwen-Image | PiD Decoder Qwen-Image |

<Aside type="tip">
Z-Image has no PiD decoder of its own — when a Z-Image model is active, the decoder picker lists the **FLUX** decoders. That's expected.
</Aside>

## What you need to install

PiD needs two extra models, both available in **Model Manager → Starter Models**:

<Steps>
1. A **PiD Decoder** for your base model (e.g. *PiD Decoder FLUX (2K)*). Some bases offer a *2K* and a *2K-to-4K* preset; SDXL and Qwen-Image ship only the *2K-to-4K* preset.
2. The **Gemma 2 2B (PiD caption encoder)** — PiD uses it to condition the decode on your prompt. It installs automatically as a dependency of any PiD decoder, and is shared across all of them.
</Steps>

Each PiD decoder is roughly 5 GB and the shared Gemma-2 encoder is roughly 5 GB.

## Enabling PiD

Open the **Generation** settings for a supported model and expand the advanced options. You'll find a **PiD** control with three modes:

<Tabs>
<TabItem label="Off">
Standard VAE decode. No PiD models required.
</TabItem>
<TabItem label="Fit">
Generate at the requested size, decode 4× with PiD, then downscale the result back to the requested size. This is the safe default and works everywhere — the output matches your bounding box exactly, so it composites cleanly on the Canvas.
</TabItem>
<TabItem label="Native">
Treat the requested dimensions as the **4× target**: the image is generated at target ÷ 4 and PiD's full 4× output is used directly (no downscale), preserving all of the added detail. Great when you want a large, highly-detailed result.
</TabItem>
</Tabs>

When PiD mode is not *Off*, pick your **PiD Decoder** and **Gemma-2 Encoder** below the mode selector. The **PiD Steps** control (default 4) sets how many decode steps run — the released checkpoints are trained for 4.

PiD is available in both the **Generate** tab (text-to-image) and on the **Canvas** (image-to-image), in both Fit and Native modes.

## Tips & limitations

- **Turn off "Scale Before Processing"** on the Canvas when using PiD — PiD already decodes at 4×, so pre-scaling would inflate the work and is blocked.
- **Inpaint / Outpaint** are not supported with PiD yet; use text-to-image or image-to-image.
- **SDXL Refiner** cannot be combined with PiD — disable one of them.
- PiD's memory use scales with the *output* resolution. A 2048px output needs only a little more headroom than a normal decode, but Native mode at large target sizes (e.g. a 4096px result) is significantly heavier.
- Turbo variants (e.g. Z-Image Turbo) work as usual — the low step count / no-CFG only affects generation; PiD's own step count is separate.

<Aside type="caution">
The PiD decoder *code* is Apache-2.0, but NVIDIA releases the pretrained *weights* under a non-commercial / research license. Keep that in mind if you redistribute the checkpoints.
</Aside>
223 changes: 223 additions & 0 deletions invokeai/app/invocations/flux2_pid_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""FLUX.2 Klein PiD decode invocation.

Replaces the regular FLUX.2 VAE decode with the PiD pixel-diffusion super-res
decoder (``PiD_res2k[to4k]_sr4x_official_flux2_distill_4step``). Produces a 4x
super-resolved image from a FLUX.2 latent in a single 4-step distill pass. The
4B and 9B FLUX.2 Klein variants share the same 32-channel VAE, so this one node
covers both.

Latent layout (the important difference from the FLUX.1 node):

* ``flux2_denoise`` stores an *unpacked* ``(B, 32, H/8, W/8)`` latent that is
already **BN-denormalized** (``x * bn_std + bn_mean`` is applied before the
unpack, see ``flux2_denoise.py``). That is exactly the raw latent the FLUX.2
VAE's conv decoder consumes.
* PiD's FLUX.2 backbone expects the **packed** ``(B, 128, H/16, W/16)``
representation (``lq_latent_channels=128``, ``latent_spatial_down_factor=16``
in ``backend/pid/decode.py``). We therefore patchify the stored latent
(2x2 spatial patches folded into channels: 32*4 = 128) *before* handing it to
PiD - mirroring ``pack_flux2`` but keeping a spatial ``(B, C, h, w)`` layout
instead of the transformer's ``(B, seq, C)`` sequence layout.

Denormalization: unlike FLUX.1 (single ``scale``/``shift``) and Z-Image
(checkpoint-specific ``scaling_factor``/``shift_factor``), the FLUX.2 VAE
(``AutoencoderKLFlux2``) exposes **no** scalar ``scaling_factor``/``shift_factor``
at all - its only normalization is the per-channel BatchNorm applied/inverted
*outside* the VAE in ``flux2_denoise``. So the packed latent is already in PiD's
expected raw space and no further scaling is needed (identity fallbacks below).
We still accept an optional ``vae`` input and read the constants at runtime (like
the Z-Image node) so any future FLUX.2 VAE variant that does expose scalar
constants is honored automatically.
"""

from contextlib import ExitStack

import torch
from einops import rearrange
from PIL import Image
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
UIComponent,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import Gemma2EncoderField, PiDDecoderField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.pid._src.networks.pid_net import PidNet
from invokeai.backend.pid.decode import (
PiDDecodeConfig,
PiDDecoder,
encode_caption_for_pid,
estimate_pid_decode_working_memory,
)
from invokeai.backend.util.devices import TorchDevice

# FLUX.2 uses per-channel BatchNorm (affine=False) for latent normalization, and
# that BN is already inverted in flux2_denoise before the latent is stored. The
# FLUX.2 VAE (AutoencoderKLFlux2) has no scalar scaling_factor/shift_factor, so
# the identity transform below is the correct default: the stored (packed) latent
# is already the raw representation PiD was trained on.
_FLUX2_VAE_SCALING_FACTOR_FALLBACK: float = 1.0
_FLUX2_VAE_SHIFT_FACTOR_FALLBACK: float = 0.0


@invocation(
"flux2_pid_decode",
title="Latents to Image - FLUX.2 + PiD (4x SR)",
tags=["latents", "image", "pid", "flux2", "klein", "upscale"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2PiDDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Decode a FLUX.2 Klein latent with the PiD pixel-diffusion decoder.

Produces a 4x super-resolved image in a single pass. The stored FLUX.2 latent
is patchified from ``(B, 32, H/8, W/8)`` to the ``(B, 128, H/16, W/16)`` layout
PiD's FLUX.2 backbone expects, then decoded directly (it is already in raw,
BN-denormalized space; see the module docstring).
"""

latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
prompt: str = InputField(
description="Text prompt the latent was generated from. PiD conditions on it.",
ui_component=UIComponent.Textarea,
)
gemma2_encoder: Gemma2EncoderField = InputField(
title="Gemma-2 Encoder",
description="Gemma-2 caption encoder. Required by PiD.",
input=Input.Connection,
)
pid_decoder: PiDDecoderField = InputField(
title="PiD Decoder",
description="PiD FLUX.2 decoder checkpoint.",
input=Input.Connection,
)
vae: VAEField | None = InputField(
default=None,
title="VAE",
description="FLUX.2 VAE, used only to read a scalar scaling_factor / shift_factor if one exists. "
"FLUX.2 normalises latents with BatchNorm (already inverted in flux2_denoise), so this is "
"normally an identity transform and the input can be left unconnected.",
input=Input.Connection,
)
num_inference_steps: int = InputField(
default=4,
ge=1,
le=8,
description="Number of PiD distill steps. The released checkpoints are trained for 4.",
)
seed: int = InputField(default=0, description="Seed for the PiD decoder's noise.")

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)

# 1) Patchify the stored FLUX.2 latent into PiD's expected layout.
# flux2_denoise stores an unpacked (B, 32, H/8, W/8) latent; PiD's
# FLUX.2 backbone wants the packed (B, 128, H/16, W/16) form (32*4=128
# channels, spatial halved). This mirrors pack_flux2's 2x2 patchify but
# keeps a spatial (B, C, h, w) layout rather than a (B, seq, C) sequence.
if latents.shape[-3] != 32:
raise ValueError(
f"FLUX.2 PiD decode expected a 32-channel latent from flux2_denoise, got shape "
f"{tuple(latents.shape)}. The upstream node must output the unpacked FLUX.2 latent."
)
packed = rearrange(latents, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=2, pw=2)
context.logger.info(
f"FLUX.2 PiD decode: stored latent shape={tuple(latents.shape)} -> packed for PiD "
f"shape={tuple(packed.shape)} (expect [B, 128, H/16, W/16]) dtype={packed.dtype}"
)

# 2) Resolve the scalar scaling/shift (identity for current FLUX.2 VAEs).
scaling_factor = _FLUX2_VAE_SCALING_FACTOR_FALLBACK
shift_factor = _FLUX2_VAE_SHIFT_FACTOR_FALLBACK
if self.vae is not None:
vae_info = context.models.load(self.vae.vae)
with vae_info.model_on_device() as (_, vae):
config = getattr(vae, "config", None)
if config is not None and hasattr(config, "scaling_factor"):
scaling_factor = float(config.scaling_factor)
shift_factor = float(getattr(config, "shift_factor", None) or 0.0)
else:
scaling_factor = float(getattr(vae, "scale_factor", scaling_factor))
shift_factor = float(getattr(vae, "shift_factor", shift_factor))
del vae_info
TorchDevice.empty_cache()

# 3) Encode caption with Gemma-2.
gemma_text_encoder_info = context.models.load(self.gemma2_encoder.text_encoder)
gemma_tokenizer_info = context.models.load(self.gemma2_encoder.tokenizer)
with ExitStack() as stack:
(_, gemma_encoder) = stack.enter_context(gemma_text_encoder_info.model_on_device())
(_, gemma_tokenizer) = stack.enter_context(gemma_tokenizer_info.model_on_device())
if not isinstance(gemma_encoder, PreTrainedModel):
raise TypeError(f"Expected PreTrainedModel for Gemma encoder, got {type(gemma_encoder).__name__}.")
if not isinstance(gemma_tokenizer, PreTrainedTokenizerBase):
raise TypeError(
f"Expected PreTrainedTokenizerBase for Gemma tokenizer, got {type(gemma_tokenizer).__name__}."
)

device = TorchDevice.choose_torch_device()
encode_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
context.util.signal_progress("Encoding caption with Gemma-2")
caption_embs, caption_mask = encode_caption_for_pid(
[self.prompt],
tokenizer=gemma_tokenizer,
encoder=gemma_encoder,
device=device,
dtype=encode_dtype,
)
caption_embs = caption_embs.detach().to("cpu")
caption_mask = caption_mask.detach().to("cpu")
del gemma_encoder, gemma_tokenizer
# Gemma is only needed for the one-shot caption encode above. Offload it from VRAM (keeping it in the RAM
# cache) so its ~5GB is freed before the PiD decoder loads. The cache offloads anything else it needs to
# fit the decode on its own, so we deliberately do NOT evict every other model here.
context.models.offload_from_vram(self.gemma2_encoder.text_encoder)
TorchDevice.empty_cache()

# 4) Run PiD decode (the loader already returns a live PidNet).
pid_info = context.models.load(self.pid_decoder.decoder)
# The working-memory estimate scales with the OUTPUT pixel count, so it must see the PACKED latent
# (spatial H/16), not the unpacked one - otherwise it over-reserves by 4x.
estimated_working_memory = estimate_pid_decode_working_memory(packed, BaseModelType.Flux2)
with pid_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, pid_net):
if not isinstance(pid_net, PidNet):
raise TypeError(f"Expected PidNet for PiD decoder, got {type(pid_net).__name__}.")
device = TorchDevice.choose_torch_device()
dtype = next(iter(pid_net.parameters())).dtype

# The packed latent is already BN-denormalized (raw VAE-input space); the scalar transform below is
# identity for current FLUX.2 VAEs and only bites if a VAE ever exposes real scalar constants.
denorm_latent = packed.to(device=device, dtype=dtype) / scaling_factor + shift_factor
context.logger.info(
f"FLUX.2 PiD denorm_latent stats[min={denorm_latent.min().item():.3f} "
f"max={denorm_latent.max().item():.3f} mean={denorm_latent.mean().item():.3f}] "
f"using scale={scaling_factor:.4f} shift={shift_factor:.4f}"
)
caption_embs = caption_embs.to(device=device, dtype=dtype)

context.util.signal_progress("Running PiD decoder")
decoder = PiDDecoder(pid_net, backbone=BaseModelType.Flux2)
x0 = decoder.decode(
latent=denorm_latent,
caption_embs=caption_embs,
caption_mask=caption_mask,
config=PiDDecodeConfig(num_inference_steps=self.num_inference_steps, seed=self.seed),
)

TorchDevice.empty_cache()

img = rearrange(x0[0].clamp(-1, 1), "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)
Loading
Loading