bf16 asm mha: add mask=0 kernel#3957
Open
tingchen988 wants to merge 66 commits into
Open
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
…tests
asm_fmha_fwd_f16.cu:
- Set args.opt = 7 = (reverse_kv | double_q | remap_xy) so the packed
s_opt SGPR matches the launch-time gdx/gdy swap and stays compatible
with future _dq variants. Bits 0/1 are compile-time gated off in the
shipped _brd_rxy / _cas_brd_rxy[_sink] builds, so this is a no-op for
current .co files but documents the invariant.
- Always push LSE into the returned vector (caller may ignore it when
return_lse==false). Required to keep a fixed-arity 2-tuple return for
torch.library / compile_ops schema inference.
aiter/ops/mha.py:
- fmha_fwd_f16_asm + gen_fake_tensors now declare -> Tuple[Tensor, Tensor]
(drops the List import and a conditional return_lse branch in the fake
tensor generator). Variadic Tuple[Tensor, ...] is rejected by torch's
infer_schema; fixed-arity matches mxfp8 / fmha_v3_fwd convention.
- Fix misleading docstrings: softmax_scale is forwarded as-is; only sink
is multiplied by sqrt(qk_head_dim) to convert from AITER post-scale to
the kernel's pre-scale raw-logit domain.
op_tests/test_fmha_fwd_f16_asm.py:
- Mxfp8-style helper extraction: run_kernel / run_ref / run_cli used
consistently by the parametrized tests and the __main__ runner.
- Add _bench (cuda.Event timing) and use it instead of run_perftest in
the perf test + CLI; the torch.profiler / ROCTracer path on gfx1250 +
ROCm 7.x silently drops kernel events ("ROCTracer produced duplicate
flow start") so run_perftest reports 0 us / inf TFLOPS — invisible
when running pytest without -s. _bench bypasses the profiler.
- Add _nrms relative-error metric (matches op_tests/test_mha_mxfp8.py)
printed alongside max-diff in correctness tests / CLI --ref output.
- Add sanity asserts in test_fmha_fwd_f16_perf so a future regression
in timing infra (us == 0 or inf TFLOPS) FAILs explicitly instead of
silently PASSing.
- Parametrize batch on test_fmha_fwd_f16_correctness ([1, 2]) to catch
potential batch-stride bugs.
- Move `import argparse` to the top of the file alongside other imports.
- Keep _ref_attn as the in-file reference: attention_ref(upcast=True)
works numerically but its returned lse is cast back to q.dtype (bf16,
see test_mha_common.py:615), introducing ~1 ULP of bf16 quantization
on lse (~0.03 absolute for sq=8192 d=128) that exceeds tight CLI
thresholds. attention_ref is still imported (with noqa: F401) so the
swap is one line if the upstream API stops casting lse.
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
…o tingchen_fmha_f16
Resolve conflicts only (no other local changes): - optCompilerConfig.json: keep fmha_fwd_with_sink_asm + _varlen_asm modules alongside main pa_decode_bf16_asm module. - mha.py: keep both can_impl_fmha_fwd_with_sink_asm (gfx1250) and main can_impl_fmha_native (gfx942); order native split-K dispatch before the sink-asm branch.
…16_mha # Conflicts: # csrc/py_itfs_cu/asm_fmha_fwd_with_sink_varlen.cu # hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256.co # hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256.co # hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_varlen.co # hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_varlen.co
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Co-authored-by: Cursor <cursoragent@cursor.com> # Conflicts: # hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256.co # hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256.co # hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_varlen.co # hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_varlen.co
Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
[gfx1250] fmha fwd bf16 hdim128/64 support per tensor bf16 for gfx1250: add mask=0 kernel
Technical Details
replace asm files of fmha fwd bf16 kernel.
Test Plan
python3 op_tests/python3 op_tests/test_fmha_fwd_f16_asm.py
python3 op_tests/test_fmha_fwd_with_sink_varlen_asm.py
Test Result
Function test:
pytest op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_correctness -v -s
pytest op_tests/test_fmha_fwd_with_sink_varlen_asm.py::test_fmha_fwd_with_sink_varlen_asm_correctness -v -s
test pass
Perf test:
pytest op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf -v -s
pytest op_tests/test_fmha_fwd_with_sink_varlen_asm.py::test_fmha_fwd_with_sink_varlen_asm_perf -v -s
mask=0 kernel has higher perf than mask=1 kernel.
mask=1
op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf[True-128-16384-const0.25] [perf] d=128 sq=sk=16384 b=1 hq=64 hk=4 causal=True init=const0.25: 1465.0us, 3002.16 TFLOPS
op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf[True-64-32768-const0.25] [perf] d=64 sq=sk=32768 b=1 hq=64 hk=8 causal=True init=const0.25: 3844.9us, 2287.72 TFLOPS
mask=0
op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf[False-128-16384-const0.25] [perf] d=128 sq=sk=16384 b=1 hq=64 hk=4 causal=False init=const0.25: 2704.6us, 3252.23 TFLOPS
op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf[False-64-32768-const0.25] [perf] d=64 sq=sk=32768 b=1 hq=64 hk=8 causal=False init=const0.25: 7310.7us, 2406.36 TFLOPS
Submission Checklist