Skip to content

Spatial Attention: XCD-aware spatial workgroup mapping for MHA and GQA (SWIZZLE=1)#3936

Open
mc186 wants to merge 2 commits into
ROCm:mainfrom
mc186:spatial-attention
Open

Spatial Attention: XCD-aware spatial workgroup mapping for MHA and GQA (SWIZZLE=1)#3936
mc186 wants to merge 2 commits into
ROCm:mainfrom
mc186:spatial-attention

Conversation

@mc186

@mc186 mc186 commented Jun 25, 2026

Copy link
Copy Markdown

Spatial Attention: XCD-aware spatial workgroup mapping for MHA and GQA (SWIZZLE=1)

Motivation

AMD CDNA3/3.5 GPUs (MI300X, MI350X, MI355X) are built from 8 XCDs (chiplets), each with a dedicated 4 MB L2 cache. The hardware assigns workgroup wid to an XCD by wid % NUM_XCDS — round-robin, determ

In standard attention scheduling (SWIZZLE=0), consecutive workgroups cycle through all Q heads before advancing the sequence block. For MHA with many heads, this causes each head's KV data to be spread acros

Spatial scheduling (SWIZZLE=1) re-orders workgroups so that all computation touching a given KV head is pinned to one XCD. Each head's KV tensor stays hot in a single 4 MB slice rather than being replicated

Why GQA is unaffected: models with 8 KV heads (e.g. Llama3, GQA ratio 16) already achieve this locality "accidentally" — round-robin workgroup dispatch across 8 XCDs naturally maps one KV head per XCD. Spati

Changes

aiter/ops/triton/utils/_triton/pid_preprocessing.py

Adds remap_workgroup_spatial(wid, NUM_Q_HEADS, NUM_BLOCKS, BATCH, NUM_QUERIES_PER_KV, NUM_XCDS), a unified Triton JIT function with two specialised paths selected at compile time by NUM_QUERIES_PER_KV:

  • MHA path (NUM_QUERIES_PER_KV == 1): groups ceil(NUM_Q_HEADS / NUM_XCDS) consecutive Q heads onto each XCD. Ordering within each XCD is head-first (all blocks for head i before head i+1).

  • GQA path (NUM_QUERIES_PER_KV > 1): assigns each KV head exclusively to one XCD so that all Q heads sharing a KV head run on the same XCD. Ordering within each XCD is block-first (all Q heads in the group

aiter/ops/triton/_triton_kernels/attention/mha.py

  • Adds SWIZZLE: tl.constexpr kernel parameter.
  • SWIZZLE=0: preserves the existing remap_xcd behaviour (backward-compatible default).
  • SWIZZLE=1: calls remap_workgroup_spatial; NUM_QUERIES_PER_KV is derived from NUM_Q_HEADS // NUM_K_HEADS at compile time, so GQA and MHA specialise independently with zero runtime overhead.

aiter/ops/triton/attention/mha.py

  • Reads AITER_SWIZZLE environment variable at module import (default 0).
  • Adds mha_set_swizzle(value: int) for programmatic control.

Usage

# environment variable (affects all subsequent flash_attn_func calls)
AITER_SWIZZLE=1 python your_script.py
# programmatic
from aiter.ops.triton.attention.mha import mha_set_swizzle
mha_set_swizzle(1)   # enable spatial
mha_set_swizzle(0)   # restore default

Kernel-level benchmark results (AMD MI355X, BF16, batch=1)

MHA — SWIZZLE=0 vs SWIZZLE=1, non-causal

Heads (HQ=HK) Seq len SWIZZLE=0 (TFLOPS) SWIZZLE=1 (TFLOPS) Delta
128 32K 681 924 +35.7%
128 64K 636 895 +40.7%
128 128K 605 850 +40.5%
64 64K 742 811 +9.2%
16 64K 823 831 +1.0%

End-to-end model results (AMD MI355X, BF16, batch=1, random-init weights)

Whole-model forward-pass latency and socket energy (idle-baseline subtracted) on representative architectures:

Model Heads Seq Speedup Energy savings
Wan2.2-14B DiT (video diffusion, non-causal) 40 32K +4.0% +4.1%
Wan2.2-14B DiT 40 64K +14.3% +16.7%
Wan2.2-14B DiT 40 131K +20.3% +23.3%
Dense MHA 128h (causal) 128 128K +10.0% +10.6%
Dense MHA 128h (causal) 128 512K +18.0% +19.8%
Llama3-8B GQA (negative control) 32Q/8KV 32K +/-0% +/-0%

Correctness

Verified bit-identical output (rel_max = 0.00e+00) vs SWIZZLE=0 across 17 configurations on MI355X (thor-4):

  • GQA: HK in {2, 4, 8, 16, 32}, causal + non-causal, batch 1 and 4
  • MHA: HQ=HK in {16, 32, 40, 64, 128}, causal + non-causal, batch 1 and 2

Test script: op_tests/test_mha_spatial_swizzle.py (included in this branch).

mc186 added 2 commits June 25, 2026 18:55
Introduces remap_workgroup_spatial() in pid_preprocessing.py and wires
it into the flash-attention forward kernel behind AITER_SWIZZLE=1.

MHA path (spatial_mha): groups ceil(NUM_Q_HEADS/NUM_XCDS) consecutive Q
heads onto each XCD so each head's full KV tensor is processed by one
XCD, keeping it hot in that XCD's 4 MB L2 partition.

GQA path (spatial_gqa): assigns each KV head exclusively to one XCD so
all Q heads sharing a KQA head run on the same XCD. Ordering within each
XCD is block-first, which avoids the causal load-imbalance that
head-first ordering creates. Three regimes are handled: HK==NXCD
(aligned), HK>NXCD, and HK<NXCD.

AITER_SWIZZLE=0 (default) preserves the existing remap_xcd behaviour;
AITER_SWIZZLE=1 selects the new spatial mapping. The mode can also be
set programmatically via mha_set_swizzle().

Correctness verified bit-identical (rel=0.00e+00) across 17 configs
covering GQA ratios 2-16, MHA head counts 16-128, causal/non-causal,
and batch sizes 1-4 on MI355X (thor-4).
@mc186 mc186 requested a review from a team June 25, 2026 20:00
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3936 --add-label <label>

@zufayu zufayu requested a review from amd-ruitang3 June 26, 2026 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant