Feat: Add PiD (Pixel Diffusion Decoder) 4× super-resolution decode for FLUX / FLUX.2 / SD3 / SDXL / Z-Image / Qwen-Image#9281
Draft
Pfannkuchensack wants to merge 28 commits into
Draft
Conversation
Adds a vendored subset of NVIDIA's PiD (Pixel Diffusion Decoder) at invokeai/backend/pid/ as the foundation for upcoming FLUX / FLUX.2 / SD3 / Z-Image PiD decode nodes plus a future PiD-based 4x upscale node. Upstream: https://github.com/nv-tlabs/PiD (Apache 2.0). Vendor scope: * _src/{networks,models,modules}: PidNet, PixDiT_T2I, LQProjection2D, PidModel, PidDistillModel, PixelDiTModel, GeneralConditioner. * _ext/imaginaire: minimal Imaginaire framework subset (lazy_config, model, utils/{log,misc,distributed,device,count_params}). * configs/, tokenizers/, checkpointer/, trainer.py, visualize/, _demo_*, from_*, easy_io/, S3/wandb training helpers were intentionally excluded. Dependency stripping (no new hard deps introduced): * loguru, termcolor -> stdlib logging shim * iopath PathManager -> stdlib pathlib stub * fvcore Registry -> minimal stdlib Registry * lazy_config/lazy.py: yaml/dill/cloudpickle/detectron2 save/load paths replaced with a minimal LazyCall stub * lazy_config/instantiate.py: omegaconf DictConfig/ListConfig branches removed; configs are plain dict / LazyCall mappings * megatron, pynvml, boto3/wandb imports are try/except-guarded or local to functions and stay inert in our inference path All pid.* imports rewritten to invokeai.backend.pid.*; SPDX-Apache-2.0 headers retained on vendored files; attribution and detailed list of local modifications added in LICENSE-PiD.txt. The pre-trained PiD checkpoints distributed by NVIDIA remain under NSCLv1 (non-commercial); this commit only vendors code. Smoke test: PidNet, PidModel, PidDistillModel, GeneralConditioner import cleanly; LazyCall -> instantiate round-trip resolves to the expected nn.Module. ruff check passes.
Adds the model-manager plumbing and workflow nodes needed to use the
vendored PiD decoder (phase A) end-to-end with FLUX, SD3 and Z-Image.
Model manager (Phase B + B.5):
* taxonomy: ModelType.PiDDecoder, PiDDecoderVariantType
(Res2k_Sr4x / Res2kTo4k_Sr4x), ModelType.Gemma2Encoder +
ModelFormat.Gemma2Encoder, both added to AnyVariant +
variant_type_adapter.
* configs/pid_decoder.py: per-backbone PiD configs
(FLUX / FLUX.2 / SD3) with state-dict probing on 'lq_proj' substring
and backbone/variant detection from the official NVIDIA filenames.
* configs/gemma2_encoder.py: Gemma-2 directory probing on
Gemma2ForCausalLM architecture + tokenizer files.
* AnyModelConfig union updated.
* model_loaders/pid_decoder.py: loads .pth / .safetensors, strips
the upstream 'net.' prefix, supports torch.load(weights_only=True).
* model_loaders/gemma2_encoder.py: SubModelType.{Tokenizer,
TextEncoder} dispatch; returns the causal LM's inner Gemma2Model
(transformers 4.56's get_decoder() returns None for Gemma2).
Decode pipeline (Phase C):
* backend/pid/decode.py: build_pid_net + load_pid_decoder
(per-backbone PixDiT_T2I hyperparams derived from PiD's pid_sr4x
base + per-experiment overrides), encode_caption_for_pid (chi-prompt
+ Gemma encoding, mirrors PixelDiTModel._encode_text_raw), and a
PiDDecoder wrapper with a reimplemented few-step distill sampler
(no autocast / no distributed / no PixelDiTModel init paths from
upstream).
Invocations (Phase 6.x):
* Gemma2EncoderField + PiDDecoderField in invocations/model.py.
* gemma2_encoder_loader / pid_decoder_loader: thin
ModelIdentifierField pickers that emit the corresponding fields.
* z_image_pid_decode (pilot), flux_pid_decode, sd3_pid_decode:
caption encode -> Gemma offload -> PiD state dict load ->
PidNet construct -> decode. Per-backbone latent denormalisation
(FLUX1 ae_params, SD3 hardcoded 1.5305/0.0609, Z-Image piggybacks
on FLUX VAE).
End-to-end validated with the released
PiD_res2k_sr4x_official_flux_distill_4step.pth checkpoint and
gemma-2-2b-it: PidNet rebuilds at exactly 456 keys / 1.36B params,
sampler runs at ~5 GB VRAM peak (Gemma dominates), output shape and
range match.
FLUX.2 PiD decode is deliberately deferred: it needs BN-based
latent denormalisation and 32->128 channel packing, and we have no
FLUX.2 checkpoint to validate against yet.
Adds the NVIDIA PiD decoder as a 4x super-resolution alternative to the regular VAE/RAE decode path. Includes model-manager configs and loaders for both the PiD checkpoints and the Gemma-2 caption encoder they require, plus four invocations: latent-in decode for FLUX / SD3 / Z-Image and an image-in pid_upscale node. - Decode pipeline keeps PidNet params in fp32 and uses bf16 autocast only for matmuls; caption embeddings have outliers that overflow bf16 RMSNorm. - encode_caption_for_pid forces tokenizer padding_side="right" (Gemma defaults to left, PiD trained with right) and returns the attention mask as bool so it stays compatible with SDPA. - Z-Image reuses the FLUX-trained checkpoint and reads scale/shift from the VAE config at runtime (PiD upstream notes they are checkpoint-specific). - TextLLM config now excludes Gemma2ForCausalLM so it falls through to the dedicated Gemma2 encoder config instead of being misclassified. - Frontend: new model_type / model_format / variant enums, type guards and category metadata; schema.ts regenerated via pnpm typegen.
Read latent channel count from lq_proj.latent_proj.0.weight (FLUX.2=128, FLUX.1/SD3=16) as the primary discriminator; fall back to filename/dir name only to disambiguate the architecturally identical FLUX.1/SD3 pair. Fixes FLUX.2 checkpoints (model_ema_bf16.pth) not being recognised, and correctly rejects unsupported backbones (RAE/dinov2, 768ch). Fix Flux2 docstring 32->128.
Add a "PiD Decode" mode select (Off / Fit / Native) to the FLUX advanced settings with PiD decoder + Gemma-2 encoder pickers. In Fit mode the FLUX graph swaps the VAE decode for a PiD 4x super-resolution decode and downscales back to the requested size. Adds params state (pidMode, decoder, encoder, steps) with a v3->v4 migration, model hooks, readiness checks, and graph guards for the not-yet-wired Native and non-txt2img paths.
Make the generation dimension helpers PiD-aware via an optional pidScale: in Native mode the user-facing dimensions are the 4x target (grid 64, optimal 2048), generation runs at target/4, and PiD's 4x output is used directly with no downscale. Thread pidScale through the params dimension reducers and the optimal-dimension/grid-size selectors, resync dimensions when toggling Native, and wire the Native path in the FLUX graph builder. Add working_mem_bytes for PiD Decode
Extract the PiD decode chain into buildPidDecodeChain (loaders + decode + fit-downscale, no denoise setup) so it can substitute for the VAE decode across generation modes. Widen addImageToImage's l2i param to ImageOutputNodes (it only consumes .image) and wire the PiD chain into the img2img branch in Fit mode. Native stays txt2img-only (a 4x result can't composite onto the bbox); inpaint/outpaint remain gated off for now.
Add addPidImageToImageNative: the canvas bbox is the 4x target, so the init image is downscaled to bbox/4, denoised at that resolution, and PiD decodes straight back up to the full bbox with no post-decode downscale - preserving all PiD detail while still compositing cleanly onto the region. Wire it into the img2img branch of buildFLUXGraph (native vs fit vs off) and drop the native-txt2img-only guard. Make the canvas FLUX grid check PiD-aware so a native bbox must be a multiple of 64 (16 * 4) for bbox/4 to land on the grid.
Explain PiD usage on hover, mirroring the DyPE popover: what the decoder is (NVIDIA Pixel Diffusion Decoder, 4x SR, needs a PiD decoder + Gemma-2 encoder), Fit vs Native modes, the 2K / 2K-to-4K target resolutions, that Steps can be lowered, and that Scale Before Processing must be off. Links to nv-tlabs/PiD.
Register NVIDIA's PiD FLUX decoders (2K and 2K-to-4K presets, from nvidia/PiD) and the Efficient-Large-Model/gemma-2-2b-it caption encoder as starter models so they can be installed from the Model Manager. The Gemma-2 encoder is wired as a dependency of each decoder (and offered standalone).
Add a flux2_pid_decode node that packs the stored FLUX.2 latent (32ch @ H/8) into PiD's 128ch @ H/16 layout before decoding; FLUX.2's BatchNorm denormalization is already applied in flux2_denoise, so no scalar denorm is needed (optional vae input reads identity constants). Generalize the frontend PiD decode chain (decodeNodeType, optional vaeSource) and wire the isFlux2 graph path for txt2img/img2img (Fit & Native). Base-aware PiD gating/decoder-filter, FLUX.2 readiness checks, and two nvidia/PiD FLUX.2 starter decoders (2K, 2Kto4K). Standard FLUX PiD path unchanged.
Wire the existing sd3_pid_decode node into the SD3 graph builder (txt2img and img2img, Fit & Native) with a PiD guard, base-aware gating/decoder-filter (sd-3), and SD3 readiness checks. Add two nvidia/PiD SD3 starter decoders (2K, 2Kto4K). Harden the PiD config probe against the 16-channel FLUX.1/SD3 ambiguity: when the checkpoint's directory name is silent (the HF single-file download renames it), trust an explicit base override so SD3 checkpoints are not misidentified as FLUX.1. Also benefits Qwen. FLUX / FLUX.2 identification is unchanged.
Build the full SDXL PiD backend stack: _PER_BACKBONE[SDXL] (4ch/down8), PiDDecoder_Checkpoint_SDXL_Config with a 4-channel latent-map entry, factory union + loader registration, and a new sdxl_pid_decode node (reads the VAE's scaling_factor/shift at runtime; SDXL fallbacks 0.13025/0.0). 4-channel latents are unambiguous, so no directory-name disambiguation is needed. Generalize the shared PiD decode chain to support SD-family denoise: denoise_latents has no width/height, so thread an optional noise node for sizing and round to the model's native grid (8 for SDXL, 16 for FLUX). Wire buildSDXLGraph (txt2img + img2img, Fit & Native) with the VAE as the decode's scaling source, base-aware gating/readiness, and a starter decoder (SDXL 2Kto4K only). PiD + SDXL refiner is blocked for now via a graph guard and a readiness reason. FLUX/FLUX.2/SD3 paths are unchanged.
Wire the existing z_image_pid_decode node into the Z-Image graph builder (txt2img and img2img, Fit & Native) with a PiD guard and readiness checks. Z-Image shares FLUX.1's 16-channel VAE and has no PiD checkpoints of its own, so it reuses the FLUX decoder: the decoder filter maps z-image -> flux, showing FLUX PiD decoders when a Z-Image model is active. The Z-Image VAE is passed to the decode node so it reads the real scaling_factor / shift instead of the fallback constants. No backend, schema, or starter-model changes. FLUX/FLUX.2/SD3/SDXL paths are unchanged.
Build the full Qwen-Image PiD backend stack: _PER_BACKBONE[QwenImage] (16ch/down8), PiDDecoder_Checkpoint_QwenImage_Config (added to the 16-channel latent map + filename heuristic), factory union + loader registration, and a new qwen_image_pid_decode node. Unlike the scalar-scaling bases, the Qwen-Image VAE normalizes per channel (latents_mean / latents_std) and stores a 5D video-style latent, so the node denormalizes per-channel (z * std + mean, read from the VAE config) and drops the singleton temporal frame before decoding - matching qwen_image_l2i. Wire buildQwenImageGraph (txt2img + img2img, Fit & Native) with the Qwen-Image VAE as the decode's normalization source, base-aware gating/readiness, and a starter decoder (Qwen-Image 2Kto4K only). The 16-channel FLUX/SD3/Qwen ambiguity is handled by the existing trusted- base-override probe hardening. FLUX/FLUX.2/SD3/SDXL/Z-Image paths are unchanged.
- graph-builder tests: set pidMode 'off' in the FLUX / Qwen-Image / SDXL+SD3 param fixtures so the PiD guard doesn't fire on an undefined pidMode and call the (unmocked) size helpers - paramsSlice migration test: expect _version 4 (v3→v4 adds the PiD fields) - remove the unused setPidSteps action and selectPidSteps selector flagged by knip; the pidSteps state field stays at its default of 4
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds PiD (Pixel Diffusion Decoder) support to InvokeAI — NVIDIA's few-step pixel-diffusion decoder that replaces the regular VAE decode with a caption-conditioned, 4× super-resolution decode (512→2048 in a single 4-step distill pass).
This PR vendors a minimal, inference-only subset of PiD at
invokeai/backend/pid/(upstream https://github.com/nv-tlabs/PiD, code Apache-2.0) and wires it end-to-end into the model manager, invocation nodes, starter models, and the generation UI.What you get
backend/pid/decode.py): per-backbone net config (_PER_BACKBONE),build_pid_net/load_pid_decoder/PiDDecoder, Gemma-2 caption encoding, and a working-memory estimator for the cache.pid_decoder_loader(→PiDDecoderField) andgemma2_encoder_loader(→Gemma2EncoderField), plus model-manager configs/loaders for PiD checkpoints and the shared Gemma-2 caption encoder.flux_pid_decodeflux2_pid_decodesd3_pid_decodesdxl_pid_decodescaling_factorread at runtimez_image_pid_decodeqwen_image_pid_decodelatents_mean/latents_stddenorm + 5D→4D temporal squeezenvidia/PiD, per backbone; FLUX/FLUX.2/SD3 ship 2K + 2K-to-4K, SDXL/Qwen-Image ship 2K-to-4K only) plus the sharedgemma-2-2b-itcaption encoder.Robustness details
baseoverride (which the starter installer sends) when the directory name is ambiguous — so single-file HF downloads are still identified correctly.Related Issues / Discussions
Closes #9240
QA Instructions
base(e.g. a Qwen-Image decoder shows asqwen-image, not FLUX).Automated gates already green on the branch: backend imports (
starter_models, config factory, every*_pid_decodenode), and the frontendpnpm lint:tsc / lint:eslint / lint:knip / lint:dpdm.Not yet hardware-verified (needs a GPU + downloaded checkpoints): full end-to-end image output per base, measured VRAM peak (the working-memory constant is calibrated to a 2048px output ≈ 4.3 GB), and Qwen-Image Edit-mode (reference image) + PiD.
Merge Plan
Self-contained feature; no redux migration beyond the already-included
paramsslice bump.Checklist
_versionbump)What's Newcopy (if doing a release after this PR)