Skip to content

Feat: Add PiD (Pixel Diffusion Decoder) 4× super-resolution decode for FLUX / FLUX.2 / SD3 / SDXL / Z-Image / Qwen-Image#9281

Draft
Pfannkuchensack wants to merge 28 commits into
invoke-ai:mainfrom
Pfannkuchensack:feat/pid-decoder
Draft

Feat: Add PiD (Pixel Diffusion Decoder) 4× super-resolution decode for FLUX / FLUX.2 / SD3 / SDXL / Z-Image / Qwen-Image#9281
Pfannkuchensack wants to merge 28 commits into
invoke-ai:mainfrom
Pfannkuchensack:feat/pid-decoder

Conversation

@Pfannkuchensack

@Pfannkuchensack Pfannkuchensack commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds PiD (Pixel Diffusion Decoder) support to InvokeAI — NVIDIA's few-step pixel-diffusion decoder that replaces the regular VAE decode with a caption-conditioned, 4× super-resolution decode (512→2048 in a single 4-step distill pass).

This PR vendors a minimal, inference-only subset of PiD at invokeai/backend/pid/ (upstream https://github.com/nv-tlabs/PiD, code Apache-2.0) and wires it end-to-end into the model manager, invocation nodes, starter models, and the generation UI.

What you get

  • A generic PiD backend (backend/pid/decode.py): per-backbone net config (_PER_BACKBONE), build_pid_net / load_pid_decoder / PiDDecoder, Gemma-2 caption encoding, and a working-memory estimator for the cache.
  • Generic loader nodes: pid_decoder_loader (→ PiDDecoderField) and gemma2_encoder_loader (→ Gemma2EncoderField), plus model-manager configs/loaders for PiD checkpoints and the shared Gemma-2 caption encoder.
  • PiD decode nodes for six base models, each replacing that base's VAE decode:
    Base Node Notes
    FLUX.1 flux_pid_decode 16ch / down-8
    FLUX.2 Klein (4B/9B) flux2_pid_decode packs the stored 32ch latent → 128ch / down-16; BN-normalized (no scalar denorm)
    SD3 sd3_pid_decode 16ch / down-8; fixed VAE constants
    SDXL sdxl_pid_decode 4ch / down-8; VAE scaling_factor read at runtime
    Z-Image (+ Turbo) z_image_pid_decode reuses the FLUX decoder (shared 16ch VAE)
    Qwen-Image qwen_image_pid_decode 16ch / down-8; per-channel latents_mean/latents_std denorm + 5D→4D temporal squeeze
  • Generation UI: a new PiD mode selector (Off / Fit / Native) with PiD-decoder + Gemma-2-encoder pickers and a PiD-steps control, shown for any base that supports PiD. Decoder pickers are base-filtered (Z-Image shows FLUX decoders). Both txt2img and img2img (Canvas) are supported in both modes:
    • Fit: generate at the target size, PiD decodes 4×, downscale back (compositing-safe).
    • Native: the requested size is the 4× target — generate at target/4 and use PiD's full 4× output directly.
  • Starter models: NVIDIA PiD decoders (nvidia/PiD, per backbone; FLUX/FLUX.2/SD3 ship 2K + 2K-to-4K, SDXL/Qwen-Image ship 2K-to-4K only) plus the shared gemma-2-2b-it caption encoder.

Robustness details

  • Backbone identification is driven primarily by the checkpoint's latent channel count (4/16/128), with the filename/directory name as a tie-breaker. Because FLUX.1 / SD3 / Qwen-Image all share 16 channels, the config probe additionally trusts an explicit base override (which the starter installer sends) when the directory name is ambiguous — so single-file HF downloads are still identified correctly.
  • Readiness checks gate each supported base (decoder + Gemma-2 encoder present, "Scale Before Processing" off, SDXL refiner disabled with PiD).
  • The standard (non-PiD) FLUX/FLUX.2/SD3/SDXL/Z-Image/Qwen-Image paths are unchanged.

License note: the vendored PiD code is Apache-2.0. The pretrained PiD weights are released by NVIDIA under NSCLv1 (non-commercial / research) — relevant for anyone redistributing the checkpoints.

Related Issues / Discussions

Closes #9240

QA Instructions

  1. Install models (Model Manager → Starter Models): a PiD Decoder for your base (e.g. "PiD Decoder FLUX (2K)") — its dependency, "Gemma 2 2B (PiD caption encoder)", installs automatically. Confirm the decoder is registered with the correct base (e.g. a Qwen-Image decoder shows as qwen-image, not FLUX).
  2. Select a supported main model (FLUX, FLUX.2 Klein, SD3, SDXL, Z-Image/Turbo, or Qwen-Image). In the Generation settings expander, a PiD control appears; pick the PiD decoder + Gemma-2 encoder.
  3. txt2img – Fit: set PiD = Fit, generate. Output is the requested size, refined via the 4× decode.
  4. txt2img – Native: set PiD = Native, generate. The requested dimensions become the 4× target (image is generated at target/4).
  5. Canvas img2img – Fit / Native: with "Scale Before Processing" = None, run a raster-layer generation in both modes.
  6. Guards: verify inpaint/outpaint and (SDXL) an active refiner surface a clear "unsupported" toast / disabled Invoke with a reason.

Automated gates already green on the branch: backend imports (starter_models, config factory, every *_pid_decode node), and the frontend pnpm lint:tsc / lint:eslint / lint:knip / lint:dpdm.

Not yet hardware-verified (needs a GPU + downloaded checkpoints): full end-to-end image output per base, measured VRAM peak (the working-memory constant is calibrated to a 2048px output ≈ 4.3 GB), and Qwen-Image Edit-mode (reference image) + PiD.

Merge Plan

Self-contained feature; no redux migration beyond the already-included params slice bump.

Checklist

  • The PR has a short but descriptive title, suitable for a changelog
  • Tests added / updated (if applicable)
  • ❗Changes to a redux slice have a corresponding migration (params slice _version bump)
  • Documentation added / updated (if applicable)
  • Updated What's New copy (if doing a release after this PR)

Adds a vendored subset of NVIDIA's PiD (Pixel Diffusion Decoder)
at invokeai/backend/pid/ as the foundation for upcoming
FLUX / FLUX.2 / SD3 / Z-Image PiD decode nodes plus a future
PiD-based 4x upscale node.

Upstream: https://github.com/nv-tlabs/PiD (Apache 2.0).

Vendor scope:
* _src/{networks,models,modules}: PidNet, PixDiT_T2I, LQProjection2D,
  PidModel, PidDistillModel, PixelDiTModel, GeneralConditioner.
* _ext/imaginaire: minimal Imaginaire framework subset
  (lazy_config, model, utils/{log,misc,distributed,device,count_params}).
* configs/, tokenizers/, checkpointer/, trainer.py, visualize/,
  _demo_*, from_*, easy_io/, S3/wandb training helpers were
  intentionally excluded.

Dependency stripping (no new hard deps introduced):
* loguru, termcolor -> stdlib logging shim
* iopath PathManager -> stdlib pathlib stub
* fvcore Registry -> minimal stdlib Registry
* lazy_config/lazy.py: yaml/dill/cloudpickle/detectron2 save/load
  paths replaced with a minimal LazyCall stub
* lazy_config/instantiate.py: omegaconf DictConfig/ListConfig branches
  removed; configs are plain dict / LazyCall mappings
* megatron, pynvml, boto3/wandb imports are try/except-guarded
  or local to functions and stay inert in our inference path

All pid.* imports rewritten to invokeai.backend.pid.*; SPDX-Apache-2.0
headers retained on vendored files; attribution and detailed list
of local modifications added in LICENSE-PiD.txt.

The pre-trained PiD checkpoints distributed by NVIDIA remain under
NSCLv1 (non-commercial); this commit only vendors code.

Smoke test: PidNet, PidModel, PidDistillModel, GeneralConditioner
import cleanly; LazyCall -> instantiate round-trip resolves to the
expected nn.Module. ruff check passes.
Adds the model-manager plumbing and workflow nodes needed to use the
vendored PiD decoder (phase A) end-to-end with FLUX, SD3 and Z-Image.

Model manager (Phase B + B.5):
* taxonomy: ModelType.PiDDecoder, PiDDecoderVariantType
  (Res2k_Sr4x / Res2kTo4k_Sr4x), ModelType.Gemma2Encoder +
  ModelFormat.Gemma2Encoder, both added to AnyVariant +
  variant_type_adapter.
* configs/pid_decoder.py: per-backbone PiD configs
  (FLUX / FLUX.2 / SD3) with state-dict probing on 'lq_proj' substring
  and backbone/variant detection from the official NVIDIA filenames.
* configs/gemma2_encoder.py: Gemma-2 directory probing on
  Gemma2ForCausalLM architecture + tokenizer files.
* AnyModelConfig union updated.
* model_loaders/pid_decoder.py: loads .pth / .safetensors, strips
  the upstream 'net.' prefix, supports torch.load(weights_only=True).
* model_loaders/gemma2_encoder.py: SubModelType.{Tokenizer,
  TextEncoder} dispatch; returns the causal LM's inner Gemma2Model
  (transformers 4.56's get_decoder() returns None for Gemma2).

Decode pipeline (Phase C):
* backend/pid/decode.py: build_pid_net + load_pid_decoder
  (per-backbone PixDiT_T2I hyperparams derived from PiD's pid_sr4x
  base + per-experiment overrides), encode_caption_for_pid (chi-prompt
  + Gemma encoding, mirrors PixelDiTModel._encode_text_raw), and a
  PiDDecoder wrapper with a reimplemented few-step distill sampler
  (no autocast / no distributed / no PixelDiTModel init paths from
  upstream).

Invocations (Phase 6.x):
* Gemma2EncoderField + PiDDecoderField in invocations/model.py.
* gemma2_encoder_loader / pid_decoder_loader: thin
  ModelIdentifierField pickers that emit the corresponding fields.
* z_image_pid_decode (pilot), flux_pid_decode, sd3_pid_decode:
  caption encode -> Gemma offload -> PiD state dict load ->
  PidNet construct -> decode. Per-backbone latent denormalisation
  (FLUX1 ae_params, SD3 hardcoded 1.5305/0.0609, Z-Image piggybacks
  on FLUX VAE).

End-to-end validated with the released
PiD_res2k_sr4x_official_flux_distill_4step.pth checkpoint and
gemma-2-2b-it: PidNet rebuilds at exactly 456 keys / 1.36B params,
sampler runs at ~5 GB VRAM peak (Gemma dominates), output shape and
range match.

FLUX.2 PiD decode is deliberately deferred: it needs BN-based
latent denormalisation and 32->128 channel packing, and we have no
FLUX.2 checkpoint to validate against yet.
Adds the NVIDIA PiD decoder as a 4x super-resolution alternative to the
regular VAE/RAE decode path. Includes model-manager configs and loaders
for both the PiD checkpoints and the Gemma-2 caption encoder they require,
plus four invocations: latent-in decode for FLUX / SD3 / Z-Image and an
image-in pid_upscale node.

- Decode pipeline keeps PidNet params in fp32 and uses bf16 autocast only
  for matmuls; caption embeddings have outliers that overflow bf16 RMSNorm.
- encode_caption_for_pid forces tokenizer padding_side="right" (Gemma
  defaults to left, PiD trained with right) and returns the attention mask
  as bool so it stays compatible with SDPA.
- Z-Image reuses the FLUX-trained checkpoint and reads scale/shift from the
  VAE config at runtime (PiD upstream notes they are checkpoint-specific).
- TextLLM config now excludes Gemma2ForCausalLM so it falls through to the
  dedicated Gemma2 encoder config instead of being misclassified.
- Frontend: new model_type / model_format / variant enums, type guards and
  category metadata; schema.ts regenerated via pnpm typegen.
@github-actions github-actions Bot added python PRs that change python files Root invocations PRs that change invocations backend PRs that change backend files services PRs that change app services frontend PRs that change frontend files labels Jun 8, 2026
@lstein lstein added the 6.14.x label Jun 17, 2026
@lstein lstein moved this to 6.14.x Theme: USER EXPERIENCE in Invoke - Community Roadmap Jun 17, 2026
Pfannkuchensack and others added 15 commits June 21, 2026 23:31
Read latent channel count from lq_proj.latent_proj.0.weight (FLUX.2=128,
FLUX.1/SD3=16) as the primary discriminator; fall back to filename/dir name
only to disambiguate the architecturally identical FLUX.1/SD3 pair. Fixes
FLUX.2 checkpoints (model_ema_bf16.pth) not being recognised, and correctly
rejects unsupported backbones (RAE/dinov2, 768ch). Fix Flux2 docstring 32->128.
Add a "PiD Decode" mode select (Off / Fit / Native) to the FLUX advanced
settings with PiD decoder + Gemma-2 encoder pickers. In Fit mode the FLUX
graph swaps the VAE decode for a PiD 4x super-resolution decode and
downscales back to the requested size. Adds params state (pidMode, decoder,
encoder, steps) with a v3->v4 migration, model hooks, readiness checks, and
graph guards for the not-yet-wired Native and non-txt2img paths.
Make the generation dimension helpers PiD-aware via an optional pidScale:
in Native mode the user-facing dimensions are the 4x target (grid 64,
optimal 2048), generation runs at target/4, and PiD's 4x output is used
directly with no downscale. Thread pidScale through the params dimension
reducers and the optimal-dimension/grid-size selectors, resync dimensions
when toggling Native, and wire the Native path in the FLUX graph builder.

Add working_mem_bytes for PiD Decode
Extract the PiD decode chain into buildPidDecodeChain (loaders + decode +
fit-downscale, no denoise setup) so it can substitute for the VAE decode
across generation modes. Widen addImageToImage's l2i param to ImageOutputNodes
(it only consumes .image) and wire the PiD chain into the img2img branch in
Fit mode. Native stays txt2img-only (a 4x result can't composite onto the
bbox); inpaint/outpaint remain gated off for now.
Add addPidImageToImageNative: the canvas bbox is the 4x target, so the init
image is downscaled to bbox/4, denoised at that resolution, and PiD decodes
straight back up to the full bbox with no post-decode downscale - preserving
all PiD detail while still compositing cleanly onto the region. Wire it into
the img2img branch of buildFLUXGraph (native vs fit vs off) and drop the
native-txt2img-only guard. Make the canvas FLUX grid check PiD-aware so a
native bbox must be a multiple of 64 (16 * 4) for bbox/4 to land on the grid.
Explain PiD usage on hover, mirroring the DyPE popover: what the decoder is
(NVIDIA Pixel Diffusion Decoder, 4x SR, needs a PiD decoder + Gemma-2 encoder),
Fit vs Native modes, the 2K / 2K-to-4K target resolutions, that Steps can be
lowered, and that Scale Before Processing must be off. Links to nv-tlabs/PiD.
Register NVIDIA's PiD FLUX decoders (2K and 2K-to-4K presets, from
nvidia/PiD) and the Efficient-Large-Model/gemma-2-2b-it caption encoder as
starter models so they can be installed from the Model Manager. The Gemma-2
encoder is wired as a dependency of each decoder (and offered standalone).
Add a flux2_pid_decode node that packs the stored FLUX.2 latent
(32ch @ H/8) into PiD's 128ch @ H/16 layout before decoding; FLUX.2's
BatchNorm denormalization is already applied in flux2_denoise, so no
scalar denorm is needed (optional vae input reads identity constants).

Generalize the frontend PiD decode chain (decodeNodeType, optional
vaeSource) and wire the isFlux2 graph path for txt2img/img2img (Fit &
Native). Base-aware PiD gating/decoder-filter, FLUX.2 readiness checks,
and two nvidia/PiD FLUX.2 starter decoders (2K, 2Kto4K). Standard FLUX
PiD path unchanged.
Wire the existing sd3_pid_decode node into the SD3 graph builder
(txt2img and img2img, Fit & Native) with a PiD guard, base-aware
gating/decoder-filter (sd-3), and SD3 readiness checks. Add two
nvidia/PiD SD3 starter decoders (2K, 2Kto4K).

Harden the PiD config probe against the 16-channel FLUX.1/SD3
ambiguity: when the checkpoint's directory name is silent (the HF
single-file download renames it), trust an explicit base override so
SD3 checkpoints are not misidentified as FLUX.1. Also benefits Qwen.
FLUX / FLUX.2 identification is unchanged.
Build the full SDXL PiD backend stack: _PER_BACKBONE[SDXL] (4ch/down8),
PiDDecoder_Checkpoint_SDXL_Config with a 4-channel latent-map entry,
factory union + loader registration, and a new sdxl_pid_decode node
(reads the VAE's scaling_factor/shift at runtime; SDXL fallbacks
0.13025/0.0). 4-channel latents are unambiguous, so no directory-name
disambiguation is needed.

Generalize the shared PiD decode chain to support SD-family denoise:
denoise_latents has no width/height, so thread an optional noise node
for sizing and round to the model's native grid (8 for SDXL, 16 for
FLUX). Wire buildSDXLGraph (txt2img + img2img, Fit & Native) with the
VAE as the decode's scaling source, base-aware gating/readiness, and a
starter decoder (SDXL 2Kto4K only). PiD + SDXL refiner is blocked for
now via a graph guard and a readiness reason. FLUX/FLUX.2/SD3 paths are
unchanged.
Wire the existing z_image_pid_decode node into the Z-Image graph builder
(txt2img and img2img, Fit & Native) with a PiD guard and readiness
checks. Z-Image shares FLUX.1's 16-channel VAE and has no PiD
checkpoints of its own, so it reuses the FLUX decoder: the decoder
filter maps z-image -> flux, showing FLUX PiD decoders when a Z-Image
model is active. The Z-Image VAE is passed to the decode node so it
reads the real scaling_factor / shift instead of the fallback constants.

No backend, schema, or starter-model changes. FLUX/FLUX.2/SD3/SDXL
paths are unchanged.
Build the full Qwen-Image PiD backend stack: _PER_BACKBONE[QwenImage]
(16ch/down8), PiDDecoder_Checkpoint_QwenImage_Config (added to the
16-channel latent map + filename heuristic), factory union + loader
registration, and a new qwen_image_pid_decode node.

Unlike the scalar-scaling bases, the Qwen-Image VAE normalizes per
channel (latents_mean / latents_std) and stores a 5D video-style latent,
so the node denormalizes per-channel (z * std + mean, read from the VAE
config) and drops the singleton temporal frame before decoding -
matching qwen_image_l2i.

Wire buildQwenImageGraph (txt2img + img2img, Fit & Native) with the
Qwen-Image VAE as the decode's normalization source, base-aware
gating/readiness, and a starter decoder (Qwen-Image 2Kto4K only). The
16-channel FLUX/SD3/Qwen ambiguity is handled by the existing trusted-
base-override probe hardening. FLUX/FLUX.2/SD3/SDXL/Z-Image paths are
unchanged.
@Pfannkuchensack Pfannkuchensack changed the title Feat: pid decoder Feat: Add PiD (Pixel Diffusion Decoder) 4× super-resolution decode for FLUX / FLUX.2 / SD3 / SDXL / Z-Image / Qwen-Image Jul 1, 2026
@github-actions github-actions Bot added the docs PRs that change docs label Jul 1, 2026
- graph-builder tests: set pidMode 'off' in the FLUX / Qwen-Image /
  SDXL+SD3 param fixtures so the PiD guard doesn't fire on an undefined
  pidMode and call the (unmocked) size helpers
- paramsSlice migration test: expect _version 4 (v3→v4 adds the PiD fields)
- remove the unused setPidSteps action and selectPidSteps selector flagged
  by knip; the pidSteps state field stays at its default of 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

6.14.x backend PRs that change backend files DO NOT MERGE docs PRs that change docs frontend PRs that change frontend files invocations PRs that change invocations python PRs that change python files Root services PRs that change app services

Projects

Status: 6.14.x Theme: USER EXPERIENCE

Development

Successfully merging this pull request may close these issues.

[enhancement]: Add support for PiD - Pixel Diffusion Decoder

3 participants