diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index df9dbde931..890b81140b 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1594,8 +1594,6 @@ def can_impl_fmha_fwd_with_sink_asm(): # (per-Q-head fp32) supported; sink-token (sink_size) not supported. ret = get_gfx() == "gfx1250" ret = ret and (q.dtype == dtypes.bf16) - # Only causal gfx1250 binaries are registered in fmha_fwd_bf16*.csv. - ret = ret and bool(causal) ret = ret and (hdim_q in (64, 128)) ret = ret and (hdim_v == hdim_q) ret = ret and (nhead_q % nhead_k == 0) @@ -2505,8 +2503,6 @@ def can_impl_fmha_fwd_with_sink_varlen_asm(): # logits (per-Q-head fp32) supported; sink-token (sink_size) not. ret = get_gfx() == "gfx1250" ret = ret and (q.dtype == dtypes.bf16) - # Only causal gfx1250 binaries are registered in fmha_fwd_bf16*.csv. - ret = ret and bool(causal) ret = ret and (hdim_q in (64, 128)) ret = ret and (hdim_v == hdim_q) ret = ret and (nhead_q % nhead_k == 0) diff --git a/csrc/py_itfs_cu/asm_fmha_fwd_with_sink.cu b/csrc/py_itfs_cu/asm_fmha_fwd_with_sink.cu index f118566295..e8537c032b 100644 --- a/csrc/py_itfs_cu/asm_fmha_fwd_with_sink.cu +++ b/csrc/py_itfs_cu/asm_fmha_fwd_with_sink.cu @@ -85,10 +85,10 @@ static_assert(sizeof(KernelArgs) == 0x84, // ---- helpers --------------------------------------------------------------- -// Kernel selection: only (dtype, hdim_q, hdim_v, mask) — we always use the -// _brd (border) kernel variants which are a strict superset (handle aligned -// + unaligned q_seq_len/kv_seq_len uniformly). The csv schema therefore has -// no `border` column. +// Kernel selection: (dtype, hdim_q, hdim_v, mask). mask = is_causal: the csv +// registers both mask=0 (non-causal, _rxy_pfnr source) and mask=1 (causal, +// _rxy_pfnr_cas_brd source -> *_mask.co) variants, so is_causal picks the +// matching .co at launch time. static std::string get_heuristic_kernel_fmha_fwd_bf16(const std::string& dtype, int hdim_q, int hdim_v, diff --git a/csrc/py_itfs_cu/asm_fmha_fwd_with_sink_varlen.cu b/csrc/py_itfs_cu/asm_fmha_fwd_with_sink_varlen.cu index 42fdf2f703..70972104c3 100644 --- a/csrc/py_itfs_cu/asm_fmha_fwd_with_sink_varlen.cu +++ b/csrc/py_itfs_cu/asm_fmha_fwd_with_sink_varlen.cu @@ -59,8 +59,10 @@ static_assert(sizeof(FmhaFwdVarlenKernelArgs) == 0x58, // ---- helpers --------------------------------------------------------------- -// Kernel selection: only (dtype, hdim_q, hdim_v, mask). Only the _brd (border) -// causal kernels are shipped, so mask is always 1. +// Kernel selection: (dtype, hdim_q, hdim_v, mask). mask = is_causal: the csv +// registers both mask=0 (non-causal, _rxy[_sink]_pfnr source) and mask=1 +// (causal, _rxy[_sink]_pfnr_cas_brd source -> *_mask_varlen.co) variants, so +// is_causal picks the matching .co at launch time. static std::string get_heuristic_kernel_fmha_fwd_bf16_varlen(const std::string& dtype, int hdim_q, int hdim_v, diff --git a/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256.co b/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256.co index 9a15756d86..8cb5743a25 100755 Binary files a/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256.co and b/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256.co differ diff --git a/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256_mask.co b/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256_mask.co new file mode 100755 index 0000000000..f125b7546d Binary files /dev/null and b/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256_mask.co differ diff --git a/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256.co b/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256.co index 06537d39df..b2fc6f55ab 100755 Binary files a/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256.co and b/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256.co differ diff --git a/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256_mask.co b/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256_mask.co new file mode 100755 index 0000000000..02204903fe Binary files /dev/null and b/hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256_mask.co differ diff --git a/hsa/gfx1250/fmha_fwd_bf16/fmha_fwd_bf16.csv b/hsa/gfx1250/fmha_fwd_bf16/fmha_fwd_bf16.csv index df5f5e2a3f..b4dcd13727 100644 --- a/hsa/gfx1250/fmha_fwd_bf16/fmha_fwd_bf16.csv +++ b/hsa/gfx1250/fmha_fwd_bf16/fmha_fwd_bf16.csv @@ -1,3 +1,5 @@ dtype,hdim_q,hdim_v,mask,knl_name,co_name -bf16,64,64,1,_ZN5aiter35fmha_bf16_pertokenBf16_hd64_128x256E,fmha_bf16_pertokenBf16_hd64_128x256.co -bf16,128,128,1,_ZN5aiter36fmha_bf16_pertokenBf16_hd128_128x256E,fmha_bf16_pertokenBf16_hd128_128x256.co +bf16,64,64,0,_ZN5aiter35fmha_bf16_pertokenBf16_hd64_128x256E,fmha_bf16_pertokenBf16_hd64_128x256.co +bf16,128,128,0,_ZN5aiter36fmha_bf16_pertokenBf16_hd128_128x256E,fmha_bf16_pertokenBf16_hd128_128x256.co +bf16,64,64,1,_ZN5aiter40fmha_bf16_pertokenBf16_hd64_128x256_maskE,fmha_bf16_pertokenBf16_hd64_128x256_mask.co +bf16,128,128,1,_ZN5aiter41fmha_bf16_pertokenBf16_hd128_128x256_maskE,fmha_bf16_pertokenBf16_hd128_128x256_mask.co diff --git a/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_mask_varlen.co b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_mask_varlen.co new file mode 100755 index 0000000000..ffdfe43009 Binary files /dev/null and b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_mask_varlen.co differ diff --git a/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_varlen.co b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_varlen.co index 61bae611fc..fd4238310e 100755 Binary files a/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_varlen.co and b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_varlen.co differ diff --git a/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_mask_varlen.co b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_mask_varlen.co new file mode 100755 index 0000000000..8992ba4f17 Binary files /dev/null and b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_mask_varlen.co differ diff --git a/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_varlen.co b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_varlen.co index 4a4d33430c..efb7b401d1 100755 Binary files a/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_varlen.co and b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_varlen.co differ diff --git a/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_fwd_bf16_varlen.csv b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_fwd_bf16_varlen.csv index ad14180b78..8c5c8bbe88 100644 --- a/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_fwd_bf16_varlen.csv +++ b/hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_fwd_bf16_varlen.csv @@ -1,3 +1,5 @@ dtype,hdim_q,hdim_v,mask,knl_name,co_name -bf16,64,64,1,_ZN5aiter42fmha_bf16_pertokenBf16_hd64_128x256_varlenE,fmha_bf16_pertokenBf16_hd64_128x256_varlen.co -bf16,128,128,1,_ZN5aiter43fmha_bf16_pertokenBf16_hd128_128x256_varlenE,fmha_bf16_pertokenBf16_hd128_128x256_varlen.co +bf16,64,64,0,_ZN5aiter42fmha_bf16_pertokenBf16_hd64_128x256_varlenE,fmha_bf16_pertokenBf16_hd64_128x256_varlen.co +bf16,128,128,0,_ZN5aiter43fmha_bf16_pertokenBf16_hd128_128x256_varlenE,fmha_bf16_pertokenBf16_hd128_128x256_varlen.co +bf16,64,64,1,_ZN5aiter47fmha_bf16_pertokenBf16_hd64_128x256_mask_varlenE,fmha_bf16_pertokenBf16_hd64_128x256_mask_varlen.co +bf16,128,128,1,_ZN5aiter48fmha_bf16_pertokenBf16_hd128_128x256_mask_varlenE,fmha_bf16_pertokenBf16_hd128_128x256_mask_varlen.co diff --git a/op_tests/test_fmha_fwd_with_sink_asm.py b/op_tests/test_fmha_fwd_with_sink_asm.py index 4432d9701c..5ba1e51644 100644 --- a/op_tests/test_fmha_fwd_with_sink_asm.py +++ b/op_tests/test_fmha_fwd_with_sink_asm.py @@ -347,34 +347,32 @@ def run_ref(q, k, v, *, is_causal: bool, sink: Optional[torch.Tensor] = None): # --------------------------------------------------------------------------- -# Only causal kernels are shipped on gfx1250 (CSV registers only `mask=1` -# entries — the nocausal `_brd_v8` binaries were removed). is_causal is kept -# as a parameter so the kernel-call sites still receive the (now always-True) -# flag explicitly; if a nocausal binary is re-added, just add `False` back. -@pytest.mark.parametrize("is_causal", [True]) -@pytest.mark.parametrize( - "head_dim,hq,hk,sq,sk,batch", - [ - # ----- Small shapes (cheap, GQA-light) --------------------------- - # Catch unaligned-sq / unaligned-sk corner cases without paying - # the cost of materializing the full [b, h, sq, sk] fp32 attn - # matrix in _ref_attn. - (64, 8, 1, 128, 2048, 1), # D64 aligned - (64, 8, 1, 128, 2048, 2), - (64, 8, 1, 130, 2048, 1), # D64 q-unaligned (sq not mult of 128) - (64, 8, 1, 128, 2300, 1), # D64 kv-unaligned (sk not mult of 256) - (128, 8, 1, 128, 2048, 1), # D128 aligned - (128, 8, 1, 128, 2048, 2), - (128, 8, 1, 130, 2048, 1), # D128 q-unaligned - (128, 8, 1, 128, 2300, 1), # D128 kv-unaligned - # ----- Large shapes aligned to run.sh perf_v4_d64 / perf_v4_d128 - - # Same memory pressure as test_fmha_fwd_with_sink_asm_perf, batch=1 only - # because the reference path's fp32 attn matrix would otherwise - # exceed device memory (D64 batch=2 sq=sk=8192 → 32 GB). - (64, 64, 8, 8192, 8192, 1), # D64 perf-sized, aligned - (128, 64, 4, 4096, 4096, 1), # D128 perf-sized, aligned - ], -) +# KV-length constraint (mask=0 only): the non-causal (mask=0) kernels only +# support sk (kv_seqlen) that is a multiple of 256. +_CORRECTNESS_SHAPES = [ + # ----- Small shapes (cheap, GQA-light) --------------------------- + (64, 8, 1, 128, 2048, 1), # D64 aligned + (64, 8, 1, 128, 2048, 2), + (64, 8, 1, 130, 2048, 1), # D64 q-unaligned (sq not mult of 128) + (64, 8, 1, 128, 2300, 1), # D64 kv-unaligned (sk not mult of 256) -> causal only + (128, 8, 1, 128, 2048, 1), # D128 aligned + (128, 8, 1, 128, 2048, 2), + (128, 8, 1, 130, 2048, 1), # D128 q-unaligned + (128, 8, 1, 128, 2300, 1), # D128 kv-unaligned -> causal only + (64, 64, 8, 8192, 8192, 1), # D64 perf-sized, aligned + (128, 64, 4, 4096, 4096, 1), # D128 perf-sized, aligned +] + + +_CORRECTNESS_CASES = [ + (head_dim, hq, hk, sq, sk, batch, causal) + for (head_dim, hq, hk, sq, sk, batch) in _CORRECTNESS_SHAPES + for causal in (True, False) + if causal or (sk % 256 == 0) +] + + +@pytest.mark.parametrize("head_dim,hq,hk,sq,sk,batch,is_causal", _CORRECTNESS_CASES) def test_fmha_fwd_with_sink_asm_correctness(head_dim, hq, hk, sq, sk, batch, is_causal): device = "cuda" torch.manual_seed(0) @@ -573,8 +571,7 @@ def test_fmha_fwd_with_sink_asm_layout(layout, head_dim): @pytest.mark.parametrize("head_dim", [64, 128]) -# Only causal kernels are shipped (see test_fmha_fwd_with_sink_asm_correctness comment). -@pytest.mark.parametrize("is_causal", [True]) +@pytest.mark.parametrize("is_causal", [True, False]) def test_fmha_fwd_with_sink_asm_via_flash_attn_func(head_dim, is_causal): device = "cuda" torch.manual_seed(0) @@ -805,8 +802,7 @@ def _make_qkv_perf(init: str, *, layout, sq, sk, batch, hq, hk, d, dtype, device @pytest.mark.parametrize("init", _PERF_INITS) @pytest.mark.parametrize("head_dim,seqlen", _PERF_SHAPES) -# Only causal kernels are shipped (see test_fmha_fwd_with_sink_asm_correctness comment). -@pytest.mark.parametrize("is_causal", [True]) +@pytest.mark.parametrize("is_causal", [True, False]) def test_fmha_fwd_with_sink_asm_perf(head_dim, seqlen, is_causal, init): device = "cuda" torch.manual_seed(0) diff --git a/op_tests/test_fmha_fwd_with_sink_varlen_asm.py b/op_tests/test_fmha_fwd_with_sink_varlen_asm.py index 74887448e5..5b2892034d 100644 --- a/op_tests/test_fmha_fwd_with_sink_varlen_asm.py +++ b/op_tests/test_fmha_fwd_with_sink_varlen_asm.py @@ -18,9 +18,9 @@ to the kernel verbatim (no host-side scaling). D64 kernels read it; D128 kernels ignore it (pass None). -Only causal kernels are shipped (CSV registers mask=1 rows), so is_causal=True. -Causal uses bottom-right alignment per sequence (query i attends to key j iff -j <= i + (sk - sq)), matching flash_attn varlen semantics. + +KV-length constraint (mask=0 only): the non-causal (mask=0) kernels only +support per-sequence kv_seqlen that is a multiple of 256. """ from __future__ import annotations @@ -131,11 +131,22 @@ def _ref_varlen(q, k, v, cu_q, cu_k, *, is_causal: bool, sink: Optional[torch.Te def make_varlen_packed( - seqlens: List[int], hq: int, hk: int, d: int, dv: int, device="cuda", seed=0 + seqlens: List[int], + hq: int, + hk: int, + d: int, + dv: int, + device="cuda", + seed=0, + init: str = "randn", ): """Build packed THD q/k/v + cu_seqlens for the given per-batch seqlens. Uses equal q/k seqlens per batch (standard varlen self-attention). + + init pattern (mirrors the fixed-batch perf test's `_make_qkv_perf`): + "randn" : standard normal (default; exercises real attention math). + "const0.25" : fill every element with 0.25 """ torch.manual_seed(seed) cu = torch.tensor( @@ -145,6 +156,12 @@ def make_varlen_packed( q = torch.randn(total, hq, d, dtype=torch.bfloat16, device=device) k = torch.randn(total, hk, d, dtype=torch.bfloat16, device=device) v = torch.randn(total, hk, dv, dtype=torch.bfloat16, device=device) + if init == "const0.25": + q.fill_(0.25) + k.fill_(0.25) + v.fill_(0.25) + elif init != "randn": + raise ValueError(f"unknown perf init pattern: {init!r}") cu = cu.to(device) return q, k, v, cu @@ -205,23 +222,39 @@ def run_kernel( # --------------------------------------------------------------------------- -@pytest.mark.parametrize("is_causal", [True]) -@pytest.mark.parametrize( - "head_dim,hq,hk,seqlens", - [ - # aligned single batch - (64, 8, 1, [256]), - (128, 8, 1, [256]), - # multi-batch, mixed (some unaligned) seqlens - (64, 8, 1, [128, 256, 384]), - (128, 8, 1, [128, 256, 384]), - (64, 8, 2, [100, 200, 300]), # unaligned + GQA - (128, 8, 2, [100, 200, 300]), - # GQA-heavy, larger - (64, 64, 8, [512, 1024]), - (128, 64, 4, [512, 1024]), - ], -) +_CORRECTNESS_SHAPES = [ + # aligned single batch + (64, 8, 1, [256]), + (128, 8, 1, [256]), + # multi-batch, mixed (some unaligned) seqlens -> causal only + (64, 8, 1, [128, 256, 384]), + (128, 8, 1, [128, 256, 384]), + (64, 8, 2, [100, 200, 300]), # unaligned + GQA + (128, 8, 2, [100, 200, 300]), + # 256-aligned multi-batch (exercised under BOTH causal and mask=0) + (64, 8, 1, [256, 512]), + (128, 8, 1, [256, 512]), + (64, 8, 2, [256, 512, 768]), # aligned 3-batch + GQA + (128, 8, 2, [256, 512, 768]), + # GQA-heavy, larger (256-aligned) + (64, 64, 8, [512, 1024]), + (128, 64, 4, [512, 1024]), +] + + +def _kv_256_aligned(seqlens) -> bool: + return all(s % 256 == 0 for s in seqlens) + + +_CORRECTNESS_CASES = [ + (hd, hq, hk, sl, causal) + for (hd, hq, hk, sl) in _CORRECTNESS_SHAPES + for causal in (True, False) + if causal or _kv_256_aligned(sl) +] + + +@pytest.mark.parametrize("head_dim,hq,hk,seqlens,is_causal", _CORRECTNESS_CASES) def test_fmha_fwd_with_sink_varlen_asm_correctness( head_dim, hq, hk, seqlens, is_causal ): @@ -271,10 +304,10 @@ def test_fmha_fwd_with_sink_varlen_asm_correctness( @pytest.mark.parametrize("head_dim", [64, 128]) -@pytest.mark.parametrize("is_causal", [True]) +@pytest.mark.parametrize("is_causal", [True, False]) def test_fmha_fwd_with_sink_varlen_asm_via_flash_attn_varlen_func(head_dim, is_causal): device = "cuda" - hq, hk, seqlens = 8, 1, [128, 256, 384] + hq, hk, seqlens = 8, 1, [256, 512, 768] q, k, v, cu = make_varlen_packed(seqlens, hq, hk, head_dim, head_dim, device=device) max_seqlen_q = max(seqlens) scale = 1.0 / math.sqrt(head_dim) @@ -338,15 +371,27 @@ def _bench(fn, *args, num_iters=20, num_warmup=10, **kwargs) -> float: return start.elapsed_time(end) * 1000.0 / num_iters # us per iter -@pytest.mark.parametrize("head_dim", [64, 128]) -@pytest.mark.parametrize("is_causal", [True]) -def test_fmha_fwd_with_sink_varlen_asm_perf(head_dim, is_causal): +_VARLEN_PERF_SHAPES = [ + (64, 64, 8, [4096, 4096]), # D64 multi-batch + (128, 64, 4, [2048, 2048]), # D128 multi-batch + (128, 64, 4, [16384]), # D128 sq=sk=16384 (long context) + (64, 64, 8, [32768]), # D64 sq=sk=32768 (long context) +] + +# Perf input init patterns (mirrors the fixed-batch perf test's _PERF_INITS): +# "randn" : standard normal (default; exercises real attention math). +# "const0.25" : constant 0.25 fill (matches the cpp init_pattern=10 baseline). +_VARLEN_PERF_INITS = ["randn", "const0.25"] + + +@pytest.mark.parametrize("init", _VARLEN_PERF_INITS) +@pytest.mark.parametrize("head_dim,hq,hk,seqlens", _VARLEN_PERF_SHAPES) +@pytest.mark.parametrize("is_causal", [True, False]) +def test_fmha_fwd_with_sink_varlen_asm_perf(head_dim, hq, hk, seqlens, is_causal, init): device = "cuda" - if head_dim == 64: - hq, hk, seqlens = 64, 8, [4096, 4096] - else: - hq, hk, seqlens = 64, 4, [2048, 2048] - q, k, v, cu = make_varlen_packed(seqlens, hq, hk, head_dim, head_dim, device=device) + q, k, v, cu = make_varlen_packed( + seqlens, hq, hk, head_dim, head_dim, device=device, init=init + ) max_seqlen_q = max(seqlens) scale = 1.0 / math.sqrt(head_dim) sink = _d64_sink(hq, device) if head_dim == 64 else None @@ -364,10 +409,13 @@ def test_fmha_fwd_with_sink_varlen_asm_perf(head_dim, is_causal): False, sink=sink, ) - # Causal FLOPs summed over batches (each ~ 2 * hq * s^2 * 2d / 2). - flops = sum(2.0 * hq * s * s * (2 * head_dim) / 2.0 for s in seqlens) + # FLOPs summed over batches (each ~ 2 * hq * s^2 * 2d); causal halves it. + flops = sum(2.0 * hq * s * s * (2 * head_dim) for s in seqlens) + if is_causal: + flops /= 2.0 tflops = flops / (us * 1e-6) / 1e12 print( - f"[perf varlen] d={head_dim} causal={is_causal} seqlens={seqlens}: {us:.1f}us, {tflops:.2f} TFLOPS" + f"[perf varlen] d={head_dim} causal={is_causal} hq={hq} hk={hk} " + f"seqlens={seqlens} init={init}: {us:.1f}us, {tflops:.2f} TFLOPS" ) assert us > 0.0 and math.isfinite(tflops)