Skip to content

bf16 asm mha: add mask=0 kernel#3957

Open
tingchen988 wants to merge 66 commits into
mainfrom
tingchen_mha_mask0
Open

bf16 asm mha: add mask=0 kernel#3957
tingchen988 wants to merge 66 commits into
mainfrom
tingchen_mha_mask0

Conversation

@tingchen988

@tingchen988 tingchen988 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Motivation

[gfx1250] fmha fwd bf16 hdim128/64 support per tensor bf16 for gfx1250: add mask=0 kernel

Technical Details

replace asm files of fmha fwd bf16 kernel.

Test Plan

python3 op_tests/python3 op_tests/test_fmha_fwd_f16_asm.py
python3 op_tests/test_fmha_fwd_with_sink_varlen_asm.py

Test Result

Function test:
pytest op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_correctness -v -s
pytest op_tests/test_fmha_fwd_with_sink_varlen_asm.py::test_fmha_fwd_with_sink_varlen_asm_correctness -v -s
test pass

Perf test:
pytest op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf -v -s
pytest op_tests/test_fmha_fwd_with_sink_varlen_asm.py::test_fmha_fwd_with_sink_varlen_asm_perf -v -s

mask=0 kernel has higher perf than mask=1 kernel.

mask=1
op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf[True-128-16384-const0.25] [perf] d=128 sq=sk=16384 b=1 hq=64 hk=4 causal=True init=const0.25: 1465.0us, 3002.16 TFLOPS
op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf[True-64-32768-const0.25] [perf] d=64 sq=sk=32768 b=1 hq=64 hk=8 causal=True init=const0.25: 3844.9us, 2287.72 TFLOPS

mask=0
op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf[False-128-16384-const0.25] [perf] d=128 sq=sk=16384 b=1 hq=64 hk=4 causal=False init=const0.25: 2704.6us, 3252.23 TFLOPS
op_tests/test_fmha_fwd_with_sink_asm.py::test_fmha_fwd_with_sink_asm_perf[False-64-32768-const0.25] [perf] d=64 sq=sk=32768 b=1 hq=64 hk=8 causal=False init=const0.25: 7310.7us, 2406.36 TFLOPS

Submission Checklist

tingchen and others added 30 commits May 9, 2026 20:00
Co-authored-by: Cursor <cursoragent@cursor.com>
…tests

asm_fmha_fwd_f16.cu:
- Set args.opt = 7 = (reverse_kv | double_q | remap_xy) so the packed
  s_opt SGPR matches the launch-time gdx/gdy swap and stays compatible
  with future _dq variants.  Bits 0/1 are compile-time gated off in the
  shipped _brd_rxy / _cas_brd_rxy[_sink] builds, so this is a no-op for
  current .co files but documents the invariant.
- Always push LSE into the returned vector (caller may ignore it when
  return_lse==false).  Required to keep a fixed-arity 2-tuple return for
  torch.library / compile_ops schema inference.

aiter/ops/mha.py:
- fmha_fwd_f16_asm + gen_fake_tensors now declare -> Tuple[Tensor, Tensor]
  (drops the List import and a conditional return_lse branch in the fake
  tensor generator).  Variadic Tuple[Tensor, ...] is rejected by torch's
  infer_schema; fixed-arity matches mxfp8 / fmha_v3_fwd convention.
- Fix misleading docstrings: softmax_scale is forwarded as-is; only sink
  is multiplied by sqrt(qk_head_dim) to convert from AITER post-scale to
  the kernel's pre-scale raw-logit domain.

op_tests/test_fmha_fwd_f16_asm.py:
- Mxfp8-style helper extraction: run_kernel / run_ref / run_cli used
  consistently by the parametrized tests and the __main__ runner.
- Add _bench (cuda.Event timing) and use it instead of run_perftest in
  the perf test + CLI; the torch.profiler / ROCTracer path on gfx1250 +
  ROCm 7.x silently drops kernel events ("ROCTracer produced duplicate
  flow start") so run_perftest reports 0 us / inf TFLOPS — invisible
  when running pytest without -s.  _bench bypasses the profiler.
- Add _nrms relative-error metric (matches op_tests/test_mha_mxfp8.py)
  printed alongside max-diff in correctness tests / CLI --ref output.
- Add sanity asserts in test_fmha_fwd_f16_perf so a future regression
  in timing infra (us == 0 or inf TFLOPS) FAILs explicitly instead of
  silently PASSing.
- Parametrize batch on test_fmha_fwd_f16_correctness ([1, 2]) to catch
  potential batch-stride bugs.
- Move `import argparse` to the top of the file alongside other imports.
- Keep _ref_attn as the in-file reference: attention_ref(upcast=True)
  works numerically but its returned lse is cast back to q.dtype (bf16,
  see test_mha_common.py:615), introducing ~1 ULP of bf16 quantization
  on lse (~0.03 absolute for sq=8192 d=128) that exceeds tight CLI
  thresholds.  attention_ref is still imported (with noqa: F401) so the
  swap is one line if the upstream API stops casting lse.

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
tingchen988 and others added 25 commits May 26, 2026 23:40
Resolve conflicts only (no other local changes):
- optCompilerConfig.json: keep fmha_fwd_with_sink_asm + _varlen_asm modules
  alongside main pa_decode_bf16_asm module.
- mha.py: keep both can_impl_fmha_fwd_with_sink_asm (gfx1250) and main
  can_impl_fmha_native (gfx942); order native split-K dispatch before the
  sink-asm branch.
…16_mha

# Conflicts:
#	csrc/py_itfs_cu/asm_fmha_fwd_with_sink_varlen.cu
#	hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256.co
#	hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256.co
#	hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_varlen.co
#	hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_varlen.co
@tingchen988 tingchen988 requested a review from a team June 26, 2026 13:41
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3957 --add-label <label>

junxiaguo and others added 3 commits June 26, 2026 13:46
Co-authored-by: Cursor <cursoragent@cursor.com>

# Conflicts:
#	hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd128_128x256.co
#	hsa/gfx1250/fmha_fwd_bf16/fmha_bf16_pertokenBf16_hd64_128x256.co
#	hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd128_128x256_varlen.co
#	hsa/gfx1250/fmha_fwd_bf16_varlen/fmha_bf16_pertokenBf16_hd64_128x256_varlen.co
Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants