From 01076520359b4a90e61037d29a02decb9164b090 Mon Sep 17 00:00:00 2001 From: aoli Date: Tue, 23 Jun 2026 12:47:51 +0000 Subject: [PATCH 1/2] update flydsl optimized gfx1250 gemm kernel --- .../ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py | 2111 ++++++++--------- aiter/ops/flydsl/kernels/tdm_oob.py | 313 --- 2 files changed, 968 insertions(+), 1456 deletions(-) delete mode 100644 aiter/ops/flydsl/kernels/tdm_oob.py diff --git a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py index a3223c4c5f..d37a2982f0 100644 --- a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py +++ b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py @@ -7,8 +7,6 @@ """ import functools -import inspect -import os import flydsl.compiler as flyc import flydsl.expr as fx @@ -32,7 +30,7 @@ from .gemm_common_gfx1250 import ( extract_lds_base_idx, get_lds_memref, - issue_tdm_loads, + lds_load_b32_raw, lds_load_b128_raw, pipeline_fence, pipeline_fence_signal, @@ -57,40 +55,24 @@ def _s_prefetch_inst_burst(num_pages: int, page_bytes: int = 4096): _llvm.inline_asm(None, [], "\n".join(lines), "", has_side_effects=True) -# Feature-detect the installed flydsl's TDM descriptor builder. Older pinned -# flydsl predates these args; we apply each only when supported and otherwise -# fall back to the vendored OOB-capable builder for the non-tile-aligned-M path. -_TDM_SIG_PARAMS = inspect.signature(tdm_ops.make_tensor_descriptor_2d).parameters -_TDM_HAS_EARLY_TIMEOUT = "early_timeout" in _TDM_SIG_PARAMS -_TDM_HAS_OOB = "oob_outer_bound" in _TDM_SIG_PARAMS - - -def _make_tdm_desc(*, early_timeout=False, oob_outer_bound=None, **kwargs): - """Build a 2D TDM descriptor, transparently across flydsl versions.""" - strides = kwargs.get("strides") - runtime_stride = strides is not None and not isinstance(strides[0], int) - needs_oob = oob_outer_bound is not None - - if runtime_stride or (needs_oob and not _TDM_HAS_OOB): - from .tdm_oob import make_tensor_descriptor_2d as _vendored_make_desc - - return _vendored_make_desc( - early_timeout=early_timeout, oob_outer_bound=oob_outer_bound, **kwargs - ) - - if _TDM_HAS_OOB: - kwargs["oob_outer_bound"] = oob_outer_bound - if _TDM_HAS_EARLY_TIMEOUT: - kwargs["early_timeout"] = early_timeout - return tdm_ops.make_tensor_descriptor_2d(**kwargs) - - # Common constants WMMA_M, WMMA_N, WMMA_K = 16, 16, 128 WAVE_SIZE = 32 SCALE_BLOCK = 32 SCALES_PER_WMMA = WMMA_K // SCALE_BLOCK # 4 + +def _vec_chunks(n: int): + """Compile-time split of n contiguous i32 into buffer_load widths (4/2/1).""" + chunks = [] + done = 0 + while done < n: + w = 4 if (n - done) >= 4 else (2 if (n - done) >= 2 else 1) + chunks.append((done, w)) + done += w + return chunks + + LDS_PAD_A_BYTES = 16 LDS_PAD_D_BYTES = 16 LDS_SEGMENT_BYTES = 64 * 1024 @@ -114,17 +96,12 @@ def compile_fp8fp4_gemm( l2_prefetch_distance: int = 2, cluster_m: int = 1, cluster_n: int = 1, - use_tdm_store: bool = True, out_dtype: str = "f32", inst_prefetch: bool = False, - wave_specialized_tdm: bool = False, split_k: int = 1, - use_scale_opsel: bool = False, expert_sched_mode: bool = True, atomic_barrier_enable: bool = False, - b_streaming: bool = False, - scale_load_path: str = "tdm", - fp8_schedule: str = "auto", + ascale_load_path: str = "vgpr", ): """Compile an FP4/FP8/A8W4 GEMM kernel with TDM async copy. @@ -136,13 +113,16 @@ def compile_fp8fp4_gemm( Data layout: A: [M, K_packed] uint8 (FP4: K_packed=K//2, FP8: K_packed=K) B: [N, K_packed] uint8, preshuffled (16x16 byte tiles) - mxscale: scale_A [M, K//32], scale_B [N, K//32] uint8 E8M0 (preshuffled) + mxscale scale_A: + ascale_load_path="vgpr": [M, K//32] uint8 E8M0 + ascale_load_path="shuffled_tdm": [ceil(M/32), (K//128)*128] uint8 E8M0 + in 32x4 packed layout + mxscale scale_B: [N//32, (K//128)*128] uint8 E8M0 in 32x4 packed layout ptpc: scale_A [M], scale_B [N] fp32 Returns a JitFunction: launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, lda, ldc, stream) - where lda / ldc are the runtime leading-dim strides (in elements) of A / C. - Pass lda == K and ldc == N for dense (contiguous) tensors. + where lda/ldc are A/C runtime leading-dim strides in elements (dense: lda=K, ldc=N). """ if data_format not in ("fp4", "fp8", "a8w4"): raise ValueError( @@ -154,6 +134,10 @@ def compile_fp8fp4_gemm( raise ValueError( "scale_mode='ptpc' currently only supports data_format='fp8' or 'a8w4'" ) + if ascale_load_path not in ("vgpr", "shuffled_tdm"): + raise ValueError( + f"ascale_load_path must be 'vgpr' or 'shuffled_tdm', got {ascale_load_path!r}" + ) is_fp4 = data_format == "fp4" is_a8w4 = data_format == "a8w4" @@ -164,31 +148,13 @@ def compile_fp8fp4_gemm( f"out_dtype must be 'f32', 'bf16', or 'f16', got {out_dtype!r}" ) elem_bytes_d = 2 if out_dtype in ("bf16", "f16") else 4 - # scale_load_path: "tdm" = TDM->LDS (default); "vgpr" = buffer_load->VGPR, - # off the LDS/TDM/barrier path; "vgpr_ab_split" = "vgpr" plus repurposing the - # idle scale waves 2,3 to load the second A/B halves. - scale_load_paths = ("tdm", "vgpr", "vgpr_ab_split") - if scale_load_path not in scale_load_paths: - raise ValueError( - f"scale_load_path must be one of {scale_load_paths}, got {scale_load_path!r}" - ) - fp8_schedule_modes = ("auto", "quadrant", "deep-pipeline") - if fp8_schedule not in fp8_schedule_modes: - raise ValueError( - f"fp8_schedule must be one of {fp8_schedule_modes}, got {fp8_schedule!r}" - ) - if fp8_schedule != "auto" and data_format != "fp8": - raise ValueError( - f"fp8_schedule={fp8_schedule!r} is only valid for data_format='fp8'" - ) - if fp8_schedule != "auto" and b_streaming: - raise ValueError("fp8_schedule cannot be combined with b_streaming=True") effective_expert_sched_mode = bool(expert_sched_mode) - if num_buffers not in (2, 3, 4): - raise ValueError(f"num_buffers must be 2, 3, or 4, got {num_buffers}") + if num_buffers not in (2, 3, 4, 5, 6): + raise ValueError(f"num_buffers must be 2, 3, 4, 5 or 6, got {num_buffers}") if split_k < 1: raise ValueError(f"split_k must be >= 1, got {split_k}") + tdm_store_enabled = split_k == 1 use_cluster = cluster_m > 1 or cluster_n > 1 if use_cluster: @@ -203,12 +169,6 @@ def compile_fp8fp4_gemm( if block_threads > 1024: raise ValueError(f"block_threads must be <= 1024, got {block_threads}") - _min_wave_spec_warps = 2 if is_ptpc else 4 - if wave_specialized_tdm and num_warps < _min_wave_spec_warps: - raise ValueError( - f"wave_specialized_tdm requires at least {_min_wave_spec_warps} waves, got {num_warps}" - ) - # ── Format-dependent compile-time constants ── # A8W4: activation is FP8 (PACK_FACTOR_A=1), weight is FP4 (PACK_FACTOR_B=2) if is_a8w4: @@ -269,14 +229,20 @@ def compile_fp8fp4_gemm( f"warp_tile_n={warp_tile_n} must be a multiple of {WMMA_N_EFF}" ) - if split_k > 1 and use_tdm_store: - raise ValueError("split_k > 1 currently requires use_tdm_store=False") + # mxscale B-scale is always the 32x4 `preshuffle_scale` layout: require N/tile_n a + # multiple of 32 and tile_k a multiple of 128 (no legacy sub-32 fallback). + if scale_mode == "mxscale" and ( + N % 32 != 0 or tile_n % 32 != 0 or tile_k % 128 != 0 + ): + raise ValueError( + f"mxscale 32x4 B-scale requires N%32==0, tile_n%32==0, tile_k%128==0; " + f"got N={N}, tile_n={tile_n}, tile_k={tile_k}" + ) num_k_tiles = split_k_chunk // tile_k if num_k_tiles < num_buffers: raise ValueError( - f"{num_buffers}-stage buffering requires num_k_tiles >= {num_buffers}, " - f"got {num_k_tiles}" + f"{num_buffers}-stage buffering requires num_k_tiles >= {num_buffers}, got {num_k_tiles}" ) gpu_arch = str(get_hip_arch()) @@ -290,32 +256,101 @@ def compile_fp8fp4_gemm( # FP4 A/B swap: BScale rep derived from WMMA_M, not WMMA_N_EFF b_scale_load_rep = warp_tile_n // WMMA_M if is_fp4 else wmma_n_rep + # mxscale carries per-K-block scales; ptpc has no K-loop scale (per-token/ + # per-channel fp32 applied in the epilogue). + is_mxscale = not is_ptpc + use_ascale_vgpr = is_mxscale and ascale_load_path == "vgpr" + use_ascale_shuffled_tdm = is_mxscale and ascale_load_path == "shuffled_tdm" + + # 32x4 A-scale layout (preshuffle_scale): [ceil(M/32), K//128, 32, 4]. + # One 128B block (32 rows x 4 K-scales) maps to one WMMA scale operand. + as32_block_bytes = 128 + as32_global_row_stride = 0 + as32_lds_row_stride = 0 + as32_tile_blocks_pad = 1 + as32_n_load = 0 + as32_opsel = False + # 32x4 B-scale layout (preshuffle_scale): [N//32, K//128, 32, 4]. + bs32_block_bytes = 128 + bs32_global_row_stride = 0 + bs32_lds_row_stride = 0 + bs32_tile_blocks_pad = 1 + bs32_n_load = 0 + bs32_opsel = False + if is_mxscale: + if use_ascale_shuffled_tdm: + as32_global_row_stride = (K // WMMA_K) * as32_block_bytes + as32_lds_row_stride = k_wmma_steps * as32_block_bytes + as32_tile_blocks = (tile_m + 31) // 32 + as32_tile_blocks_pad = 1 << (as32_tile_blocks - 1).bit_length() + # Adjacent 16-M WMMAs share one 32-row block when the warp M span is even. + as32_opsel = wmma_m_rep >= 2 and (wmma_m_rep % 2 == 0) + as32_n_load = (wmma_m_rep // 2) if as32_opsel else wmma_m_rep + + bs32_global_row_stride = ( + K // WMMA_K + ) * bs32_block_bytes # bytes per block row (= K) + bs32_lds_row_stride = k_wmma_steps * bs32_block_bytes # LDS bytes per block row + bs32_tile_blocks = tile_n // 32 + # Pad block count to pow2 so the TDM warp split stays clean (non-pow2, e.g. + # 6, miscopies LDS). Cost-free for pow2 block counts; else 1-2 oob-clipped. + bs32_tile_blocks_pad = 1 << (bs32_tile_blocks - 1).bit_length() + bs32_opsel = (not is_fp4) and (wmma_n_rep % 2 == 0) + bs32_n_load = ( + (wmma_n_rep // 2) if bs32_opsel else wmma_n_rep + ) # b32 loads per ks + + # A-scale VGPR path keeps the original [M, K//32] layout. Its op_sel pairing is + # by M-half because lane_kgrp selects the upper/lower half of the warp's M span. + ascale_opsel = ( + use_ascale_vgpr and wmma_m_rep >= 2 and (wmma_m_rep & (wmma_m_rep - 1)) == 0 + ) + ascale_half = wmma_m_rep // 2 + ascale_load = ascale_half if ascale_opsel else wmma_m_rep + + # TDM loader assignment: + # VGPR A-scale: wave0=A, wave1=B, wave2=B-scale; at 2 waves B-scale rides wave0. + # Shuffled A-scale: wave0=A, wave1=B, wave2=A-scale, wave3=B-scale; with 2/3 + # waves the missing scale descriptor rides as a secondary issue. + two_wave_bscale = use_ascale_vgpr and num_warps == 2 + two_wave_scale = use_ascale_shuffled_tdm and num_warps == 2 + three_wave_bscale = use_ascale_shuffled_tdm and num_warps == 3 + secondary_scale_tdm = two_wave_bscale or two_wave_scale or three_wave_bscale + + # mxscale uses at least A/B TDM waves; ptpc uses A/B only. + if num_warps < 2: + raise ValueError( + f"wave-specialized TDM requires at least 2 waves, got {num_warps}" + ) + _b_frag_loads_per_wn = 2 if is_a8w4 else 4 _a_frag_loads_per_wm = 2 if is_fp4 else 4 - _scale_ds_loads = (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 + # Scale ds_loads issued alongside A/B fragment loads in the streaming schedule + # (for the partial-drain s_wait_dscnt bookkeeping). + _a_scale_ds = as32_n_load if use_ascale_shuffled_tdm else 0 + _b_scale_ds = bs32_n_load if is_mxscale else 0 + _scale_ds_loads = _a_scale_ds + _b_scale_ds + _a_frag_ds = wmma_m_rep * _a_frag_loads_per_wm _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads - _as_ds_loads = wmma_m_rep * _a_frag_loads_per_wm + _scale_ds_loads + _as_ds_loads = _a_frag_ds + _scale_ds_loads + _row_major_k_prefetch_bundle_ds = _a_frag_ds + _bs_ds_loads lds_a_stride_bytes = packed_tile_k_a + LDS_PAD_A_BYTES - if scale_load_path == "vgpr_ab_split": - if tile_m % 2 != 0: - raise ValueError( - f"scale_load_path='vgpr_ab_split' requires even tile_m, got {tile_m}" - ) - if tile_n % 32 != 0: - raise ValueError( - f"scale_load_path='vgpr_ab_split' requires tile_n divisible by 32, got {tile_n}" - ) lds_a_data_bytes = tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b - ab_split_a_rows = tile_m // 2 - ab_split_b_groups = tile_n // 32 _scale_guard_bytes = 16 - lds_a_scale_bytes = 0 if is_ptpc else tile_m * scale_k_per_tile + _scale_guard_bytes - lds_b_scale_bytes = 0 if is_ptpc else tile_n * scale_k_per_tile + _scale_guard_bytes - interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile - interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile + # A-scale LDS is allocated only for the shuffled TDM path. + lds_a_scale_bytes = ( + (as32_tile_blocks_pad * as32_lds_row_stride + _scale_guard_bytes) + if use_ascale_shuffled_tdm + else 0 + ) + lds_b_scale_bytes = ( + (bs32_tile_blocks_pad * bs32_lds_row_stride + _scale_guard_bytes) + if is_mxscale + else 0 + ) def _align_up(value: int, align: int) -> int: if value % align == 0: @@ -326,7 +361,7 @@ def _align_up(value: int, align: int) -> int: # deriving per-wave offsets from ``wave_id``. In wave-specialized mode we # dedicate one loader wave to each tensor (A/B/A_scale/B_scale), so each # active loader wave must issue a full-tile descriptor by itself. - tdm_desc_num_warps = 1 if wave_specialized_tdm else num_warps + tdm_desc_num_warps = 1 # All pipeline stages share the same intra-stage layout in the generic # arena path. The active gfx1250 FP8 TDM tile uses a separate reference @@ -357,114 +392,38 @@ def _align_up(value: int, align: int) -> int: None, arch=gpu_arch, global_sym_name=( - f"mxscale_{data_format}_{tile_m}x{tile_n}x{tile_k}_" - f"{m_warp}x{n_warp}_{num_buffers}buf_arena" + f"mxscale_{data_format}_{tile_m}x{tile_n}x{tile_k}_{m_warp}x{n_warp}_{num_buffers}buf_arena" ), ) - use_ref_segmented_lds_layout = ( - data_format == "fp8" - and tile_m == 256 - and tile_n == 256 - and tile_k == 128 - and m_warp == 2 - and n_warp == 2 - and num_buffers == 4 - and split_k == 1 - and wave_specialized_tdm - and not use_scale_opsel + stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] + stage_phys_order.append(_last_compute_stage) + stage_base_off = [0] * num_buffers + for phys_i, logical_i in enumerate(stage_phys_order): + stage_base_off[logical_i] = phys_i * stage_pitch_bytes + arena_alloc.ptr = stage_pitch_bytes * num_buffers + arena_total_bytes = arena_alloc.ptr + epilogue_fence_threshold_bytes = tdm_epilogue_fence_threshold_bytes( + stage_base_off=stage_base_off, + tail_plan=_base_tail_plan, + loop_iters=loop_iters, + extra=extra, ) - # "vgpr"/"vgpr_ab_split": load scale global->VGPR via buffer_load, bypassing - # TDM+LDS entirely. Requires the reference segmented LDS layout. - use_buffer_vgpr_scale = scale_load_path in ("vgpr", "vgpr_ab_split") - if use_buffer_vgpr_scale and not use_ref_segmented_lds_layout: - raise ValueError( - f"scale_load_path={scale_load_path!r} requires the reference segmented " - "LDS layout (not active for this tile/format configuration)" - ) - # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. D=1 is the - # sweet spot; D=2 doubles scale VGPRs -> spill + ~18% regression. - _bvs_D = max(1, int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", "1"))) - # ab_half_split: repurpose the (under "vgpr") idle scale waves 2,3 as the - # second halves of A/B, so all 4 waves share the A/B TDM (wave0=A0, wave1=B0, - # wave2=A1, wave3=B1). Measured wall-neutral. - use_ab_half_split = scale_load_path == "vgpr_ab_split" - # The buffer_load->VGPR scale ring is built only when scale is actually loaded. - _bvs_active = use_buffer_vgpr_scale - - if use_ref_segmented_lds_layout: - # The A/B data pools are no longer packed into the same per-stage - # 64KiB segment window. Scale pools keep the reference 0x800 stride so - # every TDM LDS target remains 2KiB-aligned. - ref_a_stage_stride = 0x9000 - ref_b_stage_stride = 0x8000 - ref_scale_stage_stride = 0x800 - if lds_a_data_bytes > ref_a_stage_stride: - raise RuntimeError( - "reference segmented LDS layout requires A stage <= 0x9000 bytes, " - f"got {lds_a_data_bytes}" - ) - if lds_b_data_bytes > ref_b_stage_stride: - raise RuntimeError( - "reference segmented LDS layout requires B stage <= 0x8000 bytes, " - f"got {lds_b_data_bytes}" - ) - if ( - lds_a_scale_bytes > ref_scale_stage_stride - or lds_b_scale_bytes > ref_scale_stage_stride - ): - raise RuntimeError( - "reference segmented LDS layout requires scale stage <= 0x800 bytes, " - f"got A={lds_a_scale_bytes} B={lds_b_scale_bytes}" - ) - - stage_a_data_off = [0x00000, 0x09000, 0x16000, 0x1F000] - stage_a_scale_off = [ - 0x12000 + i * ref_scale_stage_stride for i in range(num_buffers) - ] - stage_b_scale_off = [ - 0x28000 + i * ref_scale_stage_stride for i in range(num_buffers) - ] - stage_b_data_off = [ - 0x30000 + i * ref_b_stage_stride for i in range(num_buffers) - ] - arena_alloc.ptr = LDS_GFX1250_MAX_BYTES - arena_total_bytes = arena_alloc.ptr - - # The epilogue may reuse the prefix only after all main/tail TDM traffic - # is fully fenced. This is outside the hot loop and avoids assuming a - # single monotonic per-stage base for the segmented pool layout. - epilogue_fence_threshold_bytes = 0 - else: - stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] - stage_phys_order.append(_last_compute_stage) - stage_base_off = [0] * num_buffers - for phys_i, logical_i in enumerate(stage_phys_order): - stage_base_off[logical_i] = phys_i * stage_pitch_bytes - arena_alloc.ptr = stage_pitch_bytes * num_buffers - arena_total_bytes = arena_alloc.ptr - epilogue_fence_threshold_bytes = tdm_epilogue_fence_threshold_bytes( - stage_base_off=stage_base_off, - tail_plan=_base_tail_plan, - loop_iters=loop_iters, - extra=extra, - ) - - stage_a_data_off = [ - stage_base_off[i] + stage_a_data_rel_off for i in range(num_buffers) - ] - stage_b_data_off = [ - stage_base_off[i] + stage_b_data_rel_off for i in range(num_buffers) - ] - stage_a_scale_off = [ - stage_base_off[i] + stage_a_scale_rel_off for i in range(num_buffers) - ] - stage_b_scale_off = [ - stage_base_off[i] + stage_b_scale_rel_off for i in range(num_buffers) - ] + stage_a_data_off = [ + stage_base_off[i] + stage_a_data_rel_off for i in range(num_buffers) + ] + stage_b_data_off = [ + stage_base_off[i] + stage_b_data_rel_off for i in range(num_buffers) + ] + stage_a_scale_off = [ + stage_base_off[i] + stage_a_scale_rel_off for i in range(num_buffers) + ] + stage_b_scale_off = [ + stage_base_off[i] + stage_b_scale_rel_off for i in range(num_buffers) + ] - if use_tdm_store: + if tdm_store_enabled: lds_d_row_stride = warp_tile_n * elem_bytes_d + LDS_PAD_D_BYTES warp_d_bytes = warp_tile_m * lds_d_row_stride total_d_bytes = num_warps * warp_d_bytes @@ -478,12 +437,9 @@ def _align_up(value: int, align: int) -> int: arena_alloc.ptr = total_d_bytes check_smem_capacity(arena_total_bytes, gpu_arch) - # TENSORcnt is tracked per-wave in hardware. Wave-specialized TDM issues one - # tensor_load per wave per step; otherwise all 4 (A/B/A_scale/B_scale). - if wave_specialized_tdm: - TDM_LOADS_PER_STEP = 1 - else: - TDM_LOADS_PER_STEP = 4 + # TENSORcnt is tracked per-wave in hardware. Keep the fence budget in stage units; + # secondary scale descriptors on 2/3-wave mxscale paths only make this more conservative. + TDM_LOADS_PER_STEP = 1 tail_plan = [ (ls, cs, o * TDM_LOADS_PER_STEP // 2 if o > 0 else o) for ls, cs, o in _base_tail_plan @@ -509,10 +465,9 @@ def _align_up(value: int, align: int) -> int: _sub_tiles.append((acc_idx, 0, m_off, n_sub)) COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING = "row_major_streaming" - COMPUTE_SCHEDULE_FP4_COL_BAND = "fp4_col_band" + COMPUTE_SCHEDULE_FP4_QUADRANT = "fp4_quadrant" COMPUTE_SCHEDULE_FP8_QUADRANT = "fp8_quadrant" COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE = "fp8_deep_pipeline" - COMPUTE_SCHEDULE_B_STREAMING = "b_streaming" fp8_deep_pipeline_eligible = ( data_format in ("fp8", "a8w4") @@ -522,88 +477,77 @@ def _align_up(value: int, align: int) -> int: and m_warp == 2 and n_warp == 2 and num_buffers == 4 - and wave_specialized_tdm and out_dtype == "bf16" - and not use_scale_opsel ) - if fp8_schedule == "deep-pipeline" and not fp8_deep_pipeline_eligible: - raise ValueError( - "fp8_schedule='deep-pipeline' requires fp8 256x256x128, " - "m_warp=n_warp=2, num_buffers=4, wave_specialized_tdm=True, " - "out_dtype='bf16', and use_scale_opsel=False" - ) def _pick_compute_schedule_kind(): - if b_streaming: - return COMPUTE_SCHEDULE_B_STREAMING if wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8: return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING - # Quadrant schedules split B into left/right halves and compute - # top-left, bottom-left, top-right, bottom-right. FP4 additionally - # changes accumulator layout for bank friendliness; FP8 keeps row-major - # accumulators and uses the split to increase LDS-load-to-WMMA distance. + # Quadrant: split B left/right, compute the 4 quadrants to widen the + # LDS-load-to-WMMA distance. FP4/FP8 differ only in per-format wait tuning. if is_fp4: - return COMPUTE_SCHEDULE_FP4_COL_BAND + return COMPUTE_SCHEDULE_FP4_QUADRANT # A8W4 (FP8 act + FP4 weight) shares FP8's accumulator layout and operand # path, so it reuses the FP8 schedules. if data_format in ("fp8", "a8w4"): - if fp8_schedule == "deep-pipeline" or ( - fp8_schedule == "auto" and fp8_deep_pipeline_eligible - ): + if fp8_deep_pipeline_eligible: return COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE return COMPUTE_SCHEDULE_FP8_QUADRANT return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING compute_schedule_kind = _pick_compute_schedule_kind() - use_fp4_bank_friendly_schedule = ( - compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND + use_row_major_streaming_schedule = ( + compute_schedule_kind == COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING ) + use_fp4_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP4_QUADRANT use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = ( compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE ) - use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING - if use_buffer_vgpr_scale and not use_fp8_deep_pipeline_schedule: - raise ValueError( - f"scale_load_path={scale_load_path!r} is only supported with the FP8 deep-pipeline schedule" + use_row_major_k_prefetch = wmma_m_rep == 1 and k_wmma_steps > 1 + _row_major_k_prefetch_depth = 2 if use_row_major_k_prefetch else 1 + _row_major_k_prefetch_depth = max( + 0, min(k_wmma_steps - 1, _row_major_k_prefetch_depth) + ) + use_row_major_late_signal = use_row_major_k_prefetch + + # A-scale VGPR-ring prefetch depth (K-tiles ahead). Deeper K tiles expose + # more latency to hide; depth 4 improves the small-M row-major large-K path + if use_ascale_vgpr and use_row_major_streaming_schedule: + _bvs_D = 4 if num_buffers >= 4 else 3 + else: + _bvs_D = 1 + _bvs_active = use_ascale_vgpr + + if is_mxscale: + assert compute_schedule_kind in ( + COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING, + COMPUTE_SCHEDULE_FP8_QUADRANT, + COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE, + COMPUTE_SCHEDULE_FP4_QUADRANT, ) use_ws_tdm_split_signal_overlap = ( - wave_specialized_tdm - and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) + (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) and num_buffers == 4 and use_cluster ) - if use_b_streaming_schedule: - print( - f"[b_streaming] {data_format} tile=({tile_m},{tile_n},{tile_k}) " - f"M_r={wmma_m_rep} N_r={wmma_n_rep}", - flush=True, - ) + use_tdm_late_signal_overlap = ( + use_ws_tdm_split_signal_overlap or use_row_major_late_signal + ) - if use_fp4_bank_friendly_schedule: - _bank_half_wm = wmma_m_rep // 2 - _bank_half_wn = wmma_n_rep // 2 - _bank_group_size = _bank_half_wm * _bank_half_wn - _bank_half_b_scale_rep = b_scale_load_rep // 2 - _bank_group_to_row_major = [] - for _wm in range(_bank_half_wm): - for _wn in range(_bank_half_wn): - _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) - for _wm in range(_bank_half_wm, wmma_m_rep): - for _wn in range(_bank_half_wn): - _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) - for _wm in range(_bank_half_wm): - for _wn in range(_bank_half_wn, wmma_n_rep): - _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) - for _wm in range(_bank_half_wm, wmma_m_rep): - for _wn in range(_bank_half_wn, wmma_n_rep): - _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) + if use_fp4_quadrant_schedule: + _fp4_half_wm = wmma_m_rep // 2 + _fp4_half_wn = wmma_n_rep // 2 + _fp4_group_size = _fp4_half_wm * _fp4_half_wn if use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule: _fp8_half_wm = wmma_m_rep // 2 _fp8_half_wn = wmma_n_rep // 2 _fp8_group_size = _fp8_half_wm * _fp8_half_wn - _fp8_b_scale_loads = 0 if is_ptpc else (b_scale_load_rep + 3) // 4 + if is_mxscale: + _fp8_b_scale_loads = bs32_n_load # 32x4: one b32 per block-or-WMMA per ks + else: + _fp8_b_scale_loads = 0 if is_ptpc else (b_scale_load_rep + 3) // 4 if use_fp8_deep_pipeline_schedule: _fp8_pair_wm = 2 _fp8_pair_wn = 2 @@ -611,9 +555,8 @@ def _pick_compute_schedule_kind(): _fp8_wn_pairs = wmma_n_rep // _fp8_pair_wn _fp8_pair_a_loads = _fp8_pair_wm * DS_LOADS_PER_A_FRAG _fp8_pair_b_loads = _fp8_pair_wn * _b_frag_loads_per_wn - _fp8_scale_loads = ( - 0 if is_ptpc else (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 - ) + # Scale ds_loads issued at the loop top. Uses the finalized module-level counts. + _fp8_scale_loads = 0 if is_ptpc else (_a_scale_ds + _b_scale_ds) @flyc.kernel(known_block_size=[block_threads, 1, 1]) def kernel_mxscale_gemm( @@ -672,48 +615,104 @@ def kernel_mxscale_gemm( warp_m_base = wave_m_idx * arith.index(warp_tile_m) warp_n_base = wave_n_idx * arith.index(warp_tile_n) + m_idx = fx.Index(i32_m) + + def _load_contig_i32(rsrc, base_idx, n, soff): + # Load n contiguous i32 values through the widest legal buffer_load chunks. + out = [None] * n + _chunks = _vec_chunks(n) + for _ci in range_constexpr(len(_chunks)): + start, w = _chunks[_ci] + off = arith.index_cast(T.i32, base_idx + arith.index(start)) + r = buffer_ops.buffer_load( + rsrc, off, vec_width=w, dtype=T.i32, soffset_bytes=soff + ) + if const_expr(w == 1): + out[start] = r + else: + rv = fx.Vector(r) + for c in range_constexpr(w): + out[start + c] = rv[c] + return out + + _scale_identity_i32 = arith.constant(0x7F7F7F7F, type=T.i32) - if const_expr(use_buffer_vgpr_scale): - # Direct global->VGPR scale load (no TDM/LDS). Coalesced lane-major - # host layout [M_block(128), K_tile, group(2), lane16(16), 4 i32], so - # each buffer_load_b128's 16 lanes read 256 contiguous bytes: - # i32_off(group) = (mb*Kt + kt)*128 + group*64 + lane16*4 - _bvs_a_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) - _bvs_b_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False) - _bvs_Kt = K // tile_k # total K-tiles - _bvs_mb_a = blk_m // arith.index(128) + wave_m_idx - _bvs_mb_b = blk_n // arith.index(128) + wave_n_idx - _bvs_lane4 = lane16 * arith.index(4) - - def _bvs_load_scales(rsrc, mb, rep, k_base): + if const_expr(use_ascale_vgpr): + # A-scale VGPR path: read scale_A[M, K//32] directly from its row-major layout. + _ascale_nbytes = m_idx * arith.index(K_scale) + _ascale_rsrc = buffer_ops.create_buffer_resource( + arg_a_scale, + max_size=False, + num_records_bytes=_ascale_nbytes, + ) + _ascale_row_i32 = K_scale // 4 + _ascale_row0 = blk_m + warp_m_base + lane16 + if const_expr(ascale_opsel): + _ascale_row0 = _ascale_row0 + lane_kgrp * arith.index( + ascale_half * WMMA_M + ) + _vs_tile_a = k_wmma_steps * ascale_load + + def _load_contig_i32_guarded_row(row, n, soff): + row_valid = row < m_idx + if_op = scf.IfOp(row_valid, [T.i32] * n, has_else=True) + with ir.InsertionPoint(if_op.then_block): + vals = _load_contig_i32( + _ascale_rsrc, + row * arith.index(_ascale_row_i32), + n, + soff, + ) + scf.YieldOp([arith.unwrap(v) for v in vals]) + with ir.InsertionPoint(if_op.else_block): + scf.YieldOp([arith.unwrap(_scale_identity_i32) for _ in range(n)]) + return list(if_op.results) + + def _load_ascale_impl(k_base, guarded): kt = k_base // arith.index(tile_k) - tile_i32 = (mb * arith.index(_bvs_Kt) + kt) * arith.index(128) - vals = [] - for ld in range_constexpr(rep // 4): # rep=8 -> 2 groups of 4 i32 - off = arith.index_cast( - T.i32, tile_i32 + arith.index(ld * 64) + _bvs_lane4 + soff = arith.index_cast(T.i32, kt * arith.index(scale_k_per_tile)) + vals = [None] * (k_wmma_steps * ascale_load) + for i in range_constexpr(ascale_load): + row = _ascale_row0 + arith.index(i * WMMA_M) + if const_expr(guarded): + ks_vals = _load_contig_i32_guarded_row(row, k_wmma_steps, soff) + else: + vidx = row * arith.index(_ascale_row_i32) + ks_vals = _load_contig_i32( + _ascale_rsrc, vidx, k_wmma_steps, soff + ) + for ks in range_constexpr(k_wmma_steps): + vals[ks * ascale_load + i] = ks_vals[ks] + return vals + + def _load_ascale(k_base): + full_tile = (blk_m + arith.index(tile_m)) <= m_idx + if_op = scf.IfOp(full_tile, [T.i32] * _vs_tile_a, has_else=True) + with ir.InsertionPoint(if_op.then_block): + scf.YieldOp( + [ + arith.unwrap(v) + for v in _load_ascale_impl(k_base, guarded=False) + ] ) - v = fx.Vector( - buffer_ops.buffer_load(rsrc, off, vec_width=4, dtype=T.i32) + with ir.InsertionPoint(if_op.else_block): + scf.YieldOp( + [ + arith.unwrap(v) + for v in _load_ascale_impl(k_base, guarded=True) + ] ) - for j in range_constexpr(4): - vals.append(v[j]) - return vals + return list(if_op.results) - def _bvs_prefetch(k_base): - # Issue scale buffer_load for one K-tile; returns (a[8], b[8]) VGPR. - a = _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base) - b = _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) - return a, b + _bvs_prefetch = _load_ascale - m_idx = fx.Index(i32_m) - # Leading-dim strides arrive at runtime (strided A // C); the dense path - # passes lda == K and ldc == N, giving byte-identical addressing. A's - # stride is in packed-A elements (== lda for fp8 where PACK_FACTOR_A == 1). + # Runtime leading-dim strides (strided A/C). Dense callers pass lda == K, + # ldc == N for byte-identical addressing. A's stride is in packed elements. if const_expr(PACK_FACTOR_A == 1): lda_packed = fx.Index(i32_lda) else: - lda_packed = fx.Index(i32_lda) // arith.index(PACK_FACTOR_A) + lda_packed = fx.Index(i32_lda) / arith.index(PACK_FACTOR_A) + n_stride = fx.Index(i32_ldc) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) @@ -727,7 +726,7 @@ def _bvs_prefetch(k_base): def make_desc_a(memref, k_base): k_packed_off = k_base // arith.index(PACK_FACTOR_A) - return _make_tdm_desc( + return tdm_ops.make_tensor_descriptor_2d( global_ptr=arg_a, lds_memref=memref, global_offset=(blk_m, k_packed_off), @@ -746,7 +745,7 @@ def make_desc_a(memref, k_base): def make_desc_b(memref, k_base): k_packed_off = k_base // arith.index(PACK_FACTOR_B) - return _make_tdm_desc( + return tdm_ops.make_tensor_descriptor_2d( global_ptr=arg_b, lds_memref=memref, global_offset=( @@ -765,61 +764,42 @@ def make_desc_b(memref, k_base): early_timeout=True, ) - def make_desc_a_half(memref, k_base, m_half: int): - row_start = m_half * ab_split_a_rows - k_packed_off = k_base // arith.index(PACK_FACTOR_A) - return _make_tdm_desc( - global_ptr=arg_a, - lds_memref=memref, - global_offset=(blk_m + arith.index(row_start), k_packed_off), - tensor_shape=(tile_m, packed_tile_k_a), - strides=(lda_packed, 1), - tile_shape=(ab_split_a_rows, packed_tile_k_a), - elem_bytes=1, - pad_interval=packed_tile_k_a, - pad_amount=LDS_PAD_A_BYTES, - num_warps=1, - workgroup_mask=a_mcast_mask, - lds_byte_offset=arith.index(row_start * lds_a_stride_bytes), - atomic_barrier_enable=atomic_barrier_enable, - early_timeout=True, - oob_outer_bound=i32_m, - ) - - def make_desc_b_half(memref, k_base, n_half: int): - group_start = n_half * ab_split_b_groups - k_packed_off = k_base // arith.index(PACK_FACTOR_B) - return _make_tdm_desc( - global_ptr=arg_b, + def make_desc_bs(memref, k_base): + # 32x4: copy this tile's 32-N blocks x K-blocks slice of the preshuffled + # [N//32, (K//128)*128] B-scale tensor. + block_off = blk_n // arith.index(32) + col_off = (k_base // arith.index(WMMA_K)) * arith.index(bs32_block_bytes) + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b_scale, lds_memref=memref, - global_offset=( - blk_n // arith.index(16) + arith.index(group_start), - k_packed_off * arith.index(16), - ), - tensor_shape=(N // 16, K_packed_b * 16), - strides=(K_packed_b * 16, 1), - tile_shape=(ab_split_b_groups, packed_tile_k_b * 16), + global_offset=(block_off, col_off), + tensor_shape=(N // 32, bs32_global_row_stride), + strides=(bs32_global_row_stride, 1), + tile_shape=(bs32_tile_blocks_pad, bs32_lds_row_stride), elem_bytes=1, pad_interval=0, pad_amount=0, - num_warps=1, + num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask, - lds_byte_offset=arith.index(group_start * packed_tile_k_b * 16), atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=N // 32, ) def make_desc_as(memref, k_base): - k_scale_off = k_base // arith.index(SCALE_BLOCK) - outer_off = blk_m // arith.index(wmma_m_rep) - inner_off = k_scale_off * arith.index(wmma_m_rep) - return _make_tdm_desc( + # 32x4: copy this tile's M block rows from the packed A-scale tensor. + # Runtime OOB clips whole missing block rows; the LDS reader masks lanes + # inside the final partial block to the E8M0 identity value. + block_off = blk_m // arith.index(32) + col_off = (k_base // arith.index(WMMA_K)) * arith.index(as32_block_bytes) + m_block_bound = (m_idx + arith.index(31)) // arith.index(32) + return tdm_ops.make_tensor_descriptor_2d( global_ptr=arg_a_scale, lds_memref=memref, - global_offset=(outer_off, inner_off), - tensor_shape=(WMMA_M * m_warp, interleaved_scale_cols_a), - strides=(wmma_m_rep * K_scale, 1), - tile_shape=(WMMA_M * m_warp, interleaved_scale_cols_a), + global_offset=(block_off, col_off), + tensor_shape=(as32_tile_blocks_pad, as32_global_row_stride), + strides=(as32_global_row_stride, 1), + tile_shape=(as32_tile_blocks_pad, as32_lds_row_stride), elem_bytes=1, pad_interval=0, pad_amount=0, @@ -827,38 +807,18 @@ def make_desc_as(memref, k_base): workgroup_mask=a_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=m_block_bound, ) - def make_desc_bs(memref, k_base): - k_scale_off = k_base // arith.index(SCALE_BLOCK) - outer_off = blk_n // arith.index(b_scale_load_rep) - inner_off = k_scale_off * arith.index(b_scale_load_rep) - return _make_tdm_desc( - global_ptr=arg_b_scale, - lds_memref=memref, - global_offset=(outer_off, inner_off), - tensor_shape=(WMMA_M * n_warp, interleaved_scale_cols_b), - strides=(b_scale_load_rep * K_scale, 1), - tile_shape=(WMMA_M * n_warp, interleaved_scale_cols_b), - elem_bytes=1, - pad_interval=0, - pad_amount=0, - num_warps=tdm_desc_num_warps, - workgroup_mask=b_mcast_mask, - atomic_barrier_enable=atomic_barrier_enable, - early_timeout=True, - ) - - if const_expr(wave_specialized_tdm): - tdm_wave_id = rocdl.wave_id() - tdm_wave_is_a = tdm_wave_id == fx.Int32(0) - tdm_wave_is_b = tdm_wave_id == fx.Int32(1) - tdm_wave_is_as = tdm_wave_id == fx.Int32(2) + tdm_wave_id = rocdl.wave_id() + tdm_wave_is_a = tdm_wave_id == fx.Int32(0) + tdm_wave_is_b = tdm_wave_id == fx.Int32(1) + tdm_wave_is_as = tdm_wave_id == fx.Int32(2) - def _select_wave_tdm_value(a_value, b_value, as_value, bs_value): - result = arith.select(tdm_wave_is_as, as_value, bs_value) - result = arith.select(tdm_wave_is_b, b_value, result) - return arith.select(tdm_wave_is_a, a_value, result) + def _select_wave_tdm_value(a_value, b_value, as_value, bs_value): + result = arith.select(tdm_wave_is_as, as_value, bs_value) + result = arith.select(tdm_wave_is_b, b_value, result) + return arith.select(tdm_wave_is_a, a_value, result) elem_ty_lds = T.f16 @@ -991,75 +951,165 @@ def load_b_frag(lds_buffer, b_lane_bases, wn, ks): v23 = v2.shuffle(v3, list(range(8))) return v01.shuffle(v23, list(range(16))) - def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols): - """Precompute scale lane bases (byte offsets).""" - warp_lds_row = warp_base // arith.index(reps) + lane16 - base = warp_lds_row * arith.index(interleaved_cols) - if const_expr(is_fp4 or is_a8w4): - # FP4/A8W4: always add lane_kgrp offset (no opsel on BScale) - base = base + lane_kgrp * arith.index(SCALES_PER_WMMA) - else: - # FP8: conditional on opsel - if const_expr(use_scale_opsel): - base = base + lane_kgrp * arith.index(SCALES_PER_WMMA) - return lds_ptr, [base] - - def load_scale_b128(lds_buffer, scale_base, reps, ks=0): - """Load all wmma_rep scales via ds_load_b128(s) for K-subtile *ks*.""" - ks_byte_off = ks * reps * SCALES_PER_WMMA - eff_base = ( - scale_base - if ks_byte_off == 0 - else scale_base + arith.index(ks_byte_off) - ) - num_loads = (reps + 3) // 4 - vecs = [] - for ld in range_constexpr(num_loads): - off = eff_base if ld == 0 else eff_base + arith.index(ld * 16) - vecs.append(fx.Vector(lds_load_b128_raw(lds_buffer, off))) + def _precompute_bs32_bases(lds_ptr): + """Tile-local 32-N block base for the warp's 32x4 B-scale read. + + An LDS block row (32 N-rows x 4 K-scales = 128B) is one 32-lane WMMA scale + operand. op_sel path (even rep): the warp owns whole blocks block0+j. Else + (fp4 / odd rep): each WMMA reads its own 16/32-N into the operand lanes. + """ + return lds_ptr, warp_n_base // arith.index(32) + + def _precompute_as32_bases(lds_ptr): + """Tile-local first A row, relative to the copied 32-row block base.""" + return lds_ptr, (blk_m % arith.index(32)) + warp_m_base + + def _mask_a_scale_oob(word, row_abs): + return arith.select(row_abs < m_idx, word, _scale_identity_i32) + + def _load_scale32_full_blocks( + lds_buffer, + block0, + ks, + row_stride_bytes, + block_bytes, + load_count, + row_abs0=None, + ): + stride = arith.index(row_stride_bytes) + ks_off = arith.index(ks * block_bytes) + lane32 = lane_kgrp * arith.index(16) + lane16 + lane = lane32 * arith.index(4) results = [] - for i in range_constexpr(reps): - results.append(vecs[i // 4][i % 4]) + for i in range_constexpr(load_count): + off = (block0 + arith.index(i)) * stride + ks_off + lane + word = lds_load_b32_raw(lds_buffer, off) + if const_expr(row_abs0 is not None): + word = _mask_a_scale_oob( + word, row_abs0 + arith.index(i * 32) + lane32 + ) + results.append(word) return results - def load_scale_slice_b128( - lds_buffer, scale_base, full_reps, rep_start, rep_count, ks=0 + def _load_scale32_half_blocks( + lds_buffer, + row16_base, + ks, + row_stride_bytes, + block_bytes, + load_count, + row_abs_base=None, ): - """Load a contiguous slice of packed scale VGPRs for one K-subtile.""" - ks_byte_off = (ks * full_reps + rep_start) * SCALES_PER_WMMA - eff_base = ( - scale_base - if ks_byte_off == 0 - else scale_base + arith.index(ks_byte_off) - ) - num_loads = (rep_count + 3) // 4 - vecs = [] - for ld in range_constexpr(num_loads): - off = eff_base if ld == 0 else eff_base + arith.index(ld * 16) - vecs.append(fx.Vector(lds_load_b128_raw(lds_buffer, off))) + stride = arith.index(row_stride_bytes) + ks_off = arith.index(ks * block_bytes) results = [] - for i in range_constexpr(rep_count): - results.append(vecs[i // 4][i % 4]) + for i in range_constexpr(load_count): + row16 = row16_base + arith.index(i * 16) + off = ( + (row16 // arith.index(32)) * stride + + ks_off + + (row16 % arith.index(32) + lane16) * arith.index(4) + ) + word = lds_load_b32_raw(lds_buffer, off) + if const_expr(row_abs_base is not None): + word = _mask_a_scale_oob( + word, row_abs_base + arith.index(i * 16) + lane16 + ) + results.append(word) return results - def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): - """Load both scale tensors and apply op_sel downsampling per format. + def load_as32_ascale(lds_buffer, row0, ks): + """Load 32x4 A-scale i32s for K-subtile *ks*.""" + if const_expr(as32_opsel): + return _load_scale32_full_blocks( + lds_buffer, + row0 // arith.index(32), + ks, + as32_lds_row_stride, + as32_block_bytes, + wmma_m_rep // 2, + row_abs0=blk_m + warp_m_base, + ) + return _load_scale32_half_blocks( + lds_buffer, + row0, + ks, + as32_lds_row_stride, + as32_block_bytes, + wmma_m_rep, + row_abs_base=blk_m + warp_m_base, + ) - FP4 BScale has no op_sel (scaleAType=0 fixed); only AScale halves. - FP8/A8W4 16x16 supports op_sel on both. - """ + def load_bs32_bscale(lds_buffer, block0, ks): + """Load 32x4 B-scale i32s for K-subtile *ks* (one b32 per block-or-WMMA).""" + if const_expr(bs32_opsel): + # Even rep: full 32-lane block; op_sel picks the 16-half in _emit_wmma. + return _load_scale32_full_blocks( + lds_buffer, + block0, + ks, + bs32_lds_row_stride, + bs32_block_bytes, + wmma_n_rep // 2, + ) + elif const_expr(is_fp4): + # fp4: one 32-N block per WMMA (no op_sel). + return _load_scale32_full_blocks( + lds_buffer, + block0, + ks, + bs32_lds_row_stride, + bs32_block_bytes, + wmma_n_rep, + ) + # fp8 odd rep: each WMMA's 16-N into lanes 0-15 (op_sel=0); the block + # and its 16-half are runtime (warp may start mid-block). + return _load_scale32_half_blocks( + lds_buffer, + warp_n_base, + ks, + bs32_lds_row_stride, + bs32_block_bytes, + wmma_n_rep, + ) + + def _load_a_scale_lds(as_buf, as_row0, ks): + """Load 32x4 A-scale from LDS (mxscale only).""" + return load_as32_ascale(as_buf, as_row0, ks) + + # Current tile's VGPR-path A-scales, ordered [k_wmma_step][M-rep]. + _vgpr_scale_box = [None] + + def _set_vgpr_a_scales(scale_k_base, pf_a_scales): + if const_expr(use_ascale_vgpr): + if const_expr(pf_a_scales is not None): + _vgpr_scale_box[0] = pf_a_scales + else: + rocdl.sched_barrier(0) + _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) + + def _load_a_scale_vgpr(ks): + pf_a = _vgpr_scale_box[0] + return pf_a[ks * ascale_load : (ks + 1) * ascale_load] + + def _load_b_scale_lds(bs_buf, bs_block0, ks): + """Load 32x4 B-scale from LDS (mxscale only; ptpc reads no K-loop B-scale).""" + return load_bs32_bscale(bs_buf, bs_block0, ks) + + def _load_a_scale_operand(as_buf, as_bases, ks): + if const_expr(use_ascale_vgpr): + return _load_a_scale_vgpr(ks) + return _load_a_scale_lds(as_buf, as_bases, ks) + + def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): + """Load scale operands for K-subtile *ks*.""" if const_expr(is_ptpc): return None, None - a_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - b_all = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - if const_expr(use_scale_opsel): - a = a_all[::2] - b = b_all if const_expr(is_fp4) else b_all[::2] - else: - a, b = a_all, b_all + a = _load_a_scale_operand(as_buf, as_bases, ks) + b = _load_b_scale_lds(bs_buf, bs_bases, ks) return a, b - def _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, ks): + def _load_b_and_scales(b_buf, b_bases, as_buf, as_bases, bs_buf, bs_bases, ks): b_frags = [ load_b_frag(b_buf, b_bases, wn, ks) for wn in range_constexpr(wmma_n_rep) @@ -1069,16 +1119,6 @@ def _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, ks): ) return b_frags, b_scales, a_scales - def _load_a_and_scales(a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases, ks): - a_frags = [ - load_a_frag(a_buf, a_bases[wm], ks) - for wm in range_constexpr(wmma_m_rep) - ] - a_scales, b_scales = _scales_for_emit( - as_buf, as_bases, bs_buf, bs_bases, ks - ) - return a_frags, a_scales, b_scales - def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): """Emit one WMMA instruction (format-specific).""" idx = wm * wmma_n_rep + wn @@ -1095,23 +1135,17 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): fmtB=0, ) else: - # PTPC-FP8 needs no per-K scaling. We emit the scaled f8f6f4 op - # with an identity E8M0 scale (0x7F = 2^0 = 1.0) for toolchain - # compatibility; it is numerically equivalent to the dedicated - # no-scale op. Future: switch to the equivalent no-scale wmma: - # accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8(T.vec(8, T.f32), b_frag, a_frag, accs[idx]) - accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( - T.vec(8, T.f32), - b_frag, - a_frag, - accs[idx], - 0x7F7F7F7F, - 0x7F7F7F7F, - fmtA=0, - fmtB=0, + # PTPC-FP8 needs no per-K scaling: dedicated no-scale E4M3 WMMA. + accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8( + T.vec(8, T.f32), b_frag, a_frag, accs[idx] ) return - if const_expr(use_scale_opsel): + if const_expr(use_ascale_vgpr and ascale_opsel): + # VGPR path pairs M-blocks across the two lane_kgrp halves. + a_scale_idx = wm % ascale_half + a_opsel = wm // ascale_half + elif const_expr(use_ascale_shuffled_tdm and as32_opsel): + # Shuffled path pairs adjacent 16-M WMMAs in one 32-row block. a_scale_idx = wm // 2 a_opsel = wm % 2 else: @@ -1119,20 +1153,23 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): a_opsel = 0 if const_expr(is_fp4): - # 32x16 WMMA with A/B swap: SRC0=B, SRC1=A + # 32x16 WMMA with A/B swap: SRC0=B, SRC1=A. 32x4 reads one 32-N block + # per WMMA (idx wn). accs[idx] = rocdl.wmma_scale_f32_32x16x128_f4( T.vec(16, T.f32), b_frag, a_frag, accs[idx], - b_scales[wn * 2], + b_scales[wn], a_scales[a_scale_idx], scaleAType=0, scaleBType=a_opsel, ) else: - # 16x16x128 WMMA: A8W4 (fmtA=FP4) or FP8 (fmtA=FP8) - if const_expr(use_scale_opsel): + # 16x16x128 WMMA: A8W4 (fmtA=FP4) or FP8 (fmtA=FP8). op_sel pairs + # adjacent 16-N halves (32x4 even rep); else one scale per WMMA + # (32x4 odd rep, or no op_sel). + if const_expr(bs32_opsel): b_scale_idx = wn // 2 b_opsel = wn % 2 else: @@ -1201,11 +1238,11 @@ def _emit_rows(start_wm, a_frags): ) if const_expr(_use_partial_drain): - nb_buf, nb_bases, nbs_buf, nbs_bases, nas_buf, nas_bases, n_ks = ( + nb_buf, nb_bases, nas_buf, nas_bases, nbs_buf, nbs_bases, n_ks = ( next_bs_info ) next_result = _load_b_and_scales( - nb_buf, nb_bases, nbs_buf, nbs_bases, nas_buf, nas_bases, n_ks + nb_buf, nb_bases, nas_buf, nas_bases, nbs_buf, nbs_bases, n_ks ) rocdl.s_wait_dscnt(_bs_ds_loads) else: @@ -1229,81 +1266,15 @@ def _emit_rows(start_wm, a_frags): if const_expr(_use_partial_drain): return accs, next_result if const_expr(next_bs_info is not None): - nb_buf, nb_bases, nbs_buf, nbs_bases, nas_buf, nas_bases, n_ks = ( + nb_buf, nb_bases, nas_buf, nas_bases, nbs_buf, nbs_bases, n_ks = ( next_bs_info ) next_result = _load_b_and_scales( - nb_buf, nb_bases, nbs_buf, nbs_bases, nas_buf, nas_bases, n_ks + nb_buf, nb_bases, nas_buf, nas_bases, nbs_buf, nbs_bases, n_ks ) return accs, next_result return accs - def _b_streaming_compute( - accs, - b_buf, - b_bases, - a_frags, - a_scales, - b_scales, - ks, - emit_filler=None, - next_info=None, - mid_compute_callback=None, - ): - """B-streaming counterpart to _a_streaming_compute (A held, B streamed).""" - next_result = None - _front_wn = (wmma_n_rep + 1) // 2 - _back_wn = wmma_n_rep - _front_wn - - def _emit_cols(start_wn, b_frags_chunk): - for frag_i in range_constexpr(len(b_frags_chunk)): - wn = start_wn + frag_i - if const_expr(wn == wmma_n_rep - 1 and emit_filler is not None): - rocdl.sched_barrier(0) - emit_filler() - for wm_raw in range_constexpr(wmma_m_rep): - wm = (wmma_m_rep - 1 - wm_raw) if (wn % 2 == 1) else wm_raw - _emit_wmma( - accs, - wm, - wn, - a_frags[wm], - b_frags_chunk[frag_i], - a_scales, - b_scales, - ) - - b_frags_front = [ - load_b_frag(b_buf, b_bases, wn, ks) for wn in range_constexpr(_front_wn) - ] - _use_partial_drain = next_info is not None and _front_wn * wmma_m_rep >= 4 - - if const_expr(_use_partial_drain): - next_result = _load_a_and_scales(*next_info) - rocdl.s_wait_dscnt(_as_ds_loads) - else: - rocdl.s_wait_dscnt(0) - - _emit_cols(0, b_frags_front) - - if const_expr(mid_compute_callback is not None): - rocdl.sched_barrier(0) - mid_compute_callback() - - if const_expr(_back_wn > 0): - b_frags_back = [ - load_b_frag(b_buf, b_bases, _front_wn + h, ks) - for h in range_constexpr(_back_wn) - ] - rocdl.s_wait_dscnt(_as_ds_loads if _use_partial_drain else 0) - _emit_cols(_front_wn, b_frags_back) - - if const_expr(_use_partial_drain): - return accs, next_result - if const_expr(next_info is not None): - return accs, _load_a_and_scales(*next_info) - return accs - # ── Compute on one LDS buffer ── def compute_tile( accs_in, @@ -1313,20 +1284,27 @@ def compute_tile( lds_bs, emit_filler=None, mid_compute_callback=None, + late_compute_callback=None, + scale_k_base=None, + pf_a_scales=None, ): current_accs = list(accs_in) + _set_vgpr_a_scales(scale_k_base, pf_a_scales) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + if const_expr(is_mxscale): + as_buf, as_bases = _precompute_as32_bases(lds_as) + bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) + else: + as_buf, as_bases = lds_as, None + bs_buf, bs_bases = ( + lds_bs, + None, + ) # ptpc: B-scale in epilogue, bases unused if const_expr(k_wmma_steps == 1): b_frags, b_scales, a_scales = _load_b_and_scales( - b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, 0 + b_buf, b_bases, as_buf, as_bases, bs_buf, bs_bases, 0 ) current_accs = _a_streaming_compute( current_accs, @@ -1340,8 +1318,64 @@ def compute_tile( mid_compute_callback=mid_compute_callback, ) else: + if const_expr(use_row_major_k_prefetch): + + def _load_bundle(ks): + b_frags, b_scales, a_scales = _load_b_and_scales( + b_buf, b_bases, as_buf, as_bases, bs_buf, bs_bases, ks + ) + a_frag = load_a_frag(a_buf, a_bases[0], ks) + return a_frag, b_frags, a_scales, b_scales + + def _emit_bundle(bundle, emit_filler_now=False): + a_frag, b_frags, a_scales, b_scales = bundle + if const_expr(emit_filler_now and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + for wn in range_constexpr(wmma_n_rep): + _emit_wmma( + current_accs, + 0, + wn, + a_frag, + b_frags[wn], + a_scales, + b_scales, + ) + + # Keep future K-subtile LDS reads outstanding while only draining + # the current bundle before its single row-major WMMA. + preload_depth = min(k_wmma_steps, _row_major_k_prefetch_depth + 1) + bundle_queue = [ + _load_bundle(pre_ks) + for pre_ks in range_constexpr(preload_depth) + ] + next_ks = preload_depth + for ks in range_constexpr(k_wmma_steps): + is_last_ks = ks == k_wmma_steps - 1 + cur_bundle = bundle_queue.pop(0) + rocdl.s_wait_dscnt( + len(bundle_queue) * _row_major_k_prefetch_bundle_ds + ) + + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() + + _emit_bundle(cur_bundle, emit_filler_now=is_last_ks) + + if const_expr(ks == 0 and mid_compute_callback is not None): + rocdl.sched_barrier(0) + mid_compute_callback() + + if const_expr(next_ks < k_wmma_steps): + bundle_queue.append(_load_bundle(next_ks)) + next_ks += 1 + + return current_accs + prev_b, prev_bs, prev_as = _load_b_and_scales( - b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, 0 + b_buf, b_bases, as_buf, as_bases, bs_buf, bs_bases, 0 ) for ks in range_constexpr(k_wmma_steps - 1): _mid_cb = mid_compute_callback if ks == 0 else None @@ -1356,10 +1390,10 @@ def compute_tile( next_bs_info=( b_buf, b_bases, - bs_buf, - bs_bases, as_buf, as_bases, + bs_buf, + bs_bases, ks + 1, ), mid_compute_callback=_mid_cb, @@ -1376,7 +1410,7 @@ def compute_tile( ) return current_accs - def compute_tile_fp4_bank_friendly( + def compute_tile_fp4_quadrant( accs_in, lds_a, lds_b, @@ -1384,21 +1418,22 @@ def compute_tile_fp4_bank_friendly( lds_bs, emit_filler=None, mid_compute_callback=None, + scale_k_base=None, + pf_a_scales=None, ): current_accs = list(accs_in) + _set_vgpr_a_scales(scale_k_base, pf_a_scales) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) - _b_half_scale_loads = (_bank_half_b_scale_rep + 3) // 4 + as_buf, as_bases = _precompute_as32_bases(lds_as) + bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) + _b_half_scale_loads = _fp4_half_wn # 32x4: one b32 per 32-N block/WMMA def _fp4_get_a_scale_and_opsel(a_scales_all, wm_idx): - if const_expr(use_scale_opsel): - return a_scales_all[(wm_idx // 2) * 2], wm_idx % 2 + if const_expr(use_ascale_vgpr and ascale_opsel): + return a_scales_all[wm_idx % ascale_half], wm_idx // ascale_half + if const_expr(use_ascale_shuffled_tdm and as32_opsel): + return a_scales_all[wm_idx // 2], wm_idx % 2 return a_scales_all[wm_idx], 0 def _load_a_group(wm_base, wm_count, ks): @@ -1410,23 +1445,27 @@ def _load_a_group(wm_base, wm_count, ks): def _load_b_half(wn_base, ks): return [ load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) - for wn_local in range_constexpr(_bank_half_wn) + for wn_local in range_constexpr(_fp4_half_wn) ] - def _load_b_half_bundle(wn_base, rep_start, ks): - b_frags = _load_b_half(wn_base, ks) - b_scales = load_scale_slice_b128( + def _load_bs32_b_half(block0, wn_base, ks): + # 32x4: load this N-half's blocks, one ds_load_b32 per 32-N WMMA (no op_sel). + return _load_scale32_full_blocks( bs_buf, - bs_bases[0], - b_scale_load_rep, - rep_start, - _bank_half_b_scale_rep, + block0 + arith.index(wn_base), ks, + bs32_lds_row_stride, + bs32_block_bytes, + _fp4_half_wn, ) + + def _load_b_half_bundle(wn_base, ks): + b_frags = _load_b_half(wn_base, ks) + b_scales = _load_bs32_b_half(bs_bases, wn_base, ks) return b_frags, b_scales def _emit_group_rows( - group_base, + wn_base, wm_base, a_frags, b_frags, @@ -1444,22 +1483,23 @@ def _emit_group_rows( a_frag = a_frags[wm_local] global_wm = wm_base + wm_local a_scale, a_opsel = _fp4_get_a_scale_and_opsel(a_scales, global_wm) - row_base = group_base + wm_local * _bank_half_wn - for wn_local in range_constexpr(_bank_half_wn): - idx = row_base + wn_local + for wn_local in range_constexpr(_fp4_half_wn): + idx = global_wm * wmma_n_rep + ( + wn_base + wn_local + ) # row-major slot current_accs[idx] = rocdl.wmma_scale_f32_32x16x128_f4( T.vec(16, T.f32), b_frags[wn_local], a_frag, current_accs[idx], - b_scales[wn_local * 2], + b_scales[wn_local], a_scale, scaleAType=0, scaleBType=a_opsel, ) def _emit_group( - group_base, + wn_base, wm_base, a_frags, b_frags, @@ -1468,28 +1508,28 @@ def _emit_group( emit_filler_now=False, ): _emit_group_rows( - group_base, + wn_base, wm_base, a_frags, b_frags, a_scales, b_scales, 0, - _bank_half_wm, + _fp4_half_wm, emit_filler_now=emit_filler_now, ) - b_left_frags, b_left_scales = _load_b_half_bundle(0, 0, 0) + b_left_frags, b_left_scales = _load_b_half_bundle(0, 0) for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 - a_scales_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + a_scales_all = _load_a_scale_operand(as_buf, as_bases, ks) - a_top_frags = _load_a_group(0, _bank_half_wm, ks) - a_bottom_frags = _load_a_group(_bank_half_wm, _bank_half_wm, ks) + a_top_frags = _load_a_group(0, _fp4_half_wm, ks) + a_bottom_frags = _load_a_group(_fp4_half_wm, _fp4_half_wm, ks) # Wait for bottom-A loads; top-A stays in flight during Q1. - rocdl.s_wait_dscnt(_bank_half_wm * DS_LOADS_PER_A_FRAG) + rocdl.s_wait_dscnt(_fp4_half_wm * DS_LOADS_PER_A_FRAG) _emit_group( 0, @@ -1504,17 +1544,15 @@ def _emit_group( rocdl.sched_barrier(0) mid_compute_callback() - b_right_frags, b_right_scales = _load_b_half_bundle( - _bank_half_wn, _bank_half_b_scale_rep, ks - ) + b_right_frags, b_right_scales = _load_b_half_bundle(_fp4_half_wn, ks) # Hold only the next B half outstanding while the second # quadrant consumes the current left-half fragments. - rocdl.s_wait_dscnt(_bank_half_wn * 4 + _b_half_scale_loads) + rocdl.s_wait_dscnt(_fp4_half_wn * 4 + _b_half_scale_loads) _emit_group( - _bank_group_size, - _bank_half_wm, + 0, + _fp4_half_wm, a_bottom_frags, b_left_frags, a_scales_all, @@ -1522,18 +1560,16 @@ def _emit_group( ) if const_expr(not is_last_ks): - next_left_frags, next_left_scales = _load_b_half_bundle( - 0, 0, ks + 1 - ) + next_left_frags, next_left_scales = _load_b_half_bundle(0, ks + 1) # Older right-half loads must be ready before consuming # them, while the next ks left-half preload can remain in # flight under the final two quadrants. - rocdl.s_wait_dscnt(_bank_half_wn * 4 + _b_half_scale_loads) + rocdl.s_wait_dscnt(_fp4_half_wn * 4 + _b_half_scale_loads) else: rocdl.s_wait_dscnt(0) _emit_group( - _bank_group_size * 2, + _fp4_half_wn, 0, a_top_frags, b_right_frags, @@ -1541,8 +1577,8 @@ def _emit_group( b_right_scales, ) _emit_group( - _bank_group_size * 3, - _bank_half_wm, + _fp4_half_wn, + _fp4_half_wm, a_bottom_frags, b_right_frags, a_scales_all, @@ -1565,16 +1601,22 @@ def compute_tile_fp8_quadrant( emit_filler=None, mid_compute_callback=None, late_compute_callback=None, + scale_k_base=None, + pf_a_scales=None, ): current_accs = list(accs_in) + _set_vgpr_a_scales(scale_k_base, pf_a_scales) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + if const_expr(is_mxscale): + as_buf, as_bases = _precompute_as32_bases(lds_as) + bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) + else: + as_buf, as_bases = lds_as, None + bs_buf, bs_bases = ( + lds_bs, + None, + ) # ptpc: B-scale in epilogue, bases unused _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn _b_left_bundle_loads = _b_half_loads + _fp8_b_scale_loads @@ -1593,18 +1635,14 @@ def _load_b_half(wn_base, ks): def _load_a_scales(ks): if const_expr(is_ptpc): return None # PTPC: scale applied in epilogue, not in K-loop - a_scales = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - if const_expr(use_scale_opsel): - return a_scales[::2] - return a_scales + return _load_a_scale_operand(as_buf, as_bases, ks) def _load_b_scales(ks): if const_expr(is_ptpc): return None # PTPC: scale applied in epilogue, not in K-loop - b_scales = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - if const_expr(use_scale_opsel): - return b_scales[::2] - return b_scales + return _load_b_scale_lds( + bs_buf, bs_bases, ks + ) # 32x4; op_sel in _emit_wmma def _load_b_left_bundle(ks): return _load_b_half(0, ks), _load_b_scales(ks) @@ -1676,8 +1714,14 @@ def _emit_group_col( ) b_left_frags, b_scales = _load_b_left_bundle(0) + # Margin = a-top drain depth (b-scale is issued earlier, so it is unrelated); + # keep it at the per-WMMA count so op_sel's fewer b-scale loads don't widen + # keep and race the top-row A frags. + _top_keep_margin = ( + b_scale_load_rep if const_expr(bs32_opsel) else _fp8_b_scale_loads + ) _first_top_row_keep = max( - (_fp8_half_wm - 1) * DS_LOADS_PER_A_FRAG - _fp8_b_scale_loads, 0 + (_fp8_half_wm - 1) * DS_LOADS_PER_A_FRAG - _top_keep_margin, 0 ) _bottom_left_keep = max(_b_half_loads - DS_LOADS_PER_A_FRAG, 0) @@ -1783,17 +1827,20 @@ def compute_tile_fp8_deep_pipeline( a0_prefetch=None, scale_k_base=None, pf_a_scales=None, - pf_b_scales=None, ): current_accs = list(accs_in) + _set_vgpr_a_scales(scale_k_base, pf_a_scales) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + if const_expr(is_mxscale): + as_buf, as_bases = _precompute_as32_bases(lds_as) + bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) + else: + as_buf, as_bases = lds_as, None + bs_buf, bs_bases = ( + lds_bs, + None, + ) # ptpc: B-scale in epilogue, bases unused def load_a_pair(wm_pair, ks): wm_base = wm_pair * _fp8_pair_wm @@ -1809,30 +1856,6 @@ def load_b_pair(wn_pair, ks): for wn_local in range_constexpr(_fp8_pair_wn) ] - def _load_a_scales(ks): - if const_expr(is_ptpc): - return None # PTPC: scale applied in epilogue, not in K-loop - if const_expr(use_buffer_vgpr_scale): - if const_expr(pf_a_scales is not None): - return ( - pf_a_scales # prefetched (issued in the prior compute tile) - ) - return _bvs_load_scales( - _bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, scale_k_base - ) - return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - - def _load_b_scales(ks): - if const_expr(is_ptpc): - return None # PTPC: scale applied in epilogue, not in K-loop - if const_expr(use_buffer_vgpr_scale): - if const_expr(pf_b_scales is not None): - return pf_b_scales - return _bvs_load_scales( - _bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, scale_k_base - ) - return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - def emit_panel_2x2( wm_pair, wn_pair, @@ -1889,8 +1912,9 @@ def emit_panel_2x2_row( for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 - a_scales = _load_a_scales(ks) - b_scales = _load_b_scales(ks) + a_scales, b_scales = _scales_for_emit( + as_buf, as_bases, bs_buf, bs_bases, ks + ) scale_pair = (a_scales, b_scales) b0 = load_b_pair(0, ks) @@ -1983,93 +2007,44 @@ def _prefetch_a2(): return current_accs - def compute_tile_b_streaming( - accs_in, - lds_a, - lds_b, - lds_as, - lds_bs, - emit_filler=None, - mid_compute_callback=None, - ): - """compute_tile counterpart with A held and B streamed.""" - current_accs = list(accs_in) - a_buf, a_bases = _precompute_a_lane_bases(lds_a) - b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) - load_args = (a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases) - - if const_expr(k_wmma_steps == 1): - a_frags, a_scales, b_scales = _load_a_and_scales(*load_args, 0) - return _b_streaming_compute( - current_accs, - b_buf, - b_bases, - a_frags, - a_scales, - b_scales, - 0, - emit_filler=emit_filler, - mid_compute_callback=mid_compute_callback, - ) - - prev_a, prev_as, prev_bs = _load_a_and_scales(*load_args, 0) - for ks in range_constexpr(k_wmma_steps - 1): - current_accs, (prev_a, prev_as, prev_bs) = _b_streaming_compute( - current_accs, - b_buf, - b_bases, - prev_a, - prev_as, - prev_bs, - ks, - next_info=load_args + (ks + 1,), - mid_compute_callback=mid_compute_callback if ks == 0 else None, - ) - return _b_streaming_compute( - current_accs, - b_buf, - b_bases, - prev_a, - prev_as, - prev_bs, - k_wmma_steps - 1, - emit_filler=emit_filler, - ) - def hot_loop_scheduler(): + if const_expr(use_row_major_k_prefetch): + _queue_depth = min(k_wmma_steps, _row_major_k_prefetch_depth + 1) + for _ks in range_constexpr(k_wmma_steps): + if const_expr(_ks == 0): + rocdl.sched_dsrd(_row_major_k_prefetch_bundle_ds * _queue_depth) + elif const_expr(_ks + _queue_depth <= k_wmma_steps): + rocdl.sched_dsrd(_row_major_k_prefetch_bundle_ds) + rocdl.sched_mfma(wmma_n_rep) + rocdl.sched_barrier(0) + return + _half_wm = wmma_m_rep // 2 _half_wmma = _half_wm * wmma_n_rep _b_loads_per_frag = 2 if is_a8w4 else 4 - _scale_dsrd = 0 if is_ptpc else 2 + _scale_dsrd = _scale_ds_loads + _a_half_dsrd = _half_wm * DS_LOADS_PER_A_FRAG for _ks in range_constexpr(k_wmma_steps): if const_expr(_ks == 0): rocdl.sched_dsrd( - wmma_n_rep * _b_loads_per_frag - + _scale_dsrd - + _half_wm * DS_LOADS_PER_A_FRAG + wmma_n_rep * _b_loads_per_frag + _scale_dsrd + _a_half_dsrd ) else: - rocdl.sched_dsrd(_half_wm * DS_LOADS_PER_A_FRAG) + rocdl.sched_dsrd(_a_half_dsrd) rocdl.sched_mfma(_half_wmma) - rocdl.sched_dsrd(_half_wm * DS_LOADS_PER_A_FRAG) + rocdl.sched_dsrd(_a_half_dsrd) rocdl.sched_mfma(_half_wmma) if const_expr(_ks < k_wmma_steps - 1): rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + _scale_dsrd) rocdl.sched_barrier(0) - def hot_loop_scheduler_fp4_bank_friendly(): + def hot_loop_scheduler_fp4_quadrant(): _a_all_loads = wmma_m_rep * DS_LOADS_PER_A_FRAG - _a_scale_loads = (wmma_m_rep + 3) // 4 - _b_half_loads = _bank_half_wn * 4 - _b_half_scale_loads = (_bank_half_b_scale_rep + 3) // 4 - _group_wmma = _bank_group_size + _a_scale_loads = _a_scale_ds + _b_half_loads = _fp4_half_wn * 4 + _b_half_scale_loads = _fp4_half_wn # 32x4: one b32 per 32-N block/WMMA + _group_wmma = _fp4_group_size _right_half_loads = _b_half_loads + _b_half_scale_loads for _ks in range_constexpr(k_wmma_steps): @@ -2092,7 +2067,7 @@ def hot_loop_scheduler_fp4_bank_friendly(): rocdl.sched_barrier(0) def hot_loop_scheduler_fp8_quadrant(): - _a_scale_loads = 0 if is_ptpc else (wmma_m_rep + 3) // 4 + _a_scale_loads = _a_scale_ds _a_top_loads = _fp8_half_wm * DS_LOADS_PER_A_FRAG _a_bottom_loads = _a_top_loads _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn @@ -2174,20 +2149,9 @@ def compute_tile_scheduled( a0_prefetch=None, scale_k_base=None, pf_a_scales=None, - pf_b_scales=None, ): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): - return compute_tile_b_streaming( - accs_in, - lds_a, - lds_b, - lds_as, - lds_bs, - emit_filler=emit_filler, - mid_compute_callback=mid_compute_callback, - ) - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): - return compute_tile_fp4_bank_friendly( + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_QUADRANT): + return compute_tile_fp4_quadrant( accs_in, lds_a, lds_b, @@ -2195,6 +2159,8 @@ def compute_tile_scheduled( lds_bs, emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, + scale_k_base=scale_k_base, + pf_a_scales=pf_a_scales, ) if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT): return compute_tile_fp8_quadrant( @@ -2206,6 +2172,8 @@ def compute_tile_scheduled( emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, late_compute_callback=late_compute_callback, + scale_k_base=scale_k_base, + pf_a_scales=pf_a_scales, ) if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): return compute_tile_fp8_deep_pipeline( @@ -2220,7 +2188,6 @@ def compute_tile_scheduled( a0_prefetch=a0_prefetch, scale_k_base=scale_k_base, pf_a_scales=pf_a_scales, - pf_b_scales=pf_b_scales, ) return compute_tile( accs_in, @@ -2230,35 +2197,14 @@ def compute_tile_scheduled( lds_bs, emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, + late_compute_callback=late_compute_callback, + scale_k_base=scale_k_base, + pf_a_scales=pf_a_scales, ) - def hot_loop_scheduler_b_streaming(): - """hot_loop_scheduler counterpart for B-streaming.""" - _front_wn = (wmma_n_rep + 1) // 2 - _back_wn = wmma_n_rep - _front_wn - _a_loads_total = wmma_m_rep * DS_LOADS_PER_A_FRAG - _front_b_loads = _front_wn * _b_frag_loads_per_wn - _back_b_loads = _back_wn * _b_frag_loads_per_wn - _next_ks_loads = _a_loads_total + _scale_ds_loads - - for _ks in range_constexpr(k_wmma_steps): - if const_expr(_ks == 0): - rocdl.sched_dsrd(_next_ks_loads + _front_b_loads) - else: - rocdl.sched_dsrd(_front_b_loads) - rocdl.sched_mfma(_front_wn * wmma_m_rep) - if const_expr(_back_wn > 0): - rocdl.sched_dsrd(_back_b_loads) - rocdl.sched_mfma(_back_wn * wmma_m_rep) - if const_expr(_ks < k_wmma_steps - 1): - rocdl.sched_dsrd(_next_ks_loads) - rocdl.sched_barrier(0) - def hot_loop_scheduler_scheduled(): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): - hot_loop_scheduler_b_streaming() - elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): - hot_loop_scheduler_fp4_bank_friendly() + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_QUADRANT): + hot_loop_scheduler_fp4_quadrant() elif const_expr( compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE ): @@ -2397,17 +2343,6 @@ def epilogue_atomic_adds(final_accs, addrs): scf.YieldOp([]) addr_idx += n_slots - def grouped_accs_to_row_major(accs_grouped): - row_major = [None] * n_accs - for group_idx in range_constexpr(n_accs): - row_major[_bank_group_to_row_major[group_idx]] = accs_grouped[group_idx] - return row_major - - def finalize_acc_layout(accs_in): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): - return grouped_accs_to_row_major(accs_in) - return accs_in - def epilogue_load_ptpc_scales(): # PTPC scales: sa[M] per-token (scalar per wm), sb[N] per-channel # (8 contiguous N cols per wn). Both fp32, constant along K. @@ -2528,8 +2463,8 @@ def _l2_prefetch(k_base): ] if const_expr(is_ptpc): # PTPC applies sa*sb in the epilogue from global memory: no scale LDS. - # Alias the scale stage handles to A/B so the shared plumbing stays - # valid; for PTPC they are never written (no scale TDM) or read. + # Alias the scale stage handles to A/B so the shared plumbing stays valid; + # for PTPC they are never written (no scale TDM) or read. stages_as = stages_a stages_bs = stages_b else: @@ -2570,7 +2505,7 @@ def _l2_prefetch(k_base): extract_lds_base_idx(stages_bs[i]) for i in range_constexpr(num_buffers) ] - if const_expr(use_tdm_store): + if const_expr(tdm_store_enabled): d_lds_base_ptr = arena_base_ptr d_lds_f16_count = total_d_bytes // 2 d_smem = SmemPtr( @@ -2599,7 +2534,7 @@ def _l2_prefetch(k_base): ) + arith.index(d_output_off) warp_m_off_sgpr = wave_m_sgpr * arith.index(warp_tile_m) warp_n_off_sgpr = wave_n_sgpr * arith.index(warp_tile_n) - d_desc = _make_tdm_desc( + d_desc = tdm_ops.make_tensor_descriptor_2d( global_ptr=arg_c, lds_memref=d_lds_base_ptr, global_offset=(blk_m + warp_m_off_sgpr, blk_n + warp_n_off_sgpr), @@ -2634,10 +2569,11 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): stages_b_lds_addr.append( _dg0_lane(make_desc_b(stages_b_mem[i], arith.index(0)), 1) ) - if const_expr(not is_ptpc): + if const_expr(use_ascale_shuffled_tdm): stages_as_lds_addr.append( _dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1) ) + if const_expr(is_mxscale): stages_bs_lds_addr.append( _dg0_lane(make_desc_bs(stages_bs_mem[i], arith.index(0)), 1) ) @@ -2651,127 +2587,140 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): stages_bs_lds_addr = stages_b_lds_addr desc_as_init = desc_a_init desc_bs_init = desc_b_init + elif const_expr(use_ascale_vgpr): + # A-scale is not a TDM tensor in the VGPR path. Alias slot 2 so the + # generic 4-way selector stays well-formed; it is predicated off. + stages_as_lds_addr = stages_a_lds_addr + desc_as_init = desc_a_init + desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) else: desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) - if const_expr(use_ab_half_split): - stages_a0_lds_addr = [] - stages_b0_lds_addr = [] - stages_a1_lds_addr = [] - stages_b1_lds_addr = [] - for i in range_constexpr(num_buffers): - stages_a0_lds_addr.append( - _dg0_lane(make_desc_a_half(stages_a_mem[i], arith.index(0), 0), 1) - ) - stages_b0_lds_addr.append( - _dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 0), 1) - ) - stages_a1_lds_addr.append( - _dg0_lane(make_desc_a_half(stages_a_mem[i], arith.index(0), 1), 1) - ) - stages_b1_lds_addr.append( - _dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 1), 1) - ) - - desc_a0_init = make_desc_a_half(stages_a_mem[0], split_k_base, 0) - desc_b0_init = make_desc_b_half(stages_b_mem[0], split_k_base, 0) - desc_a1_init = make_desc_a_half(stages_a_mem[0], split_k_base, 1) - desc_b1_init = make_desc_b_half(stages_b_mem[0], split_k_base, 1) adv_a_i32 = fx.Int32(tile_k // PACK_FACTOR_A) adv_b_i32 = fx.Int32(packed_tile_k_b * 16) - adv_as_i32 = fx.Int32(tile_k // SCALE_BLOCK * wmma_m_rep) - adv_bs_i32 = fx.Int32(tile_k // SCALE_BLOCK * b_scale_load_rep) + # 32x4 scale TDM descriptors advance one tile's K-blocks per K-step. + adv_as_i32 = fx.Int32( + as32_lds_row_stride + if use_ascale_shuffled_tdm + else tile_k // SCALE_BLOCK * wmma_m_rep + ) + adv_bs_i32 = fx.Int32( + bs32_lds_row_stride + if is_mxscale + else tile_k // SCALE_BLOCK * b_scale_load_rep + ) - pred_const = fx.Int32(1) - if const_expr(wave_specialized_tdm): - _drop_scale_waves = is_ptpc or ( - use_buffer_vgpr_scale and not use_ab_half_split - ) + _drop_scale_waves = is_ptpc + if const_expr(use_ascale_shuffled_tdm): + _active_wave_limit = min(num_warps, 4) + elif const_expr(use_ascale_vgpr): + _active_wave_limit = min(num_warps, 3) + else: _active_wave_limit = 2 if _drop_scale_waves else 4 - active_pred_const = arith.select( - tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0) - ) + active_pred_const = arith.select( + tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0) + ) - def _select4(values): - return _select_wave_tdm_value( - values[0], values[1], values[2], values[3] - ) + def _select4(values): + return _select_wave_tdm_value(values[0], values[1], values[2], values[3]) - def _desc_lanes(descs, lane): - return [_dg0_lane(desc, lane) for desc in descs] + def _desc_lanes(descs, lane): + return [_dg0_lane(desc, lane) for desc in descs] - def _select_active_tdm(stage_lds_addrs, descs, advs): - active_stages = [ - _select_wave_tdm_value( - stage_lds_addrs[0][i], - stage_lds_addrs[1][i], - stage_lds_addrs[2][i], - stage_lds_addrs[3][i], - ) - for i in range_constexpr(num_buffers) - ] - return ( - active_stages, - _select4(_desc_lanes(descs, 2)), - _select4(_desc_lanes(descs, 3)), - _select4([desc.dgroup1 for desc in descs]), - _select4(advs), + def _select_active_tdm(stage_lds_addrs, descs, advs): + active_stages = [ + _select_wave_tdm_value( + stage_lds_addrs[0][i], + stage_lds_addrs[1][i], + stage_lds_addrs[2][i], + stage_lds_addrs[3][i], ) + for i in range_constexpr(num_buffers) + ] + return ( + active_stages, + _select4(_desc_lanes(descs, 2)), + _select4(_desc_lanes(descs, 3)), + _select4([desc.dgroup1 for desc in descs]), + _select4(advs), + ) - else: - active_pred_const = pred_const - - if const_expr(use_ab_half_split): - # All 4 waves load A/B halves: wave0=A0, wave1=B0, wave2=A1, wave3=B1. - # Both halves of A share adv_a (same K-step); both halves of B share adv_b. - ( - active_stage_lds_addr, - active_addr_lo, - active_addr_hi, - active_dgroup1, - active_adv_i32, - ) = _select_active_tdm( - ( - stages_a0_lds_addr, - stages_b0_lds_addr, - stages_a1_lds_addr, - stages_b1_lds_addr, - ), - (desc_a0_init, desc_b0_init, desc_a1_init, desc_b1_init), - (adv_a_i32, adv_b_i32, adv_a_i32, adv_b_i32), + if const_expr(use_ascale_shuffled_tdm): + _tdm_stage_sel = ( + stages_a_lds_addr, + stages_b_lds_addr, + stages_as_lds_addr, + stages_bs_lds_addr, ) - elif const_expr(wave_specialized_tdm): - ( - active_stage_lds_addr, - active_addr_lo, - active_addr_hi, - active_dgroup1, - active_adv_i32, - ) = _select_active_tdm( - ( - stages_a_lds_addr, - stages_b_lds_addr, - stages_as_lds_addr, - stages_bs_lds_addr, - ), - (desc_a_init, desc_b_init, desc_as_init, desc_bs_init), - (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32), + _tdm_desc_sel = (desc_a_init, desc_b_init, desc_as_init, desc_bs_init) + _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32) + elif const_expr(use_ascale_vgpr): + # wave2 is B-scale; wave3 is a predicated padding slot for the 4-way selector. + _tdm_stage_sel = ( + stages_a_lds_addr, + stages_b_lds_addr, + stages_bs_lds_addr, + stages_bs_lds_addr, ) + _tdm_desc_sel = (desc_a_init, desc_b_init, desc_bs_init, desc_bs_init) + _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_bs_i32, adv_bs_i32) else: - addr_lo_a = _dg0_lane(desc_a_init, 2) - addr_hi_a = _dg0_lane(desc_a_init, 3) - addr_lo_b = _dg0_lane(desc_b_init, 2) - addr_hi_b = _dg0_lane(desc_b_init, 3) - addr_lo_as = _dg0_lane(desc_as_init, 2) - addr_hi_as = _dg0_lane(desc_as_init, 3) - addr_lo_bs = _dg0_lane(desc_bs_init, 2) - addr_hi_bs = _dg0_lane(desc_bs_init, 3) - - dgroup1_a = desc_a_init.dgroup1 - dgroup1_b = desc_b_init.dgroup1 - dgroup1_as = desc_as_init.dgroup1 - dgroup1_bs = desc_bs_init.dgroup1 + _tdm_stage_sel = ( + stages_a_lds_addr, + stages_b_lds_addr, + stages_as_lds_addr, + stages_bs_lds_addr, + ) + _tdm_desc_sel = (desc_a_init, desc_b_init, desc_as_init, desc_bs_init) + _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32) + ( + active_stage_lds_addr, + active_addr_lo, + active_addr_hi, + active_dgroup1, + active_adv_i32, + ) = _select_active_tdm(_tdm_stage_sel, _tdm_desc_sel, _tdm_adv_sel) + if const_expr(secondary_scale_tdm): + if const_expr(two_wave_bscale): + sec_pred_const = arith.select(tdm_wave_is_a, fx.Int32(1), fx.Int32(0)) + sec_stage_lds_addr = stages_bs_lds_addr + sec_addr_hi = _dg0_lane(desc_bs_init, 3) + sec_dgroup1 = desc_bs_init.dgroup1 + sec_adv_i32 = adv_bs_i32 + sec_addr_lo_init = _dg0_lane(desc_bs_init, 2) + elif const_expr(two_wave_scale): + sec_pred_const = arith.select( + tdm_wave_id < fx.Int32(2), fx.Int32(1), fx.Int32(0) + ) + sec_stage_lds_addr = [ + arith.select( + tdm_wave_is_a, stages_bs_lds_addr[i], stages_as_lds_addr[i] + ) + for i in range_constexpr(num_buffers) + ] + sec_addr_hi = arith.select( + tdm_wave_is_a, + _dg0_lane(desc_bs_init, 3), + _dg0_lane(desc_as_init, 3), + ) + sec_dgroup1 = arith.select( + tdm_wave_is_a, desc_bs_init.dgroup1, desc_as_init.dgroup1 + ) + sec_adv_i32 = arith.select(tdm_wave_is_a, adv_bs_i32, adv_as_i32) + sec_addr_lo_init = arith.select( + tdm_wave_is_a, + _dg0_lane(desc_bs_init, 2), + _dg0_lane(desc_as_init, 2), + ) + else: + # 3-wave compatibility: wave2 carries A-scale, wave0 carries B-scale. + sec_pred_const = arith.select(tdm_wave_is_a, fx.Int32(1), fx.Int32(0)) + sec_stage_lds_addr = stages_bs_lds_addr + sec_addr_hi = _dg0_lane(desc_bs_init, 3) + sec_dgroup1 = desc_bs_init.dgroup1 + sec_adv_i32 = adv_bs_i32 + sec_addr_lo_init = _dg0_lane(desc_bs_init, 2) def _pipeline_fence(outstanding=0): pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) @@ -2779,62 +2728,52 @@ def _pipeline_fence(outstanding=0): def _pipeline_fence_signal(outstanding=0): pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) - if const_expr(wave_specialized_tdm): - - def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): - dg0 = _pack_dg0( - active_pred_const, - active_stage_lds_addr[load_stage], - addr_box[0], - active_addr_hi, + def _issue_active_tdm(load_stage, addr_box, k_prefetch=None, sec_box=None): + dg0 = _pack_dg0( + active_pred_const, + active_stage_lds_addr[load_stage], + addr_box[0], + active_addr_hi, + ) + tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) + addr_box[0] = addr_box[0] + active_adv_i32 + if const_expr(secondary_scale_tdm): + dg0s = _pack_dg0( + sec_pred_const, + sec_stage_lds_addr[load_stage], + sec_box[0], + sec_addr_hi, ) - tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) - addr_box[0] = addr_box[0] + active_adv_i32 - if k_prefetch is not None: - _l2_prefetch(k_prefetch) + tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0s, sec_dgroup1)) + sec_box[0] = sec_box[0] + sec_adv_i32 + if k_prefetch is not None: + _l2_prefetch(k_prefetch) # Prologue - if const_expr(wave_specialized_tdm): - for i in range_constexpr(pre_loaded): - addr_box = [active_addr_lo] + if const_expr(secondary_scale_tdm): + active_sec_lo = sec_addr_lo_init + for i in range_constexpr(pre_loaded): + addr_box = [active_addr_lo] + if const_expr(secondary_scale_tdm): + sec_box = [active_sec_lo] + _issue_active_tdm(i, addr_box, sec_box=sec_box) + active_sec_lo = sec_box[0] + else: _issue_active_tdm(i, addr_box) - active_addr_lo = addr_box[0] - else: - for i in range_constexpr(pre_loaded): - dg0_a = _pack_dg0( - pred_const, stages_a_lds_addr[i], addr_lo_a, addr_hi_a - ) - dg0_b = _pack_dg0( - pred_const, stages_b_lds_addr[i], addr_lo_b, addr_hi_b - ) - dg0_as = _pack_dg0( - pred_const, stages_as_lds_addr[i], addr_lo_as, addr_hi_as - ) - dg0_bs = _pack_dg0( - pred_const, stages_bs_lds_addr[i], addr_lo_bs, addr_hi_bs - ) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) - - addr_lo_a = addr_lo_a + adv_a_i32 - addr_lo_b = addr_lo_b + adv_b_i32 - addr_lo_as = addr_lo_as + adv_as_i32 - addr_lo_bs = addr_lo_bs + adv_bs_i32 - + active_addr_lo = addr_box[0] + _bvs_tail_seed = [] + _bvs_tail_issue_start = loop_iters * num_buffers if const_expr(_bvs_active): - # Prologue: prefetch the first _bvs_D K-tiles (global->VGPR). Carried as - # FLAT lists of i32 (list-of-tuples can't be loop-carried). + _bvs_initial_depth = _bvs_D if loop_iters > 0 else min(_bvs_D, num_k_tiles) _bvs_pf = [ _bvs_prefetch(split_k_base + arith.index(_d * tile_k)) - for _d in range(_bvs_D) + for _d in range(_bvs_initial_depth) ] - _bvs_ra = [_v for (_a, _b) in _bvs_pf for _v in _a] - _bvs_rb = [_v for (_a, _b) in _bvs_pf for _v in _b] + if const_expr(loop_iters > 0): + _bvs_ra = [_v for _a in _bvs_pf for _v in _a] + else: + _bvs_tail_seed = list(_bvs_pf) + _bvs_tail_issue_start = _bvs_initial_depth _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) @@ -2842,186 +2781,108 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): # This overlaps TDM DMA with the remaining WMMA instructions, _fence_outstanding = TDM_LOADS_PER_STEP * (num_buffers - 2) - if const_expr(loop_iters > 0 and use_ws_tdm_split_signal_overlap): + if const_expr(loop_iters > 0 and use_tdm_late_signal_overlap): _pipeline_fence_signal(outstanding=_fence_outstanding) if const_expr(loop_iters > 0): - if const_expr(wave_specialized_tdm): - init_args = list(accs) + [active_addr_lo] + init_args = list(accs) + [active_addr_lo] + if const_expr(secondary_scale_tdm): + init_args = init_args + [active_sec_lo] + if const_expr(_bvs_active): + init_args = init_args + _bvs_ra + + for loop_iter, state in range(0, loop_iters, 1, init=init_args): + accs_in = list(state[:n_accs]) + cur_addr_lo = state[n_accs] + _state_off = n_accs + 1 + if const_expr(secondary_scale_tdm): + cur_sec_lo = state[_state_off] + _state_off = _state_off + 1 if const_expr(_bvs_active): - init_args = init_args + _bvs_ra + _bvs_rb - - for loop_iter, state in range(0, loop_iters, 1, init=init_args): - accs_in = list(state[:n_accs]) - cur_addr_lo = state[n_accs] - if const_expr(_bvs_active): - _ra0 = n_accs + 1 - _ring_a = list(state[_ra0 : _ra0 + _bvs_D * wmma_m_rep]) - _rb0 = _ra0 + _bvs_D * wmma_m_rep - _ring_b = list(state[_rb0 : _rb0 + _bvs_D * b_scale_load_rep]) - - for buf_idx in range_constexpr(num_buffers): - load_stage = (buf_idx + num_buffers - 1) % num_buffers - - addr_box = [cur_addr_lo] - - def _mid_tdm_ws( - _ls=load_stage, - _ab=addr_box, - _k_off=( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index(buf_idx * tile_k) - ), - ): - _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) - - if const_expr(not use_ws_tdm_split_signal_overlap): - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) + _ra0 = _state_off + _ring_a = list(state[_ra0 : _ra0 + _bvs_D * _vs_tile_a]) + _state_off = _ra0 + _bvs_D * _vs_tile_a + + for buf_idx in range_constexpr(num_buffers): + load_stage = (buf_idx + num_buffers - 1) % num_buffers + + addr_box = [cur_addr_lo] + sec_box = [cur_sec_lo] if secondary_scale_tdm else None + + def _mid_tdm_ws( + _ls=load_stage, + _ab=addr_box, + _sb=sec_box, + _k_off=( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index(buf_idx * tile_k) + ), + ): + _issue_active_tdm(_ls, _ab, k_prefetch=_k_off, sec_box=_sb) - _late_tdm_ws_fence_signal = None - if const_expr(use_ws_tdm_split_signal_overlap): + if const_expr(not use_tdm_late_signal_overlap): + _pipeline_fence_signal(outstanding=_fence_outstanding) + pipeline_fence_wait(use_cluster=use_cluster) - def _late_tdm_ws_split_signal(): - _pipeline_fence_signal(outstanding=_fence_outstanding) + _late_tdm_ws_fence_signal = None + if const_expr(use_tdm_late_signal_overlap): - _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal + def _late_tdm_ws_split_signal(): + _pipeline_fence_signal(outstanding=_fence_outstanding) - a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) - # Consume scale prefetched _bvs_D K-tiles ago; issue the - # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). - # NOTE: must stay AFTER the fence; issuing the scale - # buffer_loads before the cluster barrier hangs the vgpr path. - if const_expr(_bvs_active): - _cur_a = _ring_a[:wmma_m_rep] - _cur_b = _ring_b[:b_scale_load_rep] - _next_kb = ( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index((buf_idx + _bvs_D) * tile_k) - ) - _na, _nb2 = _bvs_prefetch(_next_kb) - _ring_a = _ring_a[wmma_m_rep:] + list(_na) - _ring_b = _ring_b[b_scale_load_rep:] + list(_nb2) - else: - _cur_a = None - _cur_b = None - - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_ws, - late_compute_callback=_late_tdm_ws_fence_signal, - a0_prefetch=a0_prefetch, - pf_a_scales=_cur_a, - pf_b_scales=_cur_b, - ) - cur_addr_lo = addr_box[0] - hot_loop_scheduler_scheduled() + _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) + rocdl.sched_barrier(0) if const_expr(_bvs_active): - _bvs_yield = _ring_a + _ring_b - else: - _bvs_yield = [] - results = yield list(accs_in) + [cur_addr_lo] + _bvs_yield - - accs = list(results[:n_accs]) - active_addr_lo = results[n_accs] - else: - init_args = list(accs) + [addr_lo_a, addr_lo_b, addr_lo_as, addr_lo_bs] - - for loop_iter, state in range(0, loop_iters, 1, init=init_args): - accs_in = list(state[:n_accs]) - cur_lo_a = state[n_accs] - cur_lo_b = state[n_accs + 1] - cur_lo_as = state[n_accs + 2] - cur_lo_bs = state[n_accs + 3] - - for buf_idx in range_constexpr(num_buffers): - load_stage = (buf_idx + num_buffers - 1) % num_buffers - - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - - addr_boxes = [[cur_lo_a], [cur_lo_b], [cur_lo_as], [cur_lo_bs]] - - def _mid_tdm_nws( - _ls=load_stage, - _ab=addr_boxes, - _k_off=( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index(buf_idx * tile_k) - ), - ): - dg0_a = _pack_dg0( - pred_const, stages_a_lds_addr[_ls], _ab[0][0], addr_hi_a - ) - dg0_b = _pack_dg0( - pred_const, stages_b_lds_addr[_ls], _ab[1][0], addr_hi_b - ) - dg0_as = _pack_dg0( - pred_const, - stages_as_lds_addr[_ls], - _ab[2][0], - addr_hi_as, - ) - dg0_bs = _pack_dg0( - pred_const, - stages_bs_lds_addr[_ls], - _ab[3][0], - addr_hi_bs, - ) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) - _ab[0][0] = _ab[0][0] + adv_a_i32 - _ab[1][0] = _ab[1][0] + adv_b_i32 - _ab[2][0] = _ab[2][0] + adv_as_i32 - _ab[3][0] = _ab[3][0] + adv_bs_i32 - _l2_prefetch(_k_off) - - a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_nws, - a0_prefetch=a0_prefetch, + _cur_a = _ring_a[:_vs_tile_a] + _next_kb = ( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index((buf_idx + _bvs_D) * tile_k) ) - cur_lo_a = addr_boxes[0][0] - cur_lo_b = addr_boxes[1][0] - cur_lo_as = addr_boxes[2][0] - cur_lo_bs = addr_boxes[3][0] - hot_loop_scheduler_scheduled() - - results = yield list(accs_in) + [ - cur_lo_a, - cur_lo_b, - cur_lo_as, - cur_lo_bs, - ] - - accs = list(results[:n_accs]) - addr_lo_a = results[n_accs] - addr_lo_b = results[n_accs + 1] - addr_lo_as = results[n_accs + 2] - addr_lo_bs = results[n_accs + 3] - + _ring_a = _ring_a[_vs_tile_a:] + list(_bvs_prefetch(_next_kb)) + else: + _cur_a = None + + accs_in = compute_tile_scheduled( + accs_in, + stages_a_idx[buf_idx], + stages_b_idx[buf_idx], + stages_as_idx[buf_idx], + stages_bs_idx[buf_idx], + mid_compute_callback=_mid_tdm_ws, + late_compute_callback=_late_tdm_ws_fence_signal, + a0_prefetch=a0_prefetch, + pf_a_scales=_cur_a, + ) + cur_addr_lo = addr_box[0] + if const_expr(secondary_scale_tdm): + cur_sec_lo = sec_box[0] + hot_loop_scheduler_scheduled() + + _sec_yield = [cur_sec_lo] if secondary_scale_tdm else [] + _bvs_yield = _ring_a if _bvs_active else [] + results = yield list(accs_in) + [cur_addr_lo] + _sec_yield + _bvs_yield + + accs = list(results[:n_accs]) + active_addr_lo = results[n_accs] + _result_off = n_accs + 1 + if const_expr(secondary_scale_tdm): + active_sec_lo = results[n_accs + 1] + _result_off = _result_off + 1 + if const_expr(_bvs_active): + _bvs_tail_flat = list( + results[_result_off : _result_off + _bvs_D * _vs_tile_a] + ) + _bvs_tail_seed = [ + _bvs_tail_flat[_d * _vs_tile_a : (_d + 1) * _vs_tile_a] + for _d in range(_bvs_D) + ] + _bvs_tail_issue_start = loop_iters * num_buffers + _bvs_D # Tail — same acc_mixed pattern: fence at top, TDM mid-compute. - if const_expr(loop_iters > 0 and use_ws_tdm_split_signal_overlap): + if const_expr(loop_iters > 0 and use_tdm_late_signal_overlap): pipeline_fence_wait(use_cluster=use_cluster) if const_expr(loop_iters > 0): _pipeline_fence(outstanding=0) @@ -3035,7 +2896,6 @@ def _load_ptpc_scales_once(): _ptpc_scale_box[0] = epilogue_load_ptpc_scales() _tail_had_load = False - # Tail K-tile index, so the VGPR-path scale buffer_load uses the right k_base. _bvs_tail_kt = [loop_iters * num_buffers] def _bvs_tail_kb(): @@ -3045,12 +2905,29 @@ def _bvs_tail_kb(): _bvs_tail_kt[0] += 1 return kb + _bvs_tail_ring = list(_bvs_tail_seed) + _bvs_tail_issue_kt = [_bvs_tail_issue_start] + + def _bvs_tail_issue_one(): + if const_expr(_bvs_active and _bvs_tail_issue_kt[0] < num_k_tiles): + kb = split_k_base + arith.index(_bvs_tail_issue_kt[0] * tile_k) + _bvs_tail_ring.append(_bvs_prefetch(kb)) + _bvs_tail_issue_kt[0] += 1 + + def _bvs_tail_scales(): + if const_expr(_bvs_active): + return None, _bvs_tail_ring.pop(0) + return _bvs_tail_kb(), None + + if const_expr(_bvs_active): + rocdl.sched_barrier(0) + for _load_stage, _compute_stage, _outstanding in tail_plan: - _entry_kb = _bvs_tail_kb() + _entry_kb, _pf_a_scales = _bvs_tail_scales() if const_expr(_outstanding == -1): if const_expr(_tail_had_load): _pipeline_fence(outstanding=0) - if const_expr(use_tdm_store): + if const_expr(tdm_store_enabled): a0_prefetch = maybe_prefetch_fp8_deep_a0( stages_a_idx[_compute_stage] ) @@ -3063,6 +2940,7 @@ def _bvs_tail_kb(): emit_filler=(_load_ptpc_scales_once if is_ptpc else None), a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, + pf_a_scales=_pf_a_scales, ) else: @@ -3082,6 +2960,7 @@ def _emit_epi_addrs(): emit_filler=_emit_epi_addrs, a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, + pf_a_scales=_pf_a_scales, ) else: _pipeline_fence_signal(outstanding=_outstanding) @@ -3090,56 +2969,19 @@ def _emit_epi_addrs(): _tail_mid_cb = None if const_expr(_load_stage is not None): _tail_had_load = True - if const_expr(wave_specialized_tdm): - _tail_addr_box = [active_addr_lo] + _tail_addr_box = [active_addr_lo] + _tail_sec_box = [active_sec_lo] if secondary_scale_tdm else None - def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box): - _issue_active_tdm(_ls, _ab) + def _tail_mid_ws( + _ls=_load_stage, _ab=_tail_addr_box, _sb=_tail_sec_box + ): + _issue_active_tdm(_ls, _ab, sec_box=_sb) - _tail_mid_cb = _tail_mid_ws - else: - _tail_ab = [ - [addr_lo_a], - [addr_lo_b], - [addr_lo_as], - [addr_lo_bs], - ] - - def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): - dg0_a = _pack_dg0( - pred_const, stages_a_lds_addr[_ls], _ab[0][0], addr_hi_a - ) - dg0_b = _pack_dg0( - pred_const, stages_b_lds_addr[_ls], _ab[1][0], addr_hi_b - ) - dg0_as = _pack_dg0( - pred_const, - stages_as_lds_addr[_ls], - _ab[2][0], - addr_hi_as, - ) - dg0_bs = _pack_dg0( - pred_const, - stages_bs_lds_addr[_ls], - _ab[3][0], - addr_hi_bs, - ) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) - _ab[0][0] = _ab[0][0] + adv_a_i32 - _ab[1][0] = _ab[1][0] + adv_b_i32 - _ab[2][0] = _ab[2][0] + adv_as_i32 - _ab[3][0] = _ab[3][0] + adv_bs_i32 - - _tail_mid_cb = _tail_mid_nws + _tail_mid_cb = _tail_mid_ws a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) rocdl.sched_barrier(0) + _bvs_tail_issue_one() accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], @@ -3149,21 +2991,16 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): mid_compute_callback=_tail_mid_cb, a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, + pf_a_scales=_pf_a_scales, ) if const_expr(_load_stage is not None): - if const_expr(wave_specialized_tdm): - active_addr_lo = _tail_addr_box[0] - else: - addr_lo_a = _tail_ab[0][0] - addr_lo_b = _tail_ab[1][0] - addr_lo_as = _tail_ab[2][0] - addr_lo_bs = _tail_ab[3][0] + active_addr_lo = _tail_addr_box[0] + if const_expr(secondary_scale_tdm): + active_sec_lo = _tail_sec_box[0] hot_loop_scheduler_scheduled() - accs = finalize_acc_layout(accs) - if const_expr(is_ptpc): _load_ptpc_scales_once() _ptpc_sa, _ptpc_sb = _ptpc_scale_box[0] @@ -3187,10 +3024,7 @@ def _emit_buffer_store(): else: epilogue_stores(accs, epi_addrs_box[0]) - if const_expr(use_tdm_store): - # Full M-tiles take the fast TDM store; the partial last M-tile - # (rows >= M) falls back to the buffer store, whose num_records clip - # drops the OOB rows. + if const_expr(tdm_store_enabled): full_tile = (blk_m + arith.index(tile_m)) <= m_idx if_op = scf.IfOp(full_tile, [], has_else=True) with ir.InsertionPoint(if_op.then_block): @@ -3217,17 +3051,15 @@ def _emit_buffer_store(): l2_prefetch_distance, cluster_m, cluster_n, - use_tdm_store, + tdm_store_enabled, out_dtype, inst_prefetch, - wave_specialized_tdm, split_k, - use_scale_opsel, expert_sched_mode, atomic_barrier_enable, - b_streaming, - scale_load_path, - fp8_schedule, + ascale_load_path, + _row_major_k_prefetch_depth, + _bvs_D, ) @flyc.jit @@ -3250,12 +3082,11 @@ def launch_mxscale_gemm( arena_alloc.finalize() gx = (i32_m + (tile_m - 1)) // tile_m - gy = (i32_n + (tile_n - 1)) // tile_n + gy = N // tile_n gz = split_k if const_expr(use_cluster): - # Cluster launch needs a cluster-divisible grid; the extra M-tiles - # are fully OOB (rows >= M) and the kernel clips them. + # Cluster launch needs a cluster-divisible grid gx = ((gx + (cluster_m - 1)) // cluster_m) * cluster_m cluster_arg = (cluster_m, cluster_n, 1) if use_cluster else None @@ -3336,17 +3167,11 @@ def compile_ptpc_gemm( the epilogue in fp32. split_k>1 is supported (atomic add path). data_format: "fp8" (FP8 act + FP8 weight) or "a8w4" (FP8 act + FP4 weight). - wave_specialized_tdm=True requires m_warp*n_warp >= 2. + Requires m_warp*n_warp >= 2 (wave-specialized TDM). """ return compile_fp8fp4_gemm( data_format=data_format, scale_mode="ptpc", - b_streaming=False, - wave_specialized_tdm=True, - use_scale_opsel=False, - fp8_schedule="auto", - scale_load_path="tdm", - use_tdm_store=(split_k == 1), N=N, K=K, tile_m=tile_m, diff --git a/aiter/ops/flydsl/kernels/tdm_oob.py b/aiter/ops/flydsl/kernels/tdm_oob.py deleted file mode 100644 index b77963621d..0000000000 --- a/aiter/ops/flydsl/kernels/tdm_oob.py +++ /dev/null @@ -1,313 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -"""Vendored OOB-capable TDM 2D descriptor builder for gfx1250. - -This is a faithful copy of ``flydsl.expr.rocdl.tdm_ops.make_tensor_descriptor_2d`` -as of the FlyDSL "add M out-of-bounds support" change, carried here so the -non-tile-aligned-M (OOB) GEMM path works against the *older* flydsl this aiter -build pins, whose ``make_tensor_descriptor_2d`` predates the ``oob_outer_bound`` -argument. - -The kernel only routes through this fallback when the installed flydsl lacks -native ``oob_outer_bound`` support (see ``_make_tdm_desc`` in -``gemm_fp8fp4_gfx1250``); when flydsl has it, the native builder is used. - -To stay robust across flydsl internal-layout changes, every low-level symbol is -sourced from the installed ``tdm_ops`` module namespace (i.e. whatever that -module successfully imported) rather than re-imported from private paths. The -only behavioural delta vs. the pinned builder is the ``oob_outer_bound`` branch -that computes a runtime ``tensor_dim1``; with ``oob_outer_bound=None`` the output -is byte-identical to the original path. -""" - -from __future__ import annotations - -import math -from typing import Tuple, Union - -from flydsl.expr.rocdl import tdm_ops as _tdm - -# Reuse whatever the installed tdm_ops bound — keeps us in lock-step with the -# pinned flydsl's lower-level primitives instead of guessing private paths. -ir = _tdm.ir -std_arith = _tdm.std_arith -llvm_dialect = _tdm.llvm_dialect -memref_dialect = _tdm.memref_dialect -arith = _tdm.arith -vector = _tdm.vector -_raw = _tdm._raw -T = _tdm.T -_ArithValue = _tdm._ArithValue -compute_warp_distribution = _tdm.compute_warp_distribution -compute_padding_encoding = _tdm.compute_padding_encoding -TDMDescriptor2D = _tdm.TDMDescriptor2D - - -def make_tensor_descriptor_2d( - global_ptr, - lds_memref, - global_offset: Tuple, - tensor_shape: Tuple[int, int], - strides: Tuple[int, int], - tile_shape: Tuple[int, int], - elem_bytes: int = 2, - pad_interval: int = 0, - pad_amount: int = 0, - num_warps: int = 1, - cache_policy: int = 0, - pred: int = 1, - workgroup_mask: Union[int, "ir.Value"] = 0, - lds_byte_offset=None, - for_store: bool = False, - atomic_barrier_enable: bool = False, - early_timeout: bool = False, - oob_outer_bound=None, -) -> "TDMDescriptor2D": - """Build a 2D TDM descriptor (vendored, OOB-capable). - - See ``flydsl.expr.rocdl.tdm_ops.make_tensor_descriptor_2d`` for the full - argument reference. ``oob_outer_bound`` is the runtime outer-dim global - extent (real M for a row-major A/C); when given, ``tensor_dim1`` is set to - the tile-start-relative remaining extent - ``max(0, oob_outer_bound - (outer_off + warp_off_outer))`` while - ``tile_dim1`` stays the full per-warp tile, so the partial last tile exceeds - the tensor bound and the hardware OOB-handles the overhang (fault-safe load, - zero-fill in LDS). Accepts a Python int or an i32/index ir.Value. ``None`` - keeps ``tensor_dim1 == tile_dim1`` (OOB off) — byte-identical to the - non-OOB path. - """ - from flydsl._mlir.dialects import fly as _fly_d - - outer_stride, inner_stride = strides - outer_tile, inner_tile = tile_shape - outer_off, inner_off = global_offset - - # The outer (leading-dim) stride may be a compile-time int or a runtime - # i32/index ir.Value (strided A/C, e.g. a row-slice whose row pitch exceeds - # the logical inner extent). Normalise to an index value for address math and - # an i32 value for the descriptor's stride field (sgpr5). - if isinstance(outer_stride, int): - outer_stride_idx = arith.index(outer_stride) - outer_stride_is_runtime = False - else: - os_val = ( - outer_stride.ir_value() - if hasattr(outer_stride, "ir_value") - else outer_stride - ) - if not isinstance(os_val, ir.Value): - raise TypeError( - f"outer stride must be int or i32/index ir.Value, " - f"got {type(outer_stride).__name__}" - ) - if isinstance(os_val.type, ir.IndexType): - # Wrap raw ir.Value so it supports the _ArithValue ops below (*, cast). - outer_stride_idx = _ArithValue(os_val) - elif isinstance(os_val.type, ir.IntegerType) and os_val.type.width == 32: - outer_stride_idx = arith.index_cast(T.index, os_val) - else: - raise TypeError( - f"outer stride ir.Value must be index or i32, got {os_val.type}" - ) - outer_stride_is_runtime = True - - # -- Warp distribution -- - warps_per_dim, block_per_warp = compute_warp_distribution( - [outer_tile, inner_tile], - num_warps, - ) - bpw_outer, bpw_inner = block_per_warp - warps_dim0 = warps_per_dim[0] - - if num_warps > 1: - # Auto-acquire SGPR wave_id via hardware register (TTMP8[29:25]). - # This keeps the entire descriptor address chain in SALU, - from flydsl.expr import rocdl as _rocdl_ext - - _wid_i32 = _rocdl_ext.wave_id() - wave_id = arith.index_cast(T.index, _wid_i32) - warp_coord_outer = wave_id % arith.index(warps_dim0) - warp_coord_inner = wave_id // arith.index(warps_dim0) - warp_off_outer = warp_coord_outer * arith.index(bpw_outer) - warp_off_inner = warp_coord_inner * arith.index(bpw_inner) - else: - warp_off_outer = arith.index(0) - warp_off_inner = arith.index(0) - - # -- Global address (byte address for descriptor) -- - glb_ptr_type = ir.Type.parse("!llvm.ptr<1>") - i64 = ir.IntegerType.get_signless(64) - a_raw = global_ptr.__extract_to_ir_values__()[0] - glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) - glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) - glb_elem_off = (outer_off + warp_off_outer) * outer_stride_idx + ( - inner_off + warp_off_inner - ) * arith.index(inner_stride) - glb_byte_off = glb_elem_off * arith.index(elem_bytes) - glb_byte_off_i64 = arith.index_cast(T.i64, glb_byte_off) - glb_addr_i64 = glb_base_i64 + glb_byte_off_i64 - - # -- LDS address (byte address within shared memory) -- - lds_base_idx = _ArithValue( - memref_dialect.extract_aligned_pointer_as_index(lds_memref) - ) - # Compute padded LDS stride (elements) for the outer dim - if pad_interval > 0 and pad_amount > 0: - lds_inner_stride = inner_tile + pad_amount # padded row width - else: - lds_inner_stride = inner_tile - lds_warp_elem_off = warp_off_outer * arith.index(lds_inner_stride) + warp_off_inner - lds_warp_byte_off = lds_warp_elem_off * arith.index(elem_bytes) - lds_total_off = lds_base_idx + lds_warp_byte_off - if lds_byte_offset is not None: - lds_total_off = lds_total_off + lds_byte_offset - lds_addr_i32 = arith.index_cast(T.i32, lds_total_off) - - # ================================================================ - # GROUP0 (vector<4xi32>): pred, lds_addr, global_addr_lo/hi - # ================================================================ - g0_s0 = arith.constant(pred, type=T.i32) - g0_s1 = lds_addr_i32 - i32 = ir.IntegerType.get_signless(32) - g0_s2 = _ArithValue(std_arith.TruncIOp(i32, _raw(glb_addr_i64)).result) - hi_raw = _ArithValue(_raw(glb_addr_i64)).shrui(arith.constant(32, type=T.i64)) - g0_s3 = _ArithValue(std_arith.TruncIOp(i32, _raw(hi_raw)).result) | arith.constant( - 1 << 31, type=T.i32 - ) # type field = 2 in [31:30] - dgroup0 = vector.from_elements(T.vec(4, T.i32), [g0_s0, g0_s1, g0_s2, g0_s3]) - - # ================================================================ - # GROUP1 (vector<8xi32>): config + tensor dims + strides + tile - # ================================================================ - # Descriptor dim ordering: dim0=innermost, dim1=outermost - tdim0 = bpw_inner # innermost extent per warp - tdim1 = bpw_outer # outermost extent per warp - tile_d0 = bpw_inner # block dim0 per warp - tile_d1 = bpw_outer # block dim1 per warp - - # Padding can be applied to the LDS address when copying from memory to LDS, - # but not when copying from LDS to memory - # (there is no "de-padding" operation; padding is ignored). - if for_store and pad_interval > 0 and pad_amount > 0: - tile_d0 += pad_amount - pad_interval = 0 - pad_amount = 0 - - # stride_dim0 in descriptor = outermost stride in elements - stride0 = outer_stride - - # data_size = log2(elem_bytes) - data_size_code = int(math.log2(elem_bytes)) - - # Padding encoding - if pad_interval > 0 and pad_amount > 0: - elem_bits = elem_bytes * 8 - enc_interval, enc_amount = compute_padding_encoding( - pad_interval, pad_amount, elem_bits - ) - pad_enable = 1 - else: - enc_interval, enc_amount = 0, 0 - pad_enable = 0 - - # sgpr0: config bitfields - _abe = 1 if atomic_barrier_enable else 0 - _early_timeout = 1 if early_timeout else 0 - g1_s0_upper = ( - (data_size_code << 16) # data_size [17:16] - | (_abe << 18) # atomic_barrier_enable - | (0 << 19) # iterate_enable - | (pad_enable << 20) # pad_enable - | (_early_timeout << 21) # early_timeout - | (enc_interval << 22) # pad_interval [24:22] - | (enc_amount << 25) # pad_amount [31:25] - ) - - if isinstance(workgroup_mask, int): - g1_s0_val = (workgroup_mask & 0xFFFF) | g1_s0_upper - g1_s0 = arith.constant(g1_s0_val, type=T.i32) - else: - upper_const = arith.constant(g1_s0_upper, type=T.i32) - mask_i32 = arith.andi(workgroup_mask, arith.constant(0xFFFF, type=T.i32)) - g1_s0 = arith.ori(upper_const, mask_i32) - - # sgpr1: atomic_barrier_addr[15:0]=0 | tensor_dim0_lo[31:16] - g1_s1 = arith.constant((tdim0 & 0xFFFF) << 16, type=T.i32) - - if oob_outer_bound is None: - # Compile-time tensor_dim1 == tile extent: OOB checking off. - # sgpr2: tensor_dim0_hi[15:0] | tensor_dim1_lo[31:16] - g1_s2 = arith.constant( - ((tdim0 >> 16) & 0xFFFF) | ((tdim1 & 0xFFFF) << 16), - type=T.i32, - ) - # sgpr3: tensor_dim1_hi[15:0] | tile_dim0[31:16] - g1_s3 = arith.constant( - ((tdim1 >> 16) & 0xFFFF) | (tile_d0 << 16), - type=T.i32, - ) - else: - # Runtime tensor_dim1 = max(0, oob_outer_bound - (outer_off + warp_off_outer)), - # tile-start-relative (the descriptor's global address already includes the - # tile/warp start). tile_dim1 (sgpr4) stays the full per-warp tile, so the - # partial last tile exceeds the tensor bound and the HW OOB-handles the - # overhang. tensor_dim0 (innermost) and the tile dims stay compile-time. - if isinstance(oob_outer_bound, int): - ob_i32 = arith.constant(oob_outer_bound, type=T.i32) - else: - ob_i32 = ( - oob_outer_bound.ir_value() - if hasattr(oob_outer_bound, "ir_value") - else oob_outer_bound - ) - if not isinstance(ob_i32, ir.Value): - raise TypeError( - f"oob_outer_bound must be int or i32/index ir.Value, " - f"got {type(oob_outer_bound).__name__}" - ) - if isinstance(ob_i32.type, ir.IndexType): - ob_i32 = arith.index_cast(T.i32, ob_i32) - elif not ( - isinstance(ob_i32.type, ir.IntegerType) and ob_i32.type.width == 32 - ): - raise TypeError( - f"oob_outer_bound ir.Value must be index or i32, got {ob_i32.type}" - ) - start_i32 = arith.index_cast(T.i32, outer_off + warp_off_outer) - tdim1_rt = arith.maxsi( - arith.subi(ob_i32, start_i32), arith.constant(0, type=T.i32) - ) - c16 = arith.constant(16, type=T.i32) - c_mask16 = arith.constant(0xFFFF, type=T.i32) - # sgpr2: tensor_dim0_hi[15:0] (const) | tensor_dim1_lo[31:16] (runtime) - g1_s2 = arith.ori( - arith.constant((tdim0 >> 16) & 0xFFFF, type=T.i32), - arith.shli(arith.andi(tdim1_rt, c_mask16), c16), - ) - # sgpr3: tensor_dim1_hi[15:0] (runtime) | tile_dim0[31:16] (const) - g1_s3 = arith.ori( - arith.andi(arith.shrui(tdim1_rt, c16), c_mask16), - arith.constant(tile_d0 << 16, type=T.i32), - ) - - # sgpr4: tile_dim1[15:0] | tile_dim2[31:16]=0 (always the full per-warp tile) - g1_s4 = arith.constant(tile_d1 & 0xFFFF, type=T.i32) - - # sgpr5: tensor_dim0_stride (low 32 bits) — stride of outermost dim - if outer_stride_is_runtime: - # Runtime leading-dim stride: truncate the index to i32 (strides < 2^31). - g1_s5 = arith.index_cast(T.i32, outer_stride_idx) - else: - g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) - - # sgpr6-7: for 2D, no higher-dim strides - g1_s6 = arith.constant(0, type=T.i32) - g1_s7 = arith.constant(0, type=T.i32) - - dgroup1 = vector.from_elements( - T.vec(8, T.i32), - [g1_s0, g1_s1, g1_s2, g1_s3, g1_s4, g1_s5, g1_s6, g1_s7], - ) - - return TDMDescriptor2D(dgroup0=dgroup0, dgroup1=dgroup1) From 158951e25fd8a83f4262c4d56b86d2546ae05175 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sat, 27 Jun 2026 16:29:35 +0000 Subject: [PATCH 2/2] add tuned csv for new flydsl gemm kernel --- aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv | 96 +++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv index d8561d72bd..a3d439835f 100644 --- a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv @@ -549,51 +549,51 @@ gfx950,256,4096,57344,8192,torch.float8_e4m3fn,flydsl,979,0,1486.3101,flydsl_bpr gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0 gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0 gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 -gfx1250,256,16,2112,7168,torch.float8_e4m3fn,flydsl,9,0,7.4638,flydsl_bpreshuffle_wmma_t16x64x512_mw1_nw2_nb4_sk1_cm1_cn1,64.91,2052.72,0.0103 -gfx1250,256,32,2112,7168,torch.float8_e4m3fn,flydsl,10,0,7.5824,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,127.78,2044.65,0.0097 -gfx1250,256,64,2112,7168,torch.float8_e4m3fn,flydsl,10,0,7.8757,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,246.04,2014.79,0.0098 -gfx1250,256,128,2112,7168,torch.float8_e4m3fn,flydsl,11,0,8.326,flydsl_bpreshuffle_wmma_t64x64x512_mw1_nw2_nb4_sk1_cm1_cn1,465.47,1993.39,0.0096 -gfx1250,256,256,2112,7168,torch.float8_e4m3fn,flydsl,19,0,25.7514,flydsl_bpreshuffle_wmma_t256x192x128_mw1_nw2_nb4_sk1_cm1_cn1,301.0,701.13,0.0098 -gfx1250,256,512,2112,7168,torch.float8_e4m3fn,flydsl,398,0,40.0855,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,386.73,523.17,0.0098 -gfx1250,256,1024,2112,7168,torch.float8_e4m3fn,flydsl,398,0,68.7538,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,450.95,389.86,0.0096 -gfx1250,256,2048,2112,7168,torch.float8_e4m3fn,flydsl,398,0,126.0903,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,491.78,305.1,0.0098 -gfx1250,256,4096,2112,7168,torch.float8_e4m3fn,flydsl,398,0,240.7634,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,515.1,256.69,0.0098 -gfx1250,256,8192,2112,7168,torch.float8_e4m3fn,flydsl,398,0,470.1095,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,527.61,230.72,0.0098 -gfx1250,256,16384,2112,7168,torch.float8_e4m3fn,flydsl,398,0,470.1095,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,527.61,230.72,0.0098 -gfx1250,256,32768,2112,7168,torch.float8_e4m3fn,flydsl,398,0,470.1095,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,527.61,230.72,0.0098 -gfx1250,256,16,7168,16384,torch.float8_e4m3fn,flydsl,3,0,30.6273,flydsl_bpreshuffle_wmma_t16x64x512_mw1_nw2_nb4_sk1_cm1_cn1,122.7,3850.55,0.0243 -gfx1250,256,32,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.9644,flydsl_bpreshuffle_wmma_t32x64x512_mw1_nw2_nb4_sk1_cm1_cn1,235.14,3704.86,0.0237 -gfx1250,256,64,7168,16384,torch.float8_e4m3fn,flydsl,1,0,49.3665,flydsl_bpreshuffle_wmma_t64x64x512_mw1_nw2_nb4_sk1_cm1_cn1,76.13,2388.91,0.0243 -gfx1250,256,128,7168,16384,torch.float8_e4m3fn,flydsl,500,0,32.7604,flydsl_bpreshuffle_wmma_t128x256x128_mw1_nw2_nb4_sk1_cm1_cn1,917.72,3704.86,0.0237 -gfx1250,256,256,7168,16384,torch.float8_e4m3fn,flydsl,500,0,33.8217,flydsl_bpreshuffle_wmma_t256x256x128_mw1_nw2_nb4_sk1_cm1_cn1,1777.84,3704.86,0.0237 -gfx1250,256,512,7168,16384,torch.float8_e4m3fn,flydsl,500,0,51.8551,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2319.14,2568.1,0.0237 -gfx1250,256,1024,7168,16384,torch.float8_e4m3fn,flydsl,500,0,92.0588,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2612.66,1617.42,0.0237 -gfx1250,256,2048,7168,16384,torch.float8_e4m3fn,flydsl,302,0,172.4662,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2789.16,1045.74,0.0237 -gfx1250,256,4096,7168,16384,torch.float8_e4m3fn,flydsl,302,0,327.0773,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2941.42,743.77,0.0236 -gfx1250,256,8192,7168,16384,torch.float8_e4m3fn,flydsl,302,0,640.4273,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3004.47,576.33,0.0237 -gfx1250,256,16384,7168,16384,torch.float8_e4m3fn,flydsl,302,0,327.0773,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2941.42,743.77,0.0236 -gfx1250,256,32768,7168,16384,torch.float8_e4m3fn,flydsl,302,0,640.4273,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3004.47,576.33,0.0237 -gfx1250,256,16,24576,1536,torch.float8_e4m3fn,flydsl,11,0,5.5141,flydsl_bpreshuffle_wmma_t16x64x256_mw1_nw2_nb4_sk1_cm1_cn1,219.07,6992.94,0.0 -gfx1250,256,32,24576,1536,torch.float8_e4m3fn,flydsl,13,0,5.2752,flydsl_bpreshuffle_wmma_t32x128x256_mw2_nw2_nb4_sk1_cm1_cn1,457.98,7463.37,0.0 -gfx1250,256,64,24576,1536,torch.float8_e4m3fn,flydsl,16,0,6.1534,flydsl_bpreshuffle_wmma_t64x64x256_mw2_nw2_nb4_sk1_cm1_cn1,785.23,6661.81,0.0 -gfx1250,256,128,24576,1536,torch.float8_e4m3fn,flydsl,25,0,8.7809,flydsl_bpreshuffle_wmma_t128x128x128_mw2_nw2_nb4_sk1_cm1_cn1,1100.53,5037.84,0.0 -gfx1250,256,256,24576,1536,torch.float8_e4m3fn,flydsl,28,0,10.5346,flydsl_bpreshuffle_wmma_t128x256x128_mw2_nw4_nb4_sk1_cm1_cn1,1834.65,4815.07,0.0 -gfx1250,256,512,24576,1536,torch.float8_e4m3fn,flydsl,26,0,28.5168,flydsl_bpreshuffle_wmma_t128x192x128_mw2_nw2_nb4_sk1_cm1_cn1,1355.51,2233.81,0.0 -gfx1250,256,1024,24576,1536,torch.float8_e4m3fn,flydsl,302,0,64.4699,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1199.16,1390.62,0.0 -gfx1250,256,2048,24576,1536,torch.float8_e4m3fn,flydsl,302,0,90.4895,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1708.69,1564.36,0.0 -gfx1250,256,4096,24576,1536,torch.float8_e4m3fn,flydsl,302,0,146.1362,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2116.09,1679.03,0.0 -gfx1250,256,8192,24576,1536,torch.float8_e4m3fn,flydsl,302,0,257.4906,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2401.93,1759.23,0.0 -gfx1250,256,16384,24576,1536,torch.float8_e4m3fn,flydsl,302,0,146.1362,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2116.09,1679.03,0.0 -gfx1250,256,32768,24576,1536,torch.float8_e4m3fn,flydsl,302,0,257.4906,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2401.93,1759.23,0.0 -gfx1250,256,16,32768,512,torch.float8_e4m3fn,flydsl,16,0,4.5141,flydsl_bpreshuffle_wmma_t16x128x128_mw1_nw4_nb4_sk1_cm1_cn1,114.44,3801.5,0.0 -gfx1250,256,32,32768,512,torch.float8_e4m3fn,flydsl,17,0,4.7678,flydsl_bpreshuffle_wmma_t32x128x128_mw1_nw4_nb4_sk1_cm1_cn1,247.66,4357.23,0.0 -gfx1250,256,64,32768,512,torch.float8_e4m3fn,flydsl,21,0,5.5112,flydsl_bpreshuffle_wmma_t64x256x128_mw1_nw4_nb4_sk1_cm1_cn1,389.66,3811.20,0.0 -gfx1250,256,128,32768,512,torch.float8_e4m3fn,flydsl,26,0,5.6324,flydsl_bpreshuffle_wmma_t128x256x128_mw2_nw4_nb4_sk1_cm1_cn1,762.55,4479.68,0.0 -gfx1250,256,256,32768,512,torch.float8_e4m3fn,flydsl,29,0,6.6907,flydsl_bpreshuffle_wmma_t256x128x128_mw2_nw2_nb4_sk1_cm1_cn1,1283.86,5034.68,0.0 -gfx1250,256,512,32768,512,torch.float8_e4m3fn,flydsl,31,0,9.5102,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1806.47,5319.95,0.0 -gfx1250,256,1024,32768,512,torch.float8_e4m3fn,flydsl,27,0,30.9049,flydsl_bpreshuffle_wmma_t128x512x128_mw2_nw4_nb3_sk1_cm1_cn1,1111.79,2731.29,0.0 -gfx1250,256,2048,32768,512,torch.float8_e4m3fn,flydsl,31,0,49.0573,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1400.8,3099.3,0.0 -gfx1250,256,4096,32768,512,torch.float8_e4m3fn,flydsl,31,0,56.5638,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2429.8,5079.39,0.0 -gfx1250,256,8192,32768,512,torch.float8_e4m3fn,flydsl,31,0,114.0927,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2409.25,4889.38,0.0 -gfx1250,256,16384,32768,512,torch.float8_e4m3fn,flydsl,31,0,231.7468,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2372.23,4741.85,0.0 -gfx1250,256,32768,32768,512,torch.float8_e4m3fn,flydsl,31,0,466.7685,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2355.58,4672.63,0.0 +gfx1250,256,16,2112,7168,torch.float8_e4m3fn,flydsl,49,0,4.7681,flydsl_bpreshuffle_wmma_t16x32x1024_mw1_nw2_nb4_sk1_cm1_cn1,101.6,3213.25,0.0103 +gfx1250,256,32,2112,7168,torch.float8_e4m3fn,flydsl,266,0,5.3731,flydsl_bpreshuffle_wmma_t32x32x1024_mw2_nw1_nb4_sk1_cm1_cn1,180.32,2885.37,0.0097 +gfx1250,256,64,2112,7168,torch.float8_e4m3fn,flydsl,266,0,5.8099,flydsl_bpreshuffle_wmma_t32x32x1024_mw2_nw1_nb4_sk1_cm1_cn1,333.53,2731.18,0.0098 +gfx1250,256,128,2112,7168,torch.float8_e4m3fn,flydsl,651,0,7.8921,flydsl_bpreshuffle_wmma_t64x32x512_mw4_nw1_nb4_sk1_cm1_cn1,491.07,2102.99,0.0096 +gfx1250,256,256,2112,7168,torch.float8_e4m3fn,flydsl,1008,0,12.7104,flydsl_bpreshuffle_wmma_t128x32x256_mw4_nw1_nb4_sk1_cm1_cn1,609.82,1420.5,0.0098 +gfx1250,256,512,2112,7168,torch.float8_e4m3fn,flydsl,1274,0,15.7872,flydsl_bpreshuffle_wmma_t256x32x256_mw2_nw1_nb4_sk1_cm1_cn1,981.94,1328.39,0.0096 +gfx1250,256,1024,2112,7168,torch.float8_e4m3fn,flydsl,1309,0,25.2951,flydsl_bpreshuffle_wmma_t256x64x128_mw4_nw1_nb4_sk1_cm1_cn1,1225.7,1059.66,0.0098 +gfx1250,256,2048,2112,7168,torch.float8_e4m3fn,flydsl,1330,0,28.4157,flydsl_bpreshuffle_wmma_t256x96x128_mw4_nw1_nb4_sk1_cm1_cn1,2182.19,1353.82,0.0097 +gfx1250,256,4096,2112,7168,torch.float8_e4m3fn,flydsl,1358,0,33.786,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,3670.67,1829.17,0.0097 +gfx1250,256,8192,2112,7168,torch.float8_e4m3fn,flydsl,194,0,73.6239,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,3368.94,1473.19,0.0097 +gfx1250,256,16384,2112,7168,torch.float8_e4m3fn,flydsl,194,0,122.1808,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,4060.12,1651.53,0.0097 +gfx1250,256,32768,2112,7168,torch.float8_e4m3fn,flydsl,194,0,261.7924,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,3789.79,1483.74,0.0097 +gfx1250,256,16,7168,16384,torch.float8_e4m3fn,flydsl,3,0,10.4336,flydsl_bpreshuffle_wmma_t16x32x1024_mw1_nw2_nb4_sk1_cm1_cn1,360.19,11303.1,0.0243 +gfx1250,256,32,7168,16384,torch.float8_e4m3fn,flydsl,39,0,12.0338,flydsl_bpreshuffle_wmma_t32x32x1024_mw2_nw2_nb4_sk1_cm1_cn1,624.59,9840.91,0.024 +gfx1250,256,64,7168,16384,torch.float8_e4m3fn,flydsl,93,0,16.3277,flydsl_bpreshuffle_wmma_t64x32x512_mw4_nw1_nb4_sk1_cm1_cn1,920.67,7313.13,0.0237 +gfx1250,256,128,7168,16384,torch.float8_e4m3fn,flydsl,108,0,20.6842,flydsl_bpreshuffle_wmma_t64x64x512_mw4_nw1_nb4_sk1_cm1_cn1,1453.51,5867.89,0.0236 +gfx1250,256,256,7168,16384,torch.float8_e4m3fn,flydsl,171,0,33.9351,flydsl_bpreshuffle_wmma_t128x128x256_mw2_nw2_nb4_sk1_cm1_cn1,1771.9,3692.48,0.0236 +gfx1250,256,512,7168,16384,torch.float8_e4m3fn,flydsl,192,0,59.7158,flydsl_bpreshuffle_wmma_t256x128x128_mw2_nw2_nb4_sk1_cm1_cn1,2013.86,2230.05,0.0236 +gfx1250,256,1024,7168,16384,torch.float8_e4m3fn,flydsl,195,0,83.2491,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2889.14,1788.58,0.0237 +gfx1250,256,2048,7168,16384,torch.float8_e4m3fn,flydsl,195,0,102.0803,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,4712.33,1766.8,0.0237 +gfx1250,256,4096,7168,16384,torch.float8_e4m3fn,flydsl,195,0,201.5675,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,4772.96,1206.89,0.0237 +gfx1250,256,8192,7168,16384,torch.float8_e4m3fn,flydsl,195,0,391.806,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,4910.96,942.04,0.0237 +gfx1250,256,16384,7168,16384,torch.float8_e4m3fn,flydsl,195,0,752.1932,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,5116.09,825.26,0.0237 +gfx1250,256,32768,7168,16384,torch.float8_e4m3fn,flydsl,195,0,1543.2655,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,4987.2,728.37,0.0237 +gfx1250,256,16,24576,1536,torch.float8_e4m3fn,flydsl,15,0,5.2254,flydsl_bpreshuffle_wmma_t16x128x256_mw1_nw2_nb4_sk1_cm1_cn1,231.17,7379.29,0.0 +gfx1250,256,32,24576,1536,torch.float8_e4m3fn,flydsl,68,0,5.6662,flydsl_bpreshuffle_wmma_t32x128x256_mw2_nw2_nb4_sk1_cm1_cn1,426.37,6948.35,0.0 +gfx1250,256,64,24576,1536,torch.float8_e4m3fn,flydsl,123,0,6.7094,flydsl_bpreshuffle_wmma_t64x128x256_mw1_nw4_nb4_sk1_cm1_cn1,720.16,6109.75,0.0 +gfx1250,256,128,24576,1536,torch.float8_e4m3fn,flydsl,171,0,8.457,flydsl_bpreshuffle_wmma_t128x128x256_mw2_nw2_nb4_sk1_cm1_cn1,1142.68,5230.79,0.0 +gfx1250,256,256,24576,1536,torch.float8_e4m3fn,flydsl,177,0,11.917,flydsl_bpreshuffle_wmma_t128x256x128_mw1_nw4_nb4_sk1_cm1_cn1,1621.83,4256.51,0.0 +gfx1250,256,512,24576,1536,torch.float8_e4m3fn,flydsl,195,0,18.2548,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2117.51,3489.55,0.0 +gfx1250,256,1024,24576,1536,torch.float8_e4m3fn,flydsl,195,0,33.4933,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2308.21,2676.75,0.0 +gfx1250,256,2048,24576,1536,torch.float8_e4m3fn,flydsl,195,0,70.2553,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2200.81,2014.91,0.0 +gfx1250,256,4096,24576,1536,torch.float8_e4m3fn,flydsl,195,0,147.0666,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2102.7,1668.41,0.0 +gfx1250,256,8192,24576,1536,torch.float8_e4m3fn,flydsl,195,0,414.2206,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1493.11,1093.58,0.0 +gfx1250,256,16384,24576,1536,torch.float8_e4m3fn,flydsl,195,0,811.9439,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1523.44,1069.31,0.0 +gfx1250,256,32768,24576,1536,torch.float8_e4m3fn,flydsl,195,0,2120.4384,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1166.69,801.1,0.0 +gfx1250,256,16,32768,512,torch.float8_e4m3fn,flydsl,24,0,4.558,flydsl_bpreshuffle_wmma_t16x256x128_mw1_nw4_nb4_sk1_cm1_cn1,117.79,3912.68,0.0 +gfx1250,256,32,32768,512,torch.float8_e4m3fn,flydsl,76,0,4.674,flydsl_bpreshuffle_wmma_t32x256x128_mw1_nw4_nb4_sk1_cm1_cn1,229.73,4041.67,0.0 +gfx1250,256,64,32768,512,torch.float8_e4m3fn,flydsl,134,0,5.0051,flydsl_bpreshuffle_wmma_t64x256x128_mw1_nw4_nb4_sk1_cm1_cn1,429.06,4196.58,0.0 +gfx1250,256,128,32768,512,torch.float8_e4m3fn,flydsl,178,0,6.1586,flydsl_bpreshuffle_wmma_t128x256x128_mw2_nw2_nb4_sk1_cm1_cn1,697.39,4096.93,0.0 +gfx1250,256,256,32768,512,torch.float8_e4m3fn,flydsl,195,0,8.1141,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1058.64,4151.48,0.0 +gfx1250,256,512,32768,512,torch.float8_e4m3fn,flydsl,195,0,16.8347,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1020.5,3005.33,0.0 +gfx1250,256,1024,32768,512,torch.float8_e4m3fn,flydsl,193,0,30.0237,flydsl_bpreshuffle_wmma_t256x128x128_mw4_nw1_nb4_sk1_cm1_cn1,1144.42,2811.46,0.0 +gfx1250,256,2048,32768,512,torch.float8_e4m3fn,flydsl,195,0,44.2996,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1551.24,3432.16,0.0 +gfx1250,256,4096,32768,512,torch.float8_e4m3fn,flydsl,195,0,98.8264,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1390.71,2907.22,0.0 +gfx1250,256,8192,32768,512,torch.float8_e4m3fn,flydsl,195,0,231.7324,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1186.19,2407.27,0.0 +gfx1250,256,16384,32768,512,torch.float8_e4m3fn,flydsl,195,0,548.4978,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1002.29,2003.49,0.0 +gfx1250,256,32768,32768,512,torch.float8_e4m3fn,flydsl,195,0,1252.4232,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,877.91,1741.45,0.0