Spatial Attention: XCD-aware spatial workgroup mapping for MHA and GQA (SWIZZLE=1)#3936
Open
mc186 wants to merge 2 commits into
Open
Spatial Attention: XCD-aware spatial workgroup mapping for MHA and GQA (SWIZZLE=1)#3936mc186 wants to merge 2 commits into
mc186 wants to merge 2 commits into
Conversation
Introduces remap_workgroup_spatial() in pid_preprocessing.py and wires it into the flash-attention forward kernel behind AITER_SWIZZLE=1. MHA path (spatial_mha): groups ceil(NUM_Q_HEADS/NUM_XCDS) consecutive Q heads onto each XCD so each head's full KV tensor is processed by one XCD, keeping it hot in that XCD's 4 MB L2 partition. GQA path (spatial_gqa): assigns each KV head exclusively to one XCD so all Q heads sharing a KQA head run on the same XCD. Ordering within each XCD is block-first, which avoids the causal load-imbalance that head-first ordering creates. Three regimes are handled: HK==NXCD (aligned), HK>NXCD, and HK<NXCD. AITER_SWIZZLE=0 (default) preserves the existing remap_xcd behaviour; AITER_SWIZZLE=1 selects the new spatial mapping. The mode can also be set programmatically via mha_set_swizzle(). Correctness verified bit-identical (rel=0.00e+00) across 17 configs covering GQA ratios 2-16, MHA head counts 16-128, causal/non-causal, and batch sizes 1-4 on MI355X (thor-4).
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Spatial Attention: XCD-aware spatial workgroup mapping for MHA and GQA (SWIZZLE=1)
Motivation
AMD CDNA3/3.5 GPUs (MI300X, MI350X, MI355X) are built from 8 XCDs (chiplets), each with a dedicated 4 MB L2 cache. The hardware assigns workgroup
widto an XCD bywid % NUM_XCDS— round-robin, determIn standard attention scheduling (SWIZZLE=0), consecutive workgroups cycle through all Q heads before advancing the sequence block. For MHA with many heads, this causes each head's KV data to be spread acros
Spatial scheduling (SWIZZLE=1) re-orders workgroups so that all computation touching a given KV head is pinned to one XCD. Each head's KV tensor stays hot in a single 4 MB slice rather than being replicated
Why GQA is unaffected: models with 8 KV heads (e.g. Llama3, GQA ratio 16) already achieve this locality "accidentally" — round-robin workgroup dispatch across 8 XCDs naturally maps one KV head per XCD. Spati
Changes
aiter/ops/triton/utils/_triton/pid_preprocessing.pyAdds
remap_workgroup_spatial(wid, NUM_Q_HEADS, NUM_BLOCKS, BATCH, NUM_QUERIES_PER_KV, NUM_XCDS), a unified Triton JIT function with two specialised paths selected at compile time byNUM_QUERIES_PER_KV:MHA path (
NUM_QUERIES_PER_KV == 1): groupsceil(NUM_Q_HEADS / NUM_XCDS)consecutive Q heads onto each XCD. Ordering within each XCD is head-first (all blocks for head i before head i+1).GQA path (
NUM_QUERIES_PER_KV > 1): assigns each KV head exclusively to one XCD so that all Q heads sharing a KV head run on the same XCD. Ordering within each XCD is block-first (all Q heads in the groupaiter/ops/triton/_triton_kernels/attention/mha.pySWIZZLE: tl.constexprkernel parameter.SWIZZLE=0: preserves the existingremap_xcdbehaviour (backward-compatible default).SWIZZLE=1: callsremap_workgroup_spatial;NUM_QUERIES_PER_KVis derived fromNUM_Q_HEADS // NUM_K_HEADSat compile time, so GQA and MHA specialise independently with zero runtime overhead.aiter/ops/triton/attention/mha.pyAITER_SWIZZLEenvironment variable at module import (default0).mha_set_swizzle(value: int)for programmatic control.Usage
# environment variable (affects all subsequent flash_attn_func calls) AITER_SWIZZLE=1 python your_script.pyKernel-level benchmark results (AMD MI355X, BF16, batch=1)
MHA — SWIZZLE=0 vs SWIZZLE=1, non-causal
End-to-end model results (AMD MI355X, BF16, batch=1, random-init weights)
Whole-model forward-pass latency and socket energy (idle-baseline subtracted) on representative architectures:
Correctness
Verified bit-identical output (
rel_max = 0.00e+00) vs SWIZZLE=0 across 17 configurations on MI355X (thor-4):Test script:
op_tests/test_mha_spatial_swizzle.py(included in this branch).