From eacb97c91e7cdac2326aa5bc4bab9d8af0ddabde Mon Sep 17 00:00:00 2001 From: Wulley Date: Thu, 25 Jun 2026 06:36:04 +0000 Subject: [PATCH 1/4] use flydsl pa reduce default --- aiter/ops/triton/gluon/pa_decode_gluon.py | 102 +++++++++++----------- 1 file changed, 52 insertions(+), 50 deletions(-) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index 915bd14fb4..d4670cdd44 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -5084,6 +5084,33 @@ def _paged_attention_decode_v2_reduce_kernel_wrapper( All parameters from the reduction kernel plus execution grid configuration """ if PS: + if FLYDSL_PS_REDUCE_AVAILABLE: + try: + launch_pa_decode_ps_reduce_flydsl( + output_ptr, + exp_sums_ptr, + max_logits_ptr, + logits_ptr, + sink_token_ptr, + stride_output_bs, + stride_output_len, + stride_output_kv_head, + stride_output_group_size, + stride_exp_sums_seq, + stride_exp_sums_head, + stride_exp_sums_part, + stride_logits_seq, + stride_logits_head, + stride_logits_part, + stride_logits_group, + query_seq_len=query_seq_len, + query_group_size=query_group_size, + head_size=head_size, + context_partition_num=context_partition_num, + ) + return + except ImportError: + pass if CXX_PS_REDUCE_AVAILABLE: try: launch_pa_decode_ps_reduce_cxx( @@ -5111,56 +5138,31 @@ def _paged_attention_decode_v2_reduce_kernel_wrapper( return except ImportError: pass - try: - launch_pa_decode_ps_reduce_flydsl( - output_ptr, - exp_sums_ptr, - max_logits_ptr, - logits_ptr, - sink_token_ptr, - stride_output_bs, - stride_output_len, - stride_output_kv_head, - stride_output_group_size, - stride_exp_sums_seq, - stride_exp_sums_head, - stride_exp_sums_part, - stride_logits_seq, - stride_logits_head, - stride_logits_part, - stride_logits_group, - query_seq_len=query_seq_len, - query_group_size=query_group_size, - head_size=head_size, - context_partition_num=context_partition_num, - ) - return - except ImportError: - ps_reduce_grid = (grid[0], grid[1], query_seq_len * query_group_size) - paged_attention_decode_ps_reduce_kernel[ps_reduce_grid]( - output_ptr, - exp_sums_ptr, - max_logits_ptr, - logits_ptr, - sink_token_ptr, - stride_output_bs, - stride_output_len, - stride_output_kv_head, - stride_output_group_size, - stride_exp_sums_seq, - stride_exp_sums_head, - stride_exp_sums_part, - stride_logits_seq, - stride_logits_head, - stride_logits_part, - stride_logits_group, - query_group_size=query_group_size, - head_size=head_size, - context_partition_num=context_partition_num, - HEAD_SIZE_POW2=triton.next_power_of_2(head_size), - USE_SINKS=sink_token_ptr is not None, - MAX_CONTEXT_PARTITION_NUM=triton.next_power_of_2(context_partition_num), - ) + ps_reduce_grid = (grid[0], grid[1], query_seq_len * query_group_size) + paged_attention_decode_ps_reduce_kernel[ps_reduce_grid]( + output_ptr, + exp_sums_ptr, + max_logits_ptr, + logits_ptr, + sink_token_ptr, + stride_output_bs, + stride_output_len, + stride_output_kv_head, + stride_output_group_size, + stride_exp_sums_seq, + stride_exp_sums_head, + stride_exp_sums_part, + stride_logits_seq, + stride_logits_head, + stride_logits_part, + stride_logits_group, + query_group_size=query_group_size, + head_size=head_size, + context_partition_num=context_partition_num, + HEAD_SIZE_POW2=triton.next_power_of_2(head_size), + USE_SINKS=sink_token_ptr is not None, + MAX_CONTEXT_PARTITION_NUM=triton.next_power_of_2(context_partition_num), + ) else: paged_attention_decode_v2_reduce_kernel[grid]( output_ptr, From 9d1991ce05b3c8e758c66bc8f631f2cc78cb10f9 Mon Sep 17 00:00:00 2001 From: Wulley Date: Thu, 25 Jun 2026 07:14:51 +0000 Subject: [PATCH 2/4] adapt for flydsl upstream arith --- aiter/ops/triton/gluon/pa_decode_gluon.py | 35 ++++++++++++----------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index d4670cdd44..2669a68340 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -24,7 +24,7 @@ try: import flydsl.compiler as flyc import flydsl.expr as fx - from flydsl.expr import arith, gpu, rocdl, buffer_ops, range_constexpr + from flydsl.expr import arith, gpu, rocdl, buffer_ops, range_constexpr, const_expr from flydsl.expr.typing import T, Int32 from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from flydsl.runtime.device import get_rocm_arch as get_hip_arch @@ -40,6 +40,7 @@ rocdl = None buffer_ops = None range_constexpr = None + const_expr = None T = None Int32 = None SmemAllocator = None @@ -4568,7 +4569,7 @@ def pa_decode_ps_reduce_flydsl_kernel( smem_base = allocator.get_base() red_scratch = SmemPtr(smem_base, red_off, T.f32, shape=(red_slots,)) red_scratch.get() - if max_context_partition_num > FLYDSL_WARP_SIZE: + if const_expr(max_context_partition_num > FLYDSL_WARP_SIZE): part_weights_lds = SmemPtr( smem_base, part_weights_off, T.f32, shape=(max_context_partition_num,) ) @@ -4578,7 +4579,7 @@ def pa_decode_ps_reduce_flydsl_kernel( es_rsrc = buffer_ops.create_buffer_resource(exp_sums_ptr, max_size=True) ml_rsrc = buffer_ops.create_buffer_resource(max_logits_ptr, max_size=True) logits_rsrc = buffer_ops.create_buffer_resource(logits_ptr, max_size=True) - if use_sinks: + if const_expr(use_sinks): sink_rsrc = buffer_ops.create_buffer_resource(sink_token_ptr, max_size=True) c_zero_f = arith.constant(0.0, type=T.f32) @@ -4648,7 +4649,7 @@ def _block_reduce(val, mode): return red_scratch.load([arith.constant(0, index=True)]) - if max_context_partition_num <= FLYDSL_WARP_SIZE: + if const_expr(max_context_partition_num <= FLYDSL_WARP_SIZE): c_part_num = arith.constant(max_context_partition_num, type=T.i32) c_reduce_width = arith.constant(reduce_width, type=T.i32) c_four = arith.constant(4, type=T.i32) @@ -4704,13 +4705,13 @@ def _wave_reduce_sum(val): ) scaled_sum = part_sum * part_scale global_exp_sum = _wave_reduce_sum(scaled_sum) - if use_sinks: + if const_expr(use_sinks): sink_off = kv_head_idx * c_qgs + group_idx - if sink_dtype_str == "f32": + if const_expr(sink_dtype_str == "f32"): sink_value = buffer_ops.buffer_load( sink_rsrc, sink_off, vec_width=1, dtype=T.f32 ) - elif sink_dtype_str == "f16": + elif const_expr(sink_dtype_str == "f16"): sink_value_raw = buffer_ops.buffer_load( sink_rsrc, sink_off, vec_width=1, dtype=T.f16 ) @@ -4732,7 +4733,7 @@ def _wave_reduce_sum(val): c_one_f, ) weight_local = scaled_sum / safe_global_exp_sum - weight_local_i32 = arith.bitcast(T.i32, weight_local) + weight_local_i32 = arith.bitcast(T.i32, arith.unwrap(weight_local)) acc = c_zero_f for part_idx in range_constexpr(max_context_partition_num): @@ -4749,11 +4750,11 @@ def _wave_reduce_sum(val): + eqgs_idx * stride_logits_group + tid ) - if logits_dtype_str == "f32": + if const_expr(logits_dtype_str == "f32"): part_logits = buffer_ops.buffer_load( logits_rsrc, logits_off, vec_width=1, dtype=T.f32 ) - elif logits_dtype_str == "f16": + elif const_expr(logits_dtype_str == "f16"): part_logits_raw = buffer_ops.buffer_load( logits_rsrc, logits_off, vec_width=1, dtype=T.f16 ) @@ -4819,13 +4820,13 @@ def _wave_reduce_sum(val): chunk_sum = _block_reduce(part_sum * part_scale, "sum") global_exp_sum = global_exp_sum + chunk_sum - if use_sinks: + if const_expr(use_sinks): sink_off = kv_head_idx * c_qgs + group_idx - if sink_dtype_str == "f32": + if const_expr(sink_dtype_str == "f32"): sink_value = buffer_ops.buffer_load( sink_rsrc, sink_off, vec_width=1, dtype=T.f32 ) - elif sink_dtype_str == "f16": + elif const_expr(sink_dtype_str == "f16"): sink_value_raw = buffer_ops.buffer_load( sink_rsrc, sink_off, vec_width=1, dtype=T.f16 ) @@ -4892,11 +4893,11 @@ def _wave_reduce_sum(val): + eqgs_idx * stride_logits_group + tid ) - if logits_dtype_str == "f32": + if const_expr(logits_dtype_str == "f32"): part_logits = buffer_ops.buffer_load( logits_rsrc, logits_off, vec_width=1, dtype=T.f32 ) - elif logits_dtype_str == "f16": + elif const_expr(logits_dtype_str == "f16"): part_logits_raw = buffer_ops.buffer_load( logits_rsrc, logits_off, vec_width=1, dtype=T.f16 ) @@ -4917,9 +4918,9 @@ def _wave_reduce_sum(val): + group_idx * stride_output_group_size + tid ) - if output_dtype_str == "f32": + if const_expr(output_dtype_str == "f32"): out_val = acc - elif output_dtype_str == "f16": + elif const_expr(output_dtype_str == "f16"): out_val = arith.trunc_f(T.f16, acc) else: out_val = arith.trunc_f(T.bf16, acc) From bbb1bff5a573e577d2184a540a6a0dd3ac504815 Mon Sep 17 00:00:00 2001 From: Wulley Date: Thu, 25 Jun 2026 13:03:52 +0000 Subject: [PATCH 3/4] support multi jit --- aiter/ops/triton/gluon/pa_decode_gluon.py | 157 ++++++++++++++++++---- 1 file changed, 133 insertions(+), 24 deletions(-) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index 2669a68340..0448c4c02c 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -2,6 +2,7 @@ # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from functools import lru_cache +import threading import torch @@ -103,6 +104,11 @@ def get_occupancy(): return 2 +# Upper bound on the partition count get_recommended_splits returns at runtime; +# also the set of variants the flydsl PS-reduce path eagerly precompiles. +PA_DECODE_MAX_SPLITS = 8 + + def get_recommended_splits(num_sequences, num_kv_heads, split_kv_blocks=1): props = torch.cuda.get_device_properties() num_sm = props.multi_processor_count * get_occupancy() @@ -110,7 +116,7 @@ def get_recommended_splits(num_sequences, num_kv_heads, split_kv_blocks=1): num_sm, num_sequences * num_kv_heads * split_kv_blocks ) max_context_partition_num *= split_kv_blocks - return min(max_context_partition_num, 8) + return min(max_context_partition_num, PA_DECODE_MAX_SPLITS) DS_WRITE = gl.constexpr(0x200) @@ -4982,6 +4988,81 @@ def launch_pa_decode_ps_reduce_flydsl( } +# Config signatures whose full partition-variant set has already been eagerly +# compiled, so the precompile pass in launch_pa_decode_ps_reduce_flydsl runs at +# most once per config. Guarded by a lock so concurrent first calls don't both +# do the work (the lock only protects the tiny test-and-add, not compilation). +_PS_REDUCE_PRECOMPILED_SIGS = set() +_PS_REDUCE_PRECOMPILE_LOCK = threading.Lock() + + +def _precompile_ps_reduce_sibling_variants( + compile_kwargs, + skip_n, + output_dtype, + logits_dtype, + sink_dtype, + device, + stream, +): + """Compile every partition variant in 1..PA_DECODE_MAX_SPLITS except skip_n. + + Each variant is triggered by actually launching it once on small throwaway + buffers sized for that partition count. This avoids the process-global + COMPILE_ONLY env (which could make a concurrent real launch on another thread + skip execution) -- the cost is a handful of discarded dummy kernel runs on + the first call for a config. Tensor shapes are not in the FlyDSL cache key, + so these tiny launches produce the same artifacts later real calls hit. + """ + qlen = compile_kwargs["query_seq_len"] + group = compile_kwargs["query_group_size"] + head_size = compile_kwargs["head_size"] + use_sinks = compile_kwargs["use_sinks"] + eq_group = qlen * group + + for n in range(1, PA_DECODE_MAX_SPLITS + 1): + if n == skip_n: + continue + # Minimal valid shapes for partition count n (batch=1, num_kv_heads=1). + output = torch.empty( + 1, qlen, 1, group, head_size, device=device, dtype=output_dtype + ) + exp_sums = torch.zeros(1, 1, n, eq_group, device=device, dtype=torch.float32) + max_logits = torch.zeros(1, 1, n, eq_group, device=device, dtype=torch.float32) + temporary_output = torch.zeros( + 1, 1, n, eq_group, head_size, device=device, dtype=logits_dtype + ) + sink = ( + torch.empty(group, device=device, dtype=sink_dtype) + if use_sinks + else torch.empty(0, device=device, dtype=output_dtype) + ) + compiled = compile_pa_decode_ps_reduce_flydsl( + max_context_partition_num=n, **compile_kwargs + ) + compiled["launch"]( + output, + exp_sums, + max_logits, + temporary_output, + sink, + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + exp_sums.stride(0), + exp_sums.stride(1), + exp_sums.stride(2), + temporary_output.stride(0), + temporary_output.stride(1), + temporary_output.stride(2), + temporary_output.stride(3), + output.shape[0], + output.shape[2], + stream, + ) + + def launch_pa_decode_ps_reduce_flydsl( output_ptr, exp_sums_ptr, @@ -5009,8 +5090,7 @@ def launch_pa_decode_ps_reduce_flydsl( "FlyDSL PS reduce fallback: only bf16/fp16/fp32 logits supported" ) - compiled = compile_pa_decode_ps_reduce_flydsl( - max_context_partition_num=context_partition_num, + compile_kwargs = dict( query_seq_len=query_seq_len, query_group_size=query_group_size, head_size=head_size, @@ -5026,27 +5106,56 @@ def launch_pa_decode_ps_reduce_flydsl( sink_token_ptr = torch.empty( 0, dtype=output_ptr.dtype, device=output_ptr.device ) - compiled["launch"]( - output_ptr, - exp_sums_ptr, - max_logits_ptr, - logits_ptr, - sink_token_ptr, - stride_output_bs, - stride_output_len, - stride_output_kv_head, - stride_output_group_size, - stride_exp_sums_seq, - stride_exp_sums_head, - stride_exp_sums_part, - stride_logits_seq, - stride_logits_head, - stride_logits_part, - stride_logits_group, - output_ptr.shape[0], - output_ptr.shape[2], - torch.cuda.current_stream(output_ptr.device), - ) + stream = torch.cuda.current_stream(output_ptr.device) + + def _launch(n): + compiled = compile_pa_decode_ps_reduce_flydsl( + max_context_partition_num=n, **compile_kwargs + ) + compiled["launch"]( + output_ptr, + exp_sums_ptr, + max_logits_ptr, + logits_ptr, + sink_token_ptr, + stride_output_bs, + stride_output_len, + stride_output_kv_head, + stride_output_group_size, + stride_exp_sums_seq, + stride_exp_sums_head, + stride_exp_sums_part, + stride_logits_seq, + stride_logits_head, + stride_logits_part, + stride_logits_group, + output_ptr.shape[0], + output_ptr.shape[2], + stream, + ) + + # Real launch for the partition count this call actually needs. + _launch(context_partition_num) + + # First time we see this config, eagerly compile the sibling partition + # variants so a later call landing on a different get_recommended_splits value + # hits the cache instead of JIT-compiling mid-run. Thread-safe: no global env + # toggling, and the test-and-add is locked so only one thread does the work. + sig = tuple(sorted(compile_kwargs.items())) + with _PS_REDUCE_PRECOMPILE_LOCK: + do_precompile = sig not in _PS_REDUCE_PRECOMPILED_SIGS + if do_precompile: + _PS_REDUCE_PRECOMPILED_SIGS.add(sig) + if do_precompile: + _precompile_ps_reduce_sibling_variants( + compile_kwargs, + context_partition_num, + output_ptr.dtype, + logits_ptr.dtype, + sink_token_ptr.dtype, + output_ptr.device, + stream, + ) def _paged_attention_decode_v2_reduce_kernel_wrapper( From fc1430a3a292306d4f38cf345624050acd504b55 Mon Sep 17 00:00:00 2001 From: Wulley Date: Thu, 25 Jun 2026 14:37:02 +0000 Subject: [PATCH 4/4] jit build multi instances --- aiter/ops/triton/gluon/pa_decode_gluon.py | 91 +++++++++++++---------- 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index 0448c4c02c..c48815c5f6 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -5023,44 +5023,59 @@ def _precompile_ps_reduce_sibling_variants( for n in range(1, PA_DECODE_MAX_SPLITS + 1): if n == skip_n: continue - # Minimal valid shapes for partition count n (batch=1, num_kv_heads=1). - output = torch.empty( - 1, qlen, 1, group, head_size, device=device, dtype=output_dtype - ) - exp_sums = torch.zeros(1, 1, n, eq_group, device=device, dtype=torch.float32) - max_logits = torch.zeros(1, 1, n, eq_group, device=device, dtype=torch.float32) - temporary_output = torch.zeros( - 1, 1, n, eq_group, head_size, device=device, dtype=logits_dtype - ) - sink = ( - torch.empty(group, device=device, dtype=sink_dtype) - if use_sinks - else torch.empty(0, device=device, dtype=output_dtype) - ) - compiled = compile_pa_decode_ps_reduce_flydsl( - max_context_partition_num=n, **compile_kwargs - ) - compiled["launch"]( - output, - exp_sums, - max_logits, - temporary_output, - sink, - output.stride(0), - output.stride(1), - output.stride(2), - output.stride(3), - exp_sums.stride(0), - exp_sums.stride(1), - exp_sums.stride(2), - temporary_output.stride(0), - temporary_output.stride(1), - temporary_output.stride(2), - temporary_output.stride(3), - output.shape[0], - output.shape[2], - stream, - ) + # Best-effort: a failure here (compile error, OOM, ...) must not break the + # already-completed real launch. Skip this variant and let a later real + # call JIT-compile it on demand. + try: + # Minimal valid shapes for partition count n (batch=1, num_kv_heads=1). + output = torch.empty( + 1, qlen, 1, group, head_size, device=device, dtype=output_dtype + ) + exp_sums = torch.zeros( + 1, 1, n, eq_group, device=device, dtype=torch.float32 + ) + max_logits = torch.zeros( + 1, 1, n, eq_group, device=device, dtype=torch.float32 + ) + temporary_output = torch.zeros( + 1, 1, n, eq_group, head_size, device=device, dtype=logits_dtype + ) + sink = ( + torch.empty(group, device=device, dtype=sink_dtype) + if use_sinks + else torch.empty(0, device=device, dtype=output_dtype) + ) + compiled = compile_pa_decode_ps_reduce_flydsl( + max_context_partition_num=n, **compile_kwargs + ) + compiled["launch"]( + output, + exp_sums, + max_logits, + temporary_output, + sink, + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + exp_sums.stride(0), + exp_sums.stride(1), + exp_sums.stride(2), + temporary_output.stride(0), + temporary_output.stride(1), + temporary_output.stride(2), + temporary_output.stride(3), + output.shape[0], + output.shape[2], + stream, + ) + except Exception as e: + aiter.logger.warning( + "pa_decode flydsl PS-reduce: best-effort precompile of partition " + "variant n=%d failed (%s); it will JIT-compile on demand if needed.", + n, + e, + ) def launch_pa_decode_ps_reduce_flydsl(