Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
12ef21e
fmha f16 aiter integration
Apr 29, 2026
f1d631b
update fmha fwd f16 integration
May 1, 2026
37805ac
fmha_fwd_f16: refine API surface, fix perf timing on gfx1250, polish …
May 4, 2026
edb9b77
reformat python files
May 7, 2026
ce1268e
reformat
May 7, 2026
b9af1ac
reformat
May 7, 2026
3866ad1
move module import to top of file
May 7, 2026
9741a1b
reformat
May 7, 2026
31f77ff
reformat
May 7, 2026
dd98467
reformat
May 7, 2026
a5bcd1c
remove unused import
May 7, 2026
82e9a29
sync 3rdparty/composable_kernel submodule pin with main (fdf4bb7f)
May 7, 2026
db4b417
fmha_fwd_f16: gfx1250-only check
May 8, 2026
b684c05
add arch id guard
May 9, 2026
7cc452e
runtime guard gfx1250
May 9, 2026
279b557
add arch guard in main
May 10, 2026
54a5c25
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 11, 2026
60d3555
reorg fmha_fwd_f16 integration
May 12, 2026
ad4f975
Merge branch 'tingchen_fmha_f16' of https://github.com/ROCm/aiter int…
May 12, 2026
e14c5e0
reformat
May 12, 2026
b2db8b6
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 13, 2026
e0af956
ENABLE_CK=0 for module_aiter_core to bypass ck compile
May 14, 2026
0426023
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 14, 2026
869fb63
Revert CK submodule pin to match main
HaonanWang98 May 14, 2026
8587a40
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 14, 2026
4a59221
update kernel and .cu to v8
HaonanWang98 May 15, 2026
a9958a4
Merge remote-tracking branch 'origin/main' into tingchen_fmha_f16
May 15, 2026
48c01fd
reformat
May 16, 2026
e46202a
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 19, 2026
ff93809
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 20, 2026
37dcedd
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 20, 2026
c1dcc8c
fix .cu issue and replace .co
HaonanWang98 May 24, 2026
e94f34d
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 24, 2026
f0d942e
reformat
HaonanWang98 May 24, 2026
5775c00
rename api
May 25, 2026
26a049b
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 25, 2026
73377ae
add .co
May 26, 2026
8d00c08
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 26, 2026
46b3ffc
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 26, 2026
389e9f1
rename f16 to bf16
May 27, 2026
5853b59
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 27, 2026
dcfbe7b
resolve issue pointed by copilot
Boss2002n May 30, 2026
8924874
reformat
Boss2002n May 30, 2026
04b6732
Merge branch 'main' into tingchen_fmha_f16
tingchen988 May 30, 2026
0ff866e
Merge branch 'main' into tingchen_fmha_f16
tingchen988 Jun 1, 2026
50ec1b4
Merge branch 'main' into tingchen_fmha_f16
tingchen988 Jun 1, 2026
32a611f
Merge branch 'main' into tingchen_fmha_f16
tingchen988 Jun 3, 2026
292c9ef
revert kernel to non kargs preload
Jun 4, 2026
262dedc
Merge branch 'main' into tingchen_fmha_f16
tingchen988 Jun 4, 2026
22ab61c
passthrough sink; add varlen kernel
Jun 7, 2026
87027af
Merge branch 'main' into tingchen_fmha_f16
tingchen988 Jun 7, 2026
bd35371
reformat
Jun 7, 2026
e6a818f
set opt=0 for varlen
Jun 8, 2026
f692877
connect to public api
Jun 9, 2026
204f4b3
reformat
Jun 9, 2026
dc4812f
enhance test
Jun 11, 2026
d15d5f0
set double_q=0
ahmed-bsod Jun 11, 2026
86b05b1
Merge branch main into tingchen_fmha_f16
ahmed-bsod Jun 11, 2026
0104bce
Route gfx1250 prefill varlen before FlyDSL
yhl-amd Jun 11, 2026
169af67
reformat
junxiaguo Jun 12, 2026
6e27400
add kargs preload mha bf16 kernel into aiter
ahmed-bsod Jun 23, 2026
64b369d
Merge branch 'main' of https://github.com/ROCm/aiter into tingchen_bf…
ahmed-bsod Jun 23, 2026
6488178
asm bf16 mha: add non causal kernel
ahmed-bsod Jun 26, 2026
214aad6
Merge remote-tracking branch 'origin/main' into tingchen_mha_mask0
junxiaguo Jun 26, 2026
4eac00a
asm bf16 mha: add mask kernel .co binaries
junxiaguo Jun 26, 2026
86d97cb
reformat
ahmed-bsod Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,8 +1594,6 @@ def can_impl_fmha_fwd_with_sink_asm():
# (per-Q-head fp32) supported; sink-token (sink_size) not supported.
ret = get_gfx() == "gfx1250"
ret = ret and (q.dtype == dtypes.bf16)
# Only causal gfx1250 binaries are registered in fmha_fwd_bf16*.csv.
ret = ret and bool(causal)
ret = ret and (hdim_q in (64, 128))
ret = ret and (hdim_v == hdim_q)
ret = ret and (nhead_q % nhead_k == 0)
Expand Down Expand Up @@ -2505,8 +2503,6 @@ def can_impl_fmha_fwd_with_sink_varlen_asm():
# logits (per-Q-head fp32) supported; sink-token (sink_size) not.
ret = get_gfx() == "gfx1250"
ret = ret and (q.dtype == dtypes.bf16)
# Only causal gfx1250 binaries are registered in fmha_fwd_bf16*.csv.
ret = ret and bool(causal)
ret = ret and (hdim_q in (64, 128))
ret = ret and (hdim_v == hdim_q)
ret = ret and (nhead_q % nhead_k == 0)
Expand Down
8 changes: 4 additions & 4 deletions csrc/py_itfs_cu/asm_fmha_fwd_with_sink.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ static_assert(sizeof(KernelArgs) == 0x84,

// ---- helpers ---------------------------------------------------------------

// Kernel selection: only (dtype, hdim_q, hdim_v, mask) — we always use the
// _brd (border) kernel variants which are a strict superset (handle aligned
// + unaligned q_seq_len/kv_seq_len uniformly). The csv schema therefore has
// no `border` column.
// Kernel selection: (dtype, hdim_q, hdim_v, mask). mask = is_causal: the csv
// registers both mask=0 (non-causal, _rxy_pfnr source) and mask=1 (causal,
// _rxy_pfnr_cas_brd source -> *_mask.co) variants, so is_causal picks the
// matching .co at launch time.
static std::string get_heuristic_kernel_fmha_fwd_bf16(const std::string& dtype,
int hdim_q,
int hdim_v,
Expand Down
6 changes: 4 additions & 2 deletions csrc/py_itfs_cu/asm_fmha_fwd_with_sink_varlen.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ static_assert(sizeof(FmhaFwdVarlenKernelArgs) == 0x58,

// ---- helpers ---------------------------------------------------------------

// Kernel selection: only (dtype, hdim_q, hdim_v, mask). Only the _brd (border)
// causal kernels are shipped, so mask is always 1.
// Kernel selection: (dtype, hdim_q, hdim_v, mask). mask = is_causal: the csv
// registers both mask=0 (non-causal, _rxy[_sink]_pfnr source) and mask=1
// (causal, _rxy[_sink]_pfnr_cas_brd source -> *_mask_varlen.co) variants, so
// is_causal picks the matching .co at launch time.
static std::string get_heuristic_kernel_fmha_fwd_bf16_varlen(const std::string& dtype,
int hdim_q,
int hdim_v,
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
6 changes: 4 additions & 2 deletions hsa/gfx1250/fmha_fwd_bf16/fmha_fwd_bf16.csv
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dtype,hdim_q,hdim_v,mask,knl_name,co_name
bf16,64,64,1,_ZN5aiter35fmha_bf16_pertokenBf16_hd64_128x256E,fmha_bf16_pertokenBf16_hd64_128x256.co
bf16,128,128,1,_ZN5aiter36fmha_bf16_pertokenBf16_hd128_128x256E,fmha_bf16_pertokenBf16_hd128_128x256.co
bf16,64,64,0,_ZN5aiter35fmha_bf16_pertokenBf16_hd64_128x256E,fmha_bf16_pertokenBf16_hd64_128x256.co
bf16,128,128,0,_ZN5aiter36fmha_bf16_pertokenBf16_hd128_128x256E,fmha_bf16_pertokenBf16_hd128_128x256.co
bf16,64,64,1,_ZN5aiter40fmha_bf16_pertokenBf16_hd64_128x256_maskE,fmha_bf16_pertokenBf16_hd64_128x256_mask.co
bf16,128,128,1,_ZN5aiter41fmha_bf16_pertokenBf16_hd128_128x256_maskE,fmha_bf16_pertokenBf16_hd128_128x256_mask.co
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
6 changes: 4 additions & 2 deletions hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_fwd_bf16_varlen.csv
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dtype,hdim_q,hdim_v,mask,knl_name,co_name
bf16,64,64,1,_ZN5aiter42fmha_bf16_pertokenBf16_hd64_128x256_varlenE,fmha_bf16_pertokenBf16_hd64_128x256_varlen.co
bf16,128,128,1,_ZN5aiter43fmha_bf16_pertokenBf16_hd128_128x256_varlenE,fmha_bf16_pertokenBf16_hd128_128x256_varlen.co
bf16,64,64,0,_ZN5aiter42fmha_bf16_pertokenBf16_hd64_128x256_varlenE,fmha_bf16_pertokenBf16_hd64_128x256_varlen.co
bf16,128,128,0,_ZN5aiter43fmha_bf16_pertokenBf16_hd128_128x256_varlenE,fmha_bf16_pertokenBf16_hd128_128x256_varlen.co
bf16,64,64,1,_ZN5aiter47fmha_bf16_pertokenBf16_hd64_128x256_mask_varlenE,fmha_bf16_pertokenBf16_hd64_128x256_mask_varlen.co
bf16,128,128,1,_ZN5aiter48fmha_bf16_pertokenBf16_hd128_128x256_mask_varlenE,fmha_bf16_pertokenBf16_hd128_128x256_mask_varlen.co
60 changes: 28 additions & 32 deletions op_tests/test_fmha_fwd_with_sink_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,34 +347,32 @@ def run_ref(q, k, v, *, is_causal: bool, sink: Optional[torch.Tensor] = None):
# ---------------------------------------------------------------------------


# Only causal kernels are shipped on gfx1250 (CSV registers only `mask=1`
# entries — the nocausal `_brd_v8` binaries were removed). is_causal is kept
# as a parameter so the kernel-call sites still receive the (now always-True)
# flag explicitly; if a nocausal binary is re-added, just add `False` back.
@pytest.mark.parametrize("is_causal", [True])
@pytest.mark.parametrize(
"head_dim,hq,hk,sq,sk,batch",
[
# ----- Small shapes (cheap, GQA-light) ---------------------------
# Catch unaligned-sq / unaligned-sk corner cases without paying
# the cost of materializing the full [b, h, sq, sk] fp32 attn
# matrix in _ref_attn.
(64, 8, 1, 128, 2048, 1), # D64 aligned
(64, 8, 1, 128, 2048, 2),
(64, 8, 1, 130, 2048, 1), # D64 q-unaligned (sq not mult of 128)
(64, 8, 1, 128, 2300, 1), # D64 kv-unaligned (sk not mult of 256)
(128, 8, 1, 128, 2048, 1), # D128 aligned
(128, 8, 1, 128, 2048, 2),
(128, 8, 1, 130, 2048, 1), # D128 q-unaligned
(128, 8, 1, 128, 2300, 1), # D128 kv-unaligned
# ----- Large shapes aligned to run.sh perf_v4_d64 / perf_v4_d128 -
# Same memory pressure as test_fmha_fwd_with_sink_asm_perf, batch=1 only
# because the reference path's fp32 attn matrix would otherwise
# exceed device memory (D64 batch=2 sq=sk=8192 → 32 GB).
(64, 64, 8, 8192, 8192, 1), # D64 perf-sized, aligned
(128, 64, 4, 4096, 4096, 1), # D128 perf-sized, aligned
],
)
# KV-length constraint (mask=0 only): the non-causal (mask=0) kernels only
# support sk (kv_seqlen) that is a multiple of 256.
_CORRECTNESS_SHAPES = [
# ----- Small shapes (cheap, GQA-light) ---------------------------
(64, 8, 1, 128, 2048, 1), # D64 aligned
(64, 8, 1, 128, 2048, 2),
(64, 8, 1, 130, 2048, 1), # D64 q-unaligned (sq not mult of 128)
(64, 8, 1, 128, 2300, 1), # D64 kv-unaligned (sk not mult of 256) -> causal only
(128, 8, 1, 128, 2048, 1), # D128 aligned
(128, 8, 1, 128, 2048, 2),
(128, 8, 1, 130, 2048, 1), # D128 q-unaligned
(128, 8, 1, 128, 2300, 1), # D128 kv-unaligned -> causal only
(64, 64, 8, 8192, 8192, 1), # D64 perf-sized, aligned
(128, 64, 4, 4096, 4096, 1), # D128 perf-sized, aligned
]


_CORRECTNESS_CASES = [
(head_dim, hq, hk, sq, sk, batch, causal)
for (head_dim, hq, hk, sq, sk, batch) in _CORRECTNESS_SHAPES
for causal in (True, False)
if causal or (sk % 256 == 0)
]


@pytest.mark.parametrize("head_dim,hq,hk,sq,sk,batch,is_causal", _CORRECTNESS_CASES)
def test_fmha_fwd_with_sink_asm_correctness(head_dim, hq, hk, sq, sk, batch, is_causal):
device = "cuda"
torch.manual_seed(0)
Expand Down Expand Up @@ -573,8 +571,7 @@ def test_fmha_fwd_with_sink_asm_layout(layout, head_dim):


@pytest.mark.parametrize("head_dim", [64, 128])
# Only causal kernels are shipped (see test_fmha_fwd_with_sink_asm_correctness comment).
@pytest.mark.parametrize("is_causal", [True])
@pytest.mark.parametrize("is_causal", [True, False])
def test_fmha_fwd_with_sink_asm_via_flash_attn_func(head_dim, is_causal):
device = "cuda"
torch.manual_seed(0)
Expand Down Expand Up @@ -805,8 +802,7 @@ def _make_qkv_perf(init: str, *, layout, sq, sk, batch, hq, hk, d, dtype, device

@pytest.mark.parametrize("init", _PERF_INITS)
@pytest.mark.parametrize("head_dim,seqlen", _PERF_SHAPES)
# Only causal kernels are shipped (see test_fmha_fwd_with_sink_asm_correctness comment).
@pytest.mark.parametrize("is_causal", [True])
@pytest.mark.parametrize("is_causal", [True, False])
def test_fmha_fwd_with_sink_asm_perf(head_dim, seqlen, is_causal, init):
device = "cuda"
torch.manual_seed(0)
Expand Down
116 changes: 82 additions & 34 deletions op_tests/test_fmha_fwd_with_sink_varlen_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
to the kernel verbatim (no host-side scaling). D64 kernels read it; D128
kernels ignore it (pass None).

Only causal kernels are shipped (CSV registers mask=1 rows), so is_causal=True.
Causal uses bottom-right alignment per sequence (query i attends to key j iff
j <= i + (sk - sq)), matching flash_attn varlen semantics.

KV-length constraint (mask=0 only): the non-causal (mask=0) kernels only
support per-sequence kv_seqlen that is a multiple of 256.
"""

from __future__ import annotations
Expand Down Expand Up @@ -131,11 +131,22 @@ def _ref_varlen(q, k, v, cu_q, cu_k, *, is_causal: bool, sink: Optional[torch.Te


def make_varlen_packed(
seqlens: List[int], hq: int, hk: int, d: int, dv: int, device="cuda", seed=0
seqlens: List[int],
hq: int,
hk: int,
d: int,
dv: int,
device="cuda",
seed=0,
init: str = "randn",
):
"""Build packed THD q/k/v + cu_seqlens for the given per-batch seqlens.

Uses equal q/k seqlens per batch (standard varlen self-attention).

init pattern (mirrors the fixed-batch perf test's `_make_qkv_perf`):
"randn" : standard normal (default; exercises real attention math).
"const0.25" : fill every element with 0.25
"""
torch.manual_seed(seed)
cu = torch.tensor(
Expand All @@ -145,6 +156,12 @@ def make_varlen_packed(
q = torch.randn(total, hq, d, dtype=torch.bfloat16, device=device)
k = torch.randn(total, hk, d, dtype=torch.bfloat16, device=device)
v = torch.randn(total, hk, dv, dtype=torch.bfloat16, device=device)
if init == "const0.25":
q.fill_(0.25)
k.fill_(0.25)
v.fill_(0.25)
elif init != "randn":
raise ValueError(f"unknown perf init pattern: {init!r}")
cu = cu.to(device)
return q, k, v, cu

Expand Down Expand Up @@ -205,23 +222,39 @@ def run_kernel(
# ---------------------------------------------------------------------------


@pytest.mark.parametrize("is_causal", [True])
@pytest.mark.parametrize(
"head_dim,hq,hk,seqlens",
[
# aligned single batch
(64, 8, 1, [256]),
(128, 8, 1, [256]),
# multi-batch, mixed (some unaligned) seqlens
(64, 8, 1, [128, 256, 384]),
(128, 8, 1, [128, 256, 384]),
(64, 8, 2, [100, 200, 300]), # unaligned + GQA
(128, 8, 2, [100, 200, 300]),
# GQA-heavy, larger
(64, 64, 8, [512, 1024]),
(128, 64, 4, [512, 1024]),
],
)
_CORRECTNESS_SHAPES = [
# aligned single batch
(64, 8, 1, [256]),
(128, 8, 1, [256]),
# multi-batch, mixed (some unaligned) seqlens -> causal only
(64, 8, 1, [128, 256, 384]),
(128, 8, 1, [128, 256, 384]),
(64, 8, 2, [100, 200, 300]), # unaligned + GQA
(128, 8, 2, [100, 200, 300]),
# 256-aligned multi-batch (exercised under BOTH causal and mask=0)
(64, 8, 1, [256, 512]),
(128, 8, 1, [256, 512]),
(64, 8, 2, [256, 512, 768]), # aligned 3-batch + GQA
(128, 8, 2, [256, 512, 768]),
# GQA-heavy, larger (256-aligned)
(64, 64, 8, [512, 1024]),
(128, 64, 4, [512, 1024]),
]


def _kv_256_aligned(seqlens) -> bool:
return all(s % 256 == 0 for s in seqlens)


_CORRECTNESS_CASES = [
(hd, hq, hk, sl, causal)
for (hd, hq, hk, sl) in _CORRECTNESS_SHAPES
for causal in (True, False)
if causal or _kv_256_aligned(sl)
]


@pytest.mark.parametrize("head_dim,hq,hk,seqlens,is_causal", _CORRECTNESS_CASES)
def test_fmha_fwd_with_sink_varlen_asm_correctness(
head_dim, hq, hk, seqlens, is_causal
):
Expand Down Expand Up @@ -271,10 +304,10 @@ def test_fmha_fwd_with_sink_varlen_asm_correctness(


@pytest.mark.parametrize("head_dim", [64, 128])
@pytest.mark.parametrize("is_causal", [True])
@pytest.mark.parametrize("is_causal", [True, False])
def test_fmha_fwd_with_sink_varlen_asm_via_flash_attn_varlen_func(head_dim, is_causal):
device = "cuda"
hq, hk, seqlens = 8, 1, [128, 256, 384]
hq, hk, seqlens = 8, 1, [256, 512, 768]
q, k, v, cu = make_varlen_packed(seqlens, hq, hk, head_dim, head_dim, device=device)
max_seqlen_q = max(seqlens)
scale = 1.0 / math.sqrt(head_dim)
Expand Down Expand Up @@ -338,15 +371,27 @@ def _bench(fn, *args, num_iters=20, num_warmup=10, **kwargs) -> float:
return start.elapsed_time(end) * 1000.0 / num_iters # us per iter


@pytest.mark.parametrize("head_dim", [64, 128])
@pytest.mark.parametrize("is_causal", [True])
def test_fmha_fwd_with_sink_varlen_asm_perf(head_dim, is_causal):
_VARLEN_PERF_SHAPES = [
(64, 64, 8, [4096, 4096]), # D64 multi-batch
(128, 64, 4, [2048, 2048]), # D128 multi-batch
(128, 64, 4, [16384]), # D128 sq=sk=16384 (long context)
(64, 64, 8, [32768]), # D64 sq=sk=32768 (long context)
]

# Perf input init patterns (mirrors the fixed-batch perf test's _PERF_INITS):
# "randn" : standard normal (default; exercises real attention math).
# "const0.25" : constant 0.25 fill (matches the cpp init_pattern=10 baseline).
_VARLEN_PERF_INITS = ["randn", "const0.25"]


@pytest.mark.parametrize("init", _VARLEN_PERF_INITS)
@pytest.mark.parametrize("head_dim,hq,hk,seqlens", _VARLEN_PERF_SHAPES)
@pytest.mark.parametrize("is_causal", [True, False])
def test_fmha_fwd_with_sink_varlen_asm_perf(head_dim, hq, hk, seqlens, is_causal, init):
device = "cuda"
if head_dim == 64:
hq, hk, seqlens = 64, 8, [4096, 4096]
else:
hq, hk, seqlens = 64, 4, [2048, 2048]
q, k, v, cu = make_varlen_packed(seqlens, hq, hk, head_dim, head_dim, device=device)
q, k, v, cu = make_varlen_packed(
seqlens, hq, hk, head_dim, head_dim, device=device, init=init
)
max_seqlen_q = max(seqlens)
scale = 1.0 / math.sqrt(head_dim)
sink = _d64_sink(hq, device) if head_dim == 64 else None
Expand All @@ -364,10 +409,13 @@ def test_fmha_fwd_with_sink_varlen_asm_perf(head_dim, is_causal):
False,
sink=sink,
)
# Causal FLOPs summed over batches (each ~ 2 * hq * s^2 * 2d / 2).
flops = sum(2.0 * hq * s * s * (2 * head_dim) / 2.0 for s in seqlens)
# FLOPs summed over batches (each ~ 2 * hq * s^2 * 2d); causal halves it.
flops = sum(2.0 * hq * s * s * (2 * head_dim) for s in seqlens)
if is_causal:
flops /= 2.0
tflops = flops / (us * 1e-6) / 1e12
print(
f"[perf varlen] d={head_dim} causal={is_causal} seqlens={seqlens}: {us:.1f}us, {tflops:.2f} TFLOPS"
f"[perf varlen] d={head_dim} causal={is_causal} hq={hq} hk={hk} "
f"seqlens={seqlens} init={init}: {us:.1f}us, {tflops:.2f} TFLOPS"
)
assert us > 0.0 and math.isfinite(tflops)
Loading