From 0a222136f5820497f18fcd676a1d7fa4315469b3 Mon Sep 17 00:00:00 2001 From: Santosh Mohan Date: Fri, 19 Jun 2026 16:04:31 -0700 Subject: [PATCH] [mx_formats/cutedsl] Unified NVFP4 + MXFP4 (+/- RHT) quantize kernel Adds a self-contained CuTeDSL FP4 quantize subpackage: one no-smem streaming kernel that serves both FP4 formats and all three scale layouts the GEMM consumers use, with optional fused RHT. Supersedes the separate per-format nvfp4_rht / mxfp4_rht casts. * fmt="nvfp4": block 16, two-level E4M3 block scale + per-tensor global scale (float8_e4m3fn scales); supports arbitrary K % 16 via a masked remainder. * fmt="mxfp4": block 32, single-level E8M0 block scale (float8_e8m0fnu), floor or rceil; requires K % 32. * scale_layout in {linear, cublas_blocked, mma_tiled} selected at compile time (cublas_blocked feeds f4f4bf16; mma_tiled feeds the SM100 blockscaled GEMM with no separate scale-conversion pass). * optional fused RHT (register FWHT16/32 + sign), skipped via a constexpr on the plain path. A "group" is 32 input elements = one 128-bit store = two NVFP4 blocks or one MXFP4 block; the per-format scale recipe, FWHT size, and MMA row-atom are compile-time FORMAT-selected so a single kernel body covers both formats. Two byte-identical thread mappings are exposed via mapping=: "striped" (best at very large N) and "wpr" (warp-per-row + grid.x column split; best at small/mid N). Files: cute_utils.py (E2M1 pack + E4M3/E8M0 scale recipes + amax, bit-exact vs eager), fwht.py (register FWHT16/32 + sign), fp4_unified_quantize.py (the kernel + torchao::fp4_quantize_unified op + gated fp4_quantize_unified_2d). Test: test/prototype/mx_formats/test_fp4_unified_cutedsl.py (B200-gated) checks the plain-cast per-block scales byte-exact vs the cute_utils host references (themselves bit-exact vs eager), wpr == striped across all layouts, qdata invariance across the three scale layouts, and a plain-cast dequant round-trip. --- .../mx_formats/test_fp4_unified_cutedsl.py | 156 +++++ .../prototype/mx_formats/cutedsl/__init__.py | 75 +++ .../mx_formats/cutedsl/cute_utils.py | 607 +++++++++++++++++ .../cutedsl/fp4_unified_quantize.py | 622 ++++++++++++++++++ torchao/prototype/mx_formats/cutedsl/fwht.py | 298 +++++++++ 5 files changed, 1758 insertions(+) create mode 100644 test/prototype/mx_formats/test_fp4_unified_cutedsl.py create mode 100644 torchao/prototype/mx_formats/cutedsl/__init__.py create mode 100644 torchao/prototype/mx_formats/cutedsl/cute_utils.py create mode 100644 torchao/prototype/mx_formats/cutedsl/fp4_unified_quantize.py create mode 100644 torchao/prototype/mx_formats/cutedsl/fwht.py diff --git a/test/prototype/mx_formats/test_fp4_unified_cutedsl.py b/test/prototype/mx_formats/test_fp4_unified_cutedsl.py new file mode 100644 index 0000000000..b913f42682 --- /dev/null +++ b/test/prototype/mx_formats/test_fp4_unified_cutedsl.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the unified FP4 (NVFP4 + MXFP4) +/- RHT CuTeDSL quantize cast. + +Validates that the plain (no-RHT) cast's per-block scales are byte-exact vs the +``cute_utils`` host scale references (themselves bit-exact vs eager torchao), +that the two thread mappings (striped / warp-per-row) produce identical output, +that qdata is invariant across the three scale layouts, and that the plain cast +round-trips to within FP4 error. +""" + +import pytest +import torch + +from torchao.prototype.mx_formats.cutedsl import _fp4_cutedsl_kernels_available + +pytestmark = pytest.mark.skipif( + not _fp4_cutedsl_kernels_available, + reason="requires SM 10.x (Blackwell), CUDA>=12.8, and the CuTeDSL runtime", +) + +_E2M1_VALS = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + + +def _x(M, K, dtype=torch.bfloat16): + torch.manual_seed(0) + return (torch.randn(M, K, device="cuda", dtype=dtype) * 5).contiguous() + + +def _dequant_linear(qdata, scales_u8, gs, M, K, fmt): + blk = 16 if fmt == "nvfp4" else 32 + kb = K // blk + lut = torch.tensor(_E2M1_VALS + [-v for v in _E2M1_VALS], device=qdata.device) + codes = torch.stack( + [lut[(qdata & 0xF).long()], lut[(qdata >> 4).long()]], dim=-1 + ).reshape(M, K) + if fmt == "nvfp4": + sc = scales_u8.view(torch.float8_e4m3fn).float().view(M, kb) / gs + else: + sc = scales_u8.view(torch.float8_e8m0fnu).float().view(M, kb) + return (codes.view(M, kb, blk) * sc.unsqueeze(-1)).view(M, K) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("M,K", [(256, 512), (2304, 4096), (1024, 2048)]) +def test_nvfp4_plain_scales_match_reference(dtype, M, K): + # Unified linear-layout E4M3 scales must be byte-exact vs the cute_utils + # per-block host reference (bit-exact vs eager NVFP4). + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + compute_block_scale_e4m3_nvfp4, + ) + from torchao.prototype.mx_formats.cutedsl.fp4_unified_quantize import ( + _fp4_quantize_unified_impl, + ) + + x = _x(M, K, dtype) + gs = 2688.0 / x.abs().max().item() + ref = compute_block_scale_e4m3_nvfp4(x, gs) # (M, K//16) float8_e4m3fn + _, su = _fp4_quantize_unified_impl( + x, fmt="nvfp4", scale_layout="linear", global_scale=gs, mapping="striped" + ) + torch.cuda.synchronize() + assert int((ref.view(torch.uint8).view(-1) != su.view(-1)).sum()) == 0 + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("M,K", [(256, 512), (2304, 4096), (1024, 2048)]) +@pytest.mark.parametrize("mode", ["floor", "rceil"]) +def test_mxfp4_plain_scales_match_reference(dtype, M, K, mode): + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + compute_block_scale_e8m0_fp4, + ) + from torchao.prototype.mx_formats.cutedsl.fp4_unified_quantize import ( + _fp4_quantize_unified_impl, + ) + + x = _x(M, K, dtype) + ref = compute_block_scale_e8m0_fp4(x, mode) # (M, K//32) float8_e8m0fnu + _, su = _fp4_quantize_unified_impl( + x, fmt="mxfp4", scaling_mode=mode, scale_layout="linear", mapping="striped" + ) + torch.cuda.synchronize() + assert int((ref.view(torch.uint8).view(-1) != su.view(-1)).sum()) == 0 + + +@pytest.mark.parametrize("fmt", ["nvfp4", "mxfp4"]) +@pytest.mark.parametrize("rht", [False, True]) +@pytest.mark.parametrize("scale_layout", ["linear", "cublas_blocked", "mma_tiled"]) +def test_wpr_matches_striped(fmt, rht, scale_layout): + # The two mappings are a pure work-distribution change: identical output. + from torchao.prototype.mx_formats.cutedsl.fp4_unified_quantize import ( + _fp4_quantize_unified_impl, + ) + + x = _x(256, 2048) + gs = 2688.0 / x.abs().max().item() if fmt == "nvfp4" else 1.0 + sign = ([1, -1] * (8 if fmt == "nvfp4" else 16)) if rht else None + qs, ss = _fp4_quantize_unified_impl( + x, sign_vector=sign, fmt=fmt, scale_layout=scale_layout, global_scale=gs, + mapping="striped", + ) + qw, sw = _fp4_quantize_unified_impl( + x, sign_vector=sign, fmt=fmt, scale_layout=scale_layout, global_scale=gs, + mapping="wpr", warps=4, xsplit=2, ilp=2, + ) + torch.cuda.synchronize() + assert int((qs != qw).sum()) == 0 + assert int((ss.view(-1) != sw.view(-1)).sum()) == 0 + + +@pytest.mark.parametrize("fmt", ["nvfp4", "mxfp4"]) +def test_qdata_layout_invariant(fmt): + from torchao.prototype.mx_formats.cutedsl.fp4_unified_quantize import ( + _fp4_quantize_unified_impl, + ) + + x = _x(256, 2048) + gs = 2688.0 / x.abs().max().item() if fmt == "nvfp4" else 1.0 + q = {} + for lay in ("linear", "cublas_blocked", "mma_tiled"): + q[lay], _ = _fp4_quantize_unified_impl( + x, fmt=fmt, scale_layout=lay, global_scale=gs, mapping="striped" + ) + torch.cuda.synchronize() + assert int((q["linear"] != q["cublas_blocked"]).sum()) == 0 + assert int((q["linear"] != q["mma_tiled"]).sum()) == 0 + + +@pytest.mark.parametrize("fmt", ["nvfp4", "mxfp4"]) +def test_dequant_roundtrip_plain(fmt): + from torchao.prototype.mx_formats.cutedsl.fp4_unified_quantize import ( + _fp4_quantize_unified_impl, + ) + + x = _x(256, 2048) + gs = 2688.0 / x.abs().max().item() if fmt == "nvfp4" else 1.0 + q, s = _fp4_quantize_unified_impl( + x, fmt=fmt, scale_layout="linear", global_scale=gs, mapping="striped" + ) + torch.cuda.synchronize() + deq = _dequant_linear(q, s, gs, 256, 2048, fmt) + xf = x.float() + m = xf.abs() > 0.1 + rel = ((deq - xf).abs()[m] / xf.abs()[m]).median().item() + assert rel < 0.3 + + +def test_custom_op_smoke(): + from torchao.prototype.mx_formats.cutedsl import fp4_quantize_unified_2d + + q, s = fp4_quantize_unified_2d(_x(128, 2048), None, "nvfp4", "floor", "mma_tiled") + assert q.shape == (128, 1024) diff --git a/torchao/prototype/mx_formats/cutedsl/__init__.py b/torchao/prototype/mx_formats/cutedsl/__init__.py new file mode 100644 index 0000000000..56d59d0362 --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/__init__.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Unified FP4 (NVFP4 + MXFP4) +/- Random Hadamard Transform CuTeDSL casts. + +A single no-smem streaming CuTeDSL quantize kernel for both FP4 formats +(NVFP4 E2M1 block-16 with a two-level E4M3 scale; MXFP4 E2M1 block-32 with an +E8M0 scale), all three GEMM scale layouts (``linear`` / ``cublas_blocked`` / +``mma_tiled``), and an optional fused Random Hadamard Transform. Gated behind: + +* a Blackwell-family GPU (SM 10.x), +* CUDA >= 12.8, +* the CuTeDSL runtime packages (``nvidia-cutlass-dsl`` and friends). + +No ``cutlass`` import happens at module scope so importing this package is safe +on machines without the CuTeDSL runtime (the gate flag simply evaluates False). +""" + +import torch + +from torchao.utils import is_cuda_version_at_least + +from .cute_utils import _cutedsl_runtime_available + + +def _is_sm_10x() -> bool: + """Return True iff a Blackwell-family (SM 10.x) GPU is available.""" + return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10 + + +_fp4_cutedsl_kernels_available = ( + _is_sm_10x() and is_cuda_version_at_least(12, 8) and _cutedsl_runtime_available() +) + + +def pack32_e2m1_to_bytes(x: torch.Tensor) -> torch.Tensor: + """Lazily re-exported test/validation entry for E2M1 packing. + + See ``cute_utils.pack32_e2m1_to_bytes``. Imported lazily so that importing + this package does not require the CuTeDSL runtime. + """ + from .cute_utils import pack32_e2m1_to_bytes as _impl + + return _impl(x) + + +def fp4_quantize_unified_2d( + x, + sign_vector=None, + fmt: str = "nvfp4", + scaling_mode: str = "floor", + scale_layout: str = "cublas_blocked", +): + """Lazily re-exported gated wrapper for the unified FP4 (+/- RHT) cast. + + See ``fp4_unified_quantize.fp4_quantize_unified_2d``. Imported lazily so + that importing this package does not require the CuTeDSL runtime. One kernel + serves NVFP4 (``fmt="nvfp4"``) and MXFP4 (``fmt="mxfp4"``) across the + ``linear`` / ``cublas_blocked`` / ``mma_tiled`` scale layouts; an empty / + ``None`` ``sign_vector`` selects the plain cast. + """ + from .fp4_unified_quantize import fp4_quantize_unified_2d as _impl + + return _impl(x, sign_vector, fmt, scaling_mode, scale_layout) + + +__all__ = [ + "_is_sm_10x", + "_fp4_cutedsl_kernels_available", + "pack32_e2m1_to_bytes", + "fp4_quantize_unified_2d", +] diff --git a/torchao/prototype/mx_formats/cutedsl/cute_utils.py b/torchao/prototype/mx_formats/cutedsl/cute_utils.py new file mode 100644 index 0000000000..785b9dc64e --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/cute_utils.py @@ -0,0 +1,607 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared utilities for the MXFP4 + RHT CuTeDSL quantize kernels.""" + +import importlib.util + +import torch + +# Runtime package detection +_CUTEDSL_RUNTIME_PACKAGES = { + "cuda.bindings.driver": "cuda-python", + "cutlass": "nvidia-cutlass-dsl", + "cutlass.cute": "nvidia-cutlass-dsl", + "tvm_ffi": "apache-tvm-ffi", +} + + +def _missing_cutedsl_runtime_packages() -> list[str]: + """Check which CuTeDSL runtime packages are missing. + + Returns: + List of missing package names + """ + missing = [] + for module_name, package_name in _CUTEDSL_RUNTIME_PACKAGES.items(): + try: + spec = importlib.util.find_spec(module_name) + except (ModuleNotFoundError, ValueError): + # ModuleNotFoundError: parent module doesn't exist (e.g., 'cuda' on CPU) + # ValueError: can occur with malformed module names + spec = None + + if spec is None and package_name not in missing: + missing.append(package_name) + return missing + + +def _cutedsl_runtime_available() -> bool: + """Check if all CuTeDSL runtime packages are available. + + Returns: + True if all required packages are installed + """ + return len(_missing_cutedsl_runtime_packages()) == 0 + + +if _cutedsl_runtime_available(): + import cutlass + import cutlass.cute as cute + from cutlass._mlir.dialects import llvm + from cutlass.base_dsl._mlir_helpers import arith as _dsl_arith + from cutlass.cutlass_dsl import T, dsl_user_op + + # FP4 (E2M1) constants. F4_E2M1_MAX == 6.0. + INV_F4_E2M1_MAX = cutlass.Float32(1.0 / 6.0) + + # NVFP4 two-level (E4M3 block scale) constants. + # E4M3_EPS is torch.finfo(float8_e4m3fn).tiny == 2**-6 == 0.015625, the + # smallest *normal* E4M3 value. The eager NVFP4 path clamps the block scale + # to [E4M3_EPS, F8E4M3_MAX] before the float8_e4m3fn cast, so we replicate + # the lower clamp here (the saturating PTX cvt handles the upper bound). + F8E4M3_MAX = cutlass.Float32(448.0) + E4M3_EPS = cutlass.Float32(0.015625) + + # FP4 E8M0 scale constants. NOTE: these are the FP4 values, NOT the FP8 + # ones -- `F4_E2M1_MAX_POW2 == 2` (the FP8 helper uses 8), and the RCEIL + # descale divisor is `F4_E2M1_MAX == 6.0` (the FP8 helper uses 448). + F4_E2M1_MAX_POW2 = 2 # log2 of the largest power-of-two <= F4_E2M1_MAX (6.0) + E8M0_EXPONENT_BIAS = 127 + E8M0_EXPONENT_NAN_VAL = 255 + + @dsl_user_op + def _cvt_rn_satfinite_e2m1x2_f32( + hi: cutlass.Float32, + lo: cutlass.Float32, + *, + loc=None, + ip=None, + ) -> cutlass.Uint8: + """PTX inline assembly that packs two f32 values into one E2M1x2 byte. + + Uses inline PTX on Blackwell-family targets because CuTeDSL does not + currently lower the float32 -> E2M1 pair conversion to + ``cvt.rn.satfinite.e2m1x2.f32`` on its own. + + The PTX result of ``cvt.rn.satfinite.e2m1x2.f32 d, a, b`` is a ``.b8`` + value, but inline-asm output registers must be at least 16-bit and + ptxas rejects a 16-bit register directly as the ``cvt`` destination. So + (mirroring cutlass's ``cvt.rn.f16x2.e2m1x2`` wrappers in + ``cute/arch/nvvm_wrappers.py``) we ``cvt`` into a ``.reg .b8`` and + assemble it into a ``.b16`` output via ``mov.b16 $0, {d, z}`` with a + zero high byte. The low byte is then masked out and returned. + + Convention (validated bit-exactly by ``test_mxfp4_rht_cutedsl``): the + second PTX operand ``b`` (``lo``) lands in the LOW nibble and the first + operand ``a`` (``hi``) lands in the HIGH nibble, i.e. the returned byte + is ``(e2m1(hi) << 4) | e2m1(lo)``. + """ + packed = cutlass.Uint16( + llvm.inline_asm( + T.i16(), + [ + cutlass.Float32(hi).ir_value(loc=loc, ip=ip), + cutlass.Float32(lo).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b8 d, z, w;\n\t" + ".reg .b16 zero16;\n\t" + "mov.u16 zero16, 0;\n\t" + "mov.b16 {z, w}, zero16;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 d, $1, $2;\n\t" + "mov.b16 $0, {d, z};\n\t" + "}", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return cutlass.Uint8(packed & cutlass.Uint16(0xFF)) + + @cute.kernel + def _pack32_e2m1_kernel( + gx: cute.Tensor, + gout: cute.Tensor, + ): + """One-block, single-thread kernel: pack 32 f32 -> 16 E2M1x2 bytes. + + For ``p`` in ``0..15`` it emits + ``_cvt_rn_satfinite_e2m1x2_f32(hi=x[2p+1], lo=x[2p])`` so that the even + column ``2p`` lands in the low nibble and the odd column ``2p+1`` in the + high nibble of output byte ``p``. + """ + tidx, _, _ = cute.arch.thread_idx() + if tidx == 0: + for p in cutlass.range_constexpr(16): + lo = cutlass.Float32(gx[2 * p]) + hi = cutlass.Float32(gx[2 * p + 1]) + gout[p] = _cvt_rn_satfinite_e2m1x2_f32(hi, lo) + + @cute.jit + def _pack32_e2m1_launch( + gx: cute.Tensor, + gout: cute.Tensor, + stream, + ): + _pack32_e2m1_kernel(gx, gout).launch( + grid=(1, 1, 1), + block=(32, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def pack32_e2m1_to_bytes(x: torch.Tensor) -> torch.Tensor: + """Pack a length-32 fp32 CUDA vector into 16 E2M1x2 bytes. + + Test/validation entry only -- the production kernel inlines the same + ``_cvt_rn_satfinite_e2m1x2_f32`` logic. Even input columns go to the low + nibble, odd columns to the high nibble (matching torchao's packed-fp4 + byte order). + + Args: + x: 1D float32 CUDA tensor of length 32. + + Returns: + A ``(16,)`` uint8 CUDA tensor. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dtype == torch.float32, "input must be float32" + assert x.numel() == 32, "input must have exactly 32 elements" + x = x.contiguous() + out = torch.empty((16,), device=x.device, dtype=torch.uint8) + + gx = from_dlpack(x) + gout = from_dlpack(out) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _pack32_e2m1_launch(gx, gout, stream) + return out + + # ------------------------------------------------------------------ + # FP4 E8M0 block-scale helpers + # ------------------------------------------------------------------ + + @dsl_user_op + def _cvt_rp_satfinite_ue8m0x2_f32( + a: cutlass.Float32, + *, + loc=None, + ip=None, + ) -> cutlass.Uint16: + """PTX inline assembly for RCEIL E8M0 conversion (Blackwell SM10.x). + + ``cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, descale`` packs two e8m0 + results into a uint16; the low byte holds the e8m0 of ``descale``. + Hardware handles NaN -> 255, Inf -> 254, subnormals -> 0. + """ + return cutlass.Uint16( + llvm.inline_asm( + T.i16(), + [cutlass.Float32(a).ir_value(loc=loc, ip=ip)], + "cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;", + "=h,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @cute.jit + def compute_amax(vals_block: cute.Tensor): + """Compute the absolute maximum of a block of values as Float32.""" + vals_vec = vals_block.load() + abs_vec = cute.where(vals_vec < 0, -vals_vec, vals_vec) + return cutlass.Float32( + abs_vec.reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0) + ) + + @cute.jit + def compute_scale_floor_fp4(amax: cutlass.Float32): + """FP4 FLOOR-mode E8M0 biased scale byte + fp32 inverse scale. + + Mirrors torchao eager ``to_mx`` (FLOOR branch) for + ``elem_dtype=torch.float4_e2m1fn_x2``: + + extracted_pow2 = ((bits(amax) >> 23) & 0xFF) - 127 + scale_unbiased = extracted_pow2 - F4_E2M1_MAX_POW2 # -2, NOT -8 + scale_unbiased = clamp(scale_unbiased, -127, 128) + scale_biased(byte) = scale_unbiased + 127 + inv_scale = 2 ** (-scale_unbiased) + + The ``inv_scale`` is the value the quantizer multiplies each element by + before the E2M1 conversion (it is ``1 / 2**scale_unbiased``). The eager + FLOOR path floors the divisor to ``2**-126``; the matching cap here is + ``inv_scale <= 2**126`` (i.e. ``scale_unbiased >= -126``), so we clamp + the exponent used for ``inv_scale`` at the bottom. + + Returns ``(scale_biased, inv_scale)`` where ``scale_biased`` is the + biased E8M0 byte as Int32 (caller stores low 8 bits) and ``inv_scale`` + is a ``Float32``. + """ + bits = _dsl_arith.bitcast(amax.ir_value(), _dsl_arith.T.i32()) + exp_i = ((bits >> cutlass.Int32(23)) & cutlass.Int32(0xFF)) - cutlass.Int32( + E8M0_EXPONENT_BIAS + ) + scale_exp_unbiased = exp_i - cutlass.Int32(F4_E2M1_MAX_POW2) + if scale_exp_unbiased < -E8M0_EXPONENT_BIAS: + scale_exp_unbiased = cutlass.Int32(-E8M0_EXPONENT_BIAS) + if scale_exp_unbiased > E8M0_EXPONENT_BIAS + 1: + scale_exp_unbiased = cutlass.Int32(E8M0_EXPONENT_BIAS + 1) + scale_biased = scale_exp_unbiased + E8M0_EXPONENT_BIAS + # inv_scale = 2 ** (-scale_unbiased); divisor floored to 2**-126, so the + # inverse is capped at 2**126 (scale_exp_unbiased >= -126). + inv_exp = -scale_exp_unbiased + if inv_exp > cutlass.Int32(126): + inv_exp = cutlass.Int32(126) + inv_scale = cute.exp2(cutlass.Float32(inv_exp)) + return scale_biased, inv_scale + + @cute.jit + def compute_scale_rceil_fp4(amax: cutlass.Float32): + """FP4 RCEIL-mode E8M0 biased scale byte + fp32 inverse scale. + + Mirrors torchao eager ``_to_mx_rceil`` for + ``elem_dtype=torch.float4_e2m1fn_x2`` (``max_pos == F4_E2M1_MAX == 6.0``): + + descale = amax / 6.0 (i.e. amax * INV_F4_E2M1_MAX, NOT 1/448) + biased = cvt.rp.satfinite.ue8m0x2.f32(0.0, descale) # low byte + inv_scale = 2 ** (127 - biased) + + Hardware saturates / handles NaN -> 255, Inf -> 254. The ``inv_scale`` + is derived from the biased byte exactly like the mxfp8 RCEIL helper: + byte ``0xFF`` (NaN) -> 0.0, byte ``0`` -> ``2**126`` (subnormal floor), + else ``2**(127 - biased)``. + + Returns ``(scale_biased, inv_scale)`` where ``scale_biased`` is the + biased E8M0 byte as Int32 (caller stores low 8 bits) and ``inv_scale`` + is a ``Float32``. + """ + descale = amax * INV_F4_E2M1_MAX + scale_biased = cutlass.Int32(_cvt_rp_satfinite_ue8m0x2_f32(descale)) + inv_scale = cutlass.Float32(1.0) + if scale_biased == 0xFF: + inv_scale = cutlass.Float32(0.0) + elif scale_biased == 0: + inv_scale = cute.exp2(cutlass.Float32(126.0)) + else: + inv_scale = cute.exp2(cutlass.Float32(127 - scale_biased)) + return scale_biased, inv_scale + + @cute.jit + def compute_scale_byte_fp4( + amax: cutlass.Float32, + USE_RCEIL: cutlass.Constexpr[bool], + ): + """Dispatch to the FP4 FLOOR / RCEIL E8M0 biased scale byte + inv_scale. + + Matches eager ``to_mx``: a NaN ``amax`` maps to the E8M0 NaN byte + (255). For FLOOR a non-NaN ``amax`` always uses the bit-extraction + path (an all-zero block yields byte ``-2 + 127 = 125``, the same as + eager, since ``floor(log2(0)) = -127`` clamps and ``0 -> extracted + pow2 = -127``... handled by clamp). RCEIL uses the saturating PTX cvt, + which itself maps subnormal/zero descale -> 0. + + Returns ``(scale_biased, inv_scale)``. For a NaN ``amax`` the + ``inv_scale`` is ``0.0`` (matching the E8M0 NaN byte). + """ + # NaN amax -> E8M0 NaN byte, matching eager to_mx. + scale_biased = cutlass.Int32(E8M0_EXPONENT_NAN_VAL) + inv_scale = cutlass.Float32(0.0) + if amax == amax: # not NaN + if cutlass.const_expr(USE_RCEIL): + scale_biased, inv_scale = compute_scale_rceil_fp4(amax) + else: + scale_biased, inv_scale = compute_scale_floor_fp4(amax) + return scale_biased, inv_scale + + @cute.kernel + def _block_scale_e8m0_fp4_kernel( + gx: cute.Tensor, + gscale: cute.Tensor, + num_blocks: cutlass.Int32, + USE_RCEIL: cutlass.Constexpr[bool], + ): + """One thread per 32-element block: amax -> E8M0 biased scale byte. + + ``gx`` is a 2D ``(num_blocks, 32)`` float32 view; ``gscale`` is a 1D + ``(num_blocks,)`` uint8 view (the flattened ``(N, K//32)`` scales). + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + block = bidx * bdim + tidx + if block < num_blocks: + vals = cute.make_rmem_tensor((32,), cutlass.Float32) + for j in cutlass.range_constexpr(32): + vals[j] = cutlass.Float32(gx[block, j]) + amax = compute_amax(vals) + scale_biased, _ = compute_scale_byte_fp4(amax, USE_RCEIL) + gscale[block] = cutlass.Uint8(scale_biased & cutlass.Int32(0xFF)) + + @cute.jit + def _block_scale_e8m0_fp4_launch( + gx: cute.Tensor, + gscale: cute.Tensor, + num_blocks: cutlass.Int32, + stream, + USE_RCEIL: cutlass.Constexpr[bool], + ): + threads = 128 + grid = (num_blocks + threads - 1) // threads + _block_scale_e8m0_fp4_kernel(gx, gscale, num_blocks, USE_RCEIL).launch( + grid=(grid, 1, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def compute_block_scale_e8m0_fp4(x: torch.Tensor, mode: str) -> torch.Tensor: + """Compute per-32-block E8M0 biased scale bytes for FP4. + + Bit-exact (validated by ``test_mxfp4_rht_cutedsl``) against torchao + eager ``to_mx(..., elem_dtype=torch.float4_e2m1fn_x2, block_size=32)`` + plain (unswizzled) scales, for FLOOR and RCEIL modes. + + Args: + x: 2D CUDA tensor ``[N, K]`` (bf16 or fp32) with ``K % 32 == 0``. + mode: ``"floor"`` or ``"rceil"``. + + Returns: + ``(N, K // 32)`` ``torch.float8_e8m0fnu`` CUDA tensor of scale bytes. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dim() == 2, "input must be 2D [N, K]" + n, k = x.shape + assert k % 32 == 0, "K must be divisible by 32" + mode = mode.lower() + assert mode in ("floor", "rceil"), f"unsupported scaling mode: {mode}" + use_rceil = mode == "rceil" + + kb = k // 32 + num_blocks = n * kb + # float32 [num_blocks, 32] view of the per-block elements. + x_f32 = x.to(torch.float32).contiguous().reshape(num_blocks, 32) + scale_u8 = torch.empty((num_blocks,), device=x.device, dtype=torch.uint8) + + gx = from_dlpack(x_f32) + gscale = from_dlpack(scale_u8) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _block_scale_e8m0_fp4_launch( + gx, gscale, cutlass.Int32(num_blocks), stream, use_rceil + ) + return scale_u8.view(torch.float8_e8m0fnu).reshape(n, kb) + + # ------------------------------------------------------------------ + # NVFP4 E4M3 two-level block-scale helpers + # ------------------------------------------------------------------ + + @dsl_user_op + def _cvt_rn_satfinite_e4m3x2_f32( + hi: cutlass.Float32, + lo: cutlass.Float32, + *, + loc=None, + ip=None, + ) -> cutlass.Uint8: + """PTX inline assembly: convert two f32 values to one E4M3x2 byte each. + + ``cvt.rn.satfinite.e4m3x2.f32 d, a, b`` produces a ``.b16`` whose low + byte is ``e4m3(b)`` (``lo``) and high byte is ``e4m3(a)`` (``hi``); the + instruction itself round-to-nearest-saturates to the E4M3 finite max + (448). We pass ``hi == lo`` and return the low byte, so the result is + ``e4m3(lo)``. + + Unlike the E2M1 variant the E4M3x2 cvt already has a ``.b16`` output, so + no ``mov.b16`` reassembly is needed. + """ + packed = cutlass.Uint16( + llvm.inline_asm( + T.i16(), + [ + cutlass.Float32(hi).ir_value(loc=loc, ip=ip), + cutlass.Float32(lo).ir_value(loc=loc, ip=ip), + ], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return cutlass.Uint8(packed & cutlass.Uint16(0xFF)) + + @dsl_user_op + def _cvt_e4m3_byte_to_f32( + byte: cutlass.Uint8, + *, + loc=None, + ip=None, + ) -> cutlass.Float32: + """Dequantize a single E4M3 byte back to Float32. + + Zero-extend the byte into a ``.b16`` (low byte = the e4m3 byte, high + byte = 0) and feed that 16-bit value straight to + ``cvt.rn.f16x2.e4m3x2`` (whose source operand is a ``.b16`` holding the + e4m3x2 pair). The low f16 of the result holds ``e4m3(byte)``, which we + widen to f32 -- the exact value a float8_e4m3fn ``.to(torch.float32)`` + produces for the byte. + + NOTE: we pass the zero-extended ``.b16`` directly as the cvt source. + Extracting it into a ``.reg .b8`` first (``mov.b16 {b, z}, $1`` then + ``cvt ... b``) is rejected by ptxas on the Blackwell + (``sm_100a`` / ``--enable-tvm-ffi``) lowering path with "Arguments + mismatch for instruction 'cvt'" -- ``cvt.f16x2.e4m3x2`` requires a 16-bit + source operand. + """ + src_i16 = cutlass.Uint16(byte) + rst_i32 = cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [src_i16.ir_value(loc=loc, ip=ip)], + "cvt.rn.f16x2.e4m3x2 $0, $1;", + "=r,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + # Low 16 bits = f16(e4m3(byte)); widen to f32. + lo_f16_bits = cutlass.Uint16(rst_i32 & cutlass.Uint32(0xFFFF)) + f16_val = cutlass.Float16( + llvm.bitcast(T.f16(), lo_f16_bits.ir_value(loc=loc, ip=ip)) + ) + return cutlass.Float32(f16_val) + + @cute.jit + def compute_nvfp4_scale_e4m3( + amax: cutlass.Float32, + global_scale: cutlass.Float32, + ): + """NVFP4 two-level E4M3 block scale byte + fp32 data inverse-scale. + + Bit-exactly mirrors the two-level path of torchao eager + ``nvfp4_quantize`` (``nvfp4_tensor.py``), where ``global_scale`` is the + multiplicative reciprocal of torchao's stored ``per_tensor_scale`` + (i.e. ``global_scale = 1 / per_tensor_scale``): + + block_scale = amax / 6.0 # amax * INV_F4_E2M1_MAX + local = block_scale * global_scale # == block_scale / per_tensor_scale + local = clamp(local, E4M3_EPS, 448) + e4m3_byte = cvt.rn.satfinite.e4m3x2.f32(local) + inv_scale = global_scale / dequant_e4m3(e4m3_byte) + + The ``inv_scale`` is what the data quantizer multiplies each element by + before the E2M1 conversion; it matches eager's + ``reciprocal_scale = (1/per_tensor_scale) / scaled_block_scales_fp32``. + + Returns ``(e4m3_byte: Uint8, inv_scale: Float32)``. + """ + block_scale = amax * INV_F4_E2M1_MAX + local = block_scale * global_scale + # Lower clamp to the smallest E4M3 normal (matches eager torch.clamp); + # the upper clamp to 448 is handled by the saturating PTX cvt, but we + # apply it explicitly to be faithful to the eager op order. + if local < E4M3_EPS: + local = E4M3_EPS + if local > F8E4M3_MAX: + local = F8E4M3_MAX + e4m3_byte = _cvt_rn_satfinite_e4m3x2_f32(local, local) + dequant = _cvt_e4m3_byte_to_f32(e4m3_byte) + inv_scale = global_scale / dequant + return e4m3_byte, inv_scale + + @cute.kernel + def _block_scale_e4m3_nvfp4_kernel( + gx: cute.Tensor, + gscale: cute.Tensor, + global_scale: cutlass.Float32, + num_blocks: cutlass.Int32, + ): + """One thread per 16-element block: amax -> NVFP4 E4M3 scale byte. + + ``gx`` is a 2D ``(num_blocks, 16)`` float32 view; ``gscale`` is a 1D + ``(num_blocks,)`` uint8 view (the flattened ``(N, K//16)`` scales). + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + block = bidx * bdim + tidx + if block < num_blocks: + vals = cute.make_rmem_tensor((16,), cutlass.Float32) + for j in cutlass.range_constexpr(16): + vals[j] = cutlass.Float32(gx[block, j]) + amax = compute_amax(vals) + e4m3_byte, _ = compute_nvfp4_scale_e4m3(amax, global_scale) + gscale[block] = e4m3_byte + + @cute.jit + def _block_scale_e4m3_nvfp4_launch( + gx: cute.Tensor, + gscale: cute.Tensor, + global_scale: cutlass.Float32, + num_blocks: cutlass.Int32, + stream, + ): + threads = 128 + grid = (num_blocks + threads - 1) // threads + _block_scale_e4m3_nvfp4_kernel(gx, gscale, global_scale, num_blocks).launch( + grid=(grid, 1, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def compute_block_scale_e4m3_nvfp4( + x: torch.Tensor, global_scale: float + ) -> torch.Tensor: + """Compute per-16-block NVFP4 E4M3 scale bytes. + + Bit-exact (validated by ``test_nvfp4_rht_cutedsl``) against the + two-level path of torchao eager + ``nvfp4_tensor.nvfp4_quantize(x, block_size=16, per_tensor_scale=...)`` + plain (unswizzled) ``float8_e4m3fn`` scales, where + ``global_scale = 1 / per_tensor_scale``. + + Args: + x: 2D CUDA tensor ``[N, K]`` (bf16 or fp32) with ``K % 16 == 0``. + global_scale: Per-tensor multiplicative global scale (the reciprocal + of torchao's ``per_tensor_scale``). + + Returns: + ``(N, K // 16)`` ``torch.float8_e4m3fn`` CUDA tensor of scale bytes. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dim() == 2, "input must be 2D [N, K]" + n, k = x.shape + assert k % 16 == 0, "K must be divisible by 16" + + kb = k // 16 + num_blocks = n * kb + # float32 [num_blocks, 16] view of the per-block elements. + x_f32 = x.to(torch.float32).contiguous().reshape(num_blocks, 16) + scale_u8 = torch.empty((num_blocks,), device=x.device, dtype=torch.uint8) + + gx = from_dlpack(x_f32) + gscale = from_dlpack(scale_u8) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _block_scale_e4m3_nvfp4_launch( + gx, + gscale, + cutlass.Float32(global_scale), + cutlass.Int32(num_blocks), + stream, + ) + return scale_u8.view(torch.float8_e4m3fn).reshape(n, kb) diff --git a/torchao/prototype/mx_formats/cutedsl/fp4_unified_quantize.py b/torchao/prototype/mx_formats/cutedsl/fp4_unified_quantize.py new file mode 100644 index 0000000000..be2f26fb2d --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/fp4_unified_quantize.py @@ -0,0 +1,622 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Unified FP4 (NVFP4 + MXFP4) + optional RHT CuTeDSL quantize cast. + +A single no-smem streaming kernel that supersedes the separate ``nvfp4_rht`` +and ``mxfp4_rht`` maxbw casts: it serves both FP4 formats and all three scale +layouts the GEMM consumers use, with optional fused RHT. + +* ``fmt="nvfp4"`` -- block 16, two-level E4M3 block scale + per-tensor global + scale (``float8_e4m3fn`` scales). Supports arbitrary ``K % 16`` (a masked + remainder handles an odd number of 16-blocks). +* ``fmt="mxfp4"`` -- block 32, single-level E8M0 block scale + (``float8_e8m0fnu`` scales), ``"floor"`` or ``"rceil"``. Requires ``K % 32``. +* ``scale_layout in {"linear", "cublas_blocked", "mma_tiled"}`` is selected at + compile time. ``cublas_blocked`` is the to_blocked padded swizzle consumed by + the f4f4bf16 GEMM; ``mma_tiled`` is the SM100 blockscaled-GEMM atom layout + (no separate scale-conversion pass); ``linear`` is the plain ``(M, K//blk)``. +* optional fused RHT (register FWHT16/32 + sign) is applied per block before + amax / scale / pack; an empty ``sign_vector`` skips it via a compile-time + constexpr (no FWHT overhead on the plain path). + +A "group" is 32 input elements = one 128-bit store = two NVFP4 blocks or one +MXFP4 block; the per-format scale recipe, FWHT size, and MMA row-atom are +compile-time ``FORMAT``-selected so a single kernel body covers both formats. + +Two byte-identical thread mappings are exposed via ``mapping=``: +* ``"striped"`` -- threads stripe a row's groups; grid-strided rows. Best at + very large N. +* ``"wpr"`` -- warp-per-row: warp ``w`` owns contiguous row ``bidy*WARPS+w``, + with the 32 lanes + a ``grid.x`` column split + ILP covering the columns + (replicates the dense-GEMM grid). Best at small / mid N; requires ``K % 32``. + +Gated behind a Blackwell (SM 10.x) GPU, CUDA >= 12.8, and the CuTeDSL runtime +(see ``cutedsl/__init__.py``). Output is byte-exact vs the eager torchao FP4 +casts (and vs the per-format ``{nvfp4,mxfp4}_rht`` CuTeDSL kernels). +""" + +from typing import Optional, Tuple + +import torch + +from torchao.utils import ceil_div + +from .cute_utils import ( + _cvt_rn_satfinite_e2m1x2_f32, + compute_amax, + compute_nvfp4_scale_e4m3, + compute_scale_byte_fp4, +) +from .fwht import fwht16_sign, fwht32_sign + +# NVFP4 two-level global-scale numerator: F8E4M3_MAX * F4_E2M1_MAX = 448 * 6. +_GLOBAL_SCALE_NUMERATOR = 2688.0 +_LAYOUTS = {"linear": 0, "cublas_blocked": 1, "mma_tiled": 2} +_E8M0_NEUTRAL = 127 # 2**0 +_E4M3_NEUTRAL = 0x38 # 1.0 + +# Compiled-launcher + per-shape kernel caches (populated lazily). +_LAUNCH_CACHE = {} +_JIT_CACHE = {} + + +def _get_launches(): + """Define and cache the (uncompiled) striped + warp-per-row ``@cute.jit`` + launchers. ``cutlass`` is imported here (not at module scope) so importing + this module is safe without the CuTeDSL runtime.""" + if "v" in _LAUNCH_CACHE: + return _LAUNCH_CACHE["v"] + + import cutlass + import cutlass.cute as cute + import cutlass.cute.nvgpu as nv + + def _scale_offset(row, kb, LAYOUT, ATOM_M0, pad_cols, rest_m, rest_k, K, NBLK): + if LAYOUT == 2: # mma_tiled (ATOM_M0=128 NVFP4 / 32 MXFP4) + atom_m0 = row % cutlass.Int32(ATOM_M0) + atom_m1 = (row // cutlass.Int32(ATOM_M0)) % cutlass.Int32(4) + rest_m_idx = row // cutlass.Int32(ATOM_M0 * 4) + atom_k = kb % cutlass.Int32(4) + rest_k_idx = kb // cutlass.Int32(4) + stride_rest_m = rest_k * cutlass.Int32(4) + stride_m1 = stride_rest_m * rest_m + stride_m0 = stride_m1 * cutlass.Int32(4) + return ( + atom_m0 * stride_m0 + + atom_m1 * stride_m1 + + rest_m_idx * stride_rest_m + + atom_k * rest_k + + rest_k_idx + ) + elif LAYOUT == 1: # cublas_blocked (format-independent) + r128 = row // cutlass.Int32(128) + r32 = row % cutlass.Int32(32) + r32_4 = (row // cutlass.Int32(32)) % cutlass.Int32(4) + kb4 = kb // cutlass.Int32(4) + kbm = kb % cutlass.Int32(4) + return ( + r128 * cutlass.Int32(128) * pad_cols + + kb4 * cutlass.Int32(512) + + r32 * cutlass.Int32(16) + + r32_4 * cutlass.Int32(4) + + kbm + ) + else: # linear + return row * (K // cutlass.Int32(NBLK)) + kb + + @cute.kernel + def _striped_kernel( # noqa: C901 + gx: cute.Tensor, + gq: cute.Tensor, + gscale: cute.Tensor, + gsign: cute.Tensor, + M: cutlass.Int32, + K: cutlass.Int32, + GPR: cutlass.Int32, + pad_cols: cutlass.Int32, + rest_m: cutlass.Int32, + rest_k: cutlass.Int32, + qstride: cutlass.Int32, + global_scale: cutlass.Float32, + ILP: cutlass.Constexpr[int], + APPLY_RHT: cutlass.Constexpr[bool], + LAYOUT: cutlass.Constexpr[int], + FORMAT: cutlass.Constexpr[int], + USE_RCEIL: cutlass.Constexpr[bool], + LDWIDTH: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy_init, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + gid = bidx * bdim + tidx + nthreads_x = cute.arch.grid_dim()[0] * bdim + NBLK = 16 if FORMAT == 0 else 32 + NB = 32 // NBLK + HALF = NBLK // 2 + ATOM_M0 = 128 if FORMAT == 0 else 32 + in_dt = gx.element_type + ld = cute.make_copy_atom(nv.CopyUniversalOp(), in_dt, num_bits_per_copy=128) + st = cute.make_copy_atom( + nv.CopyUniversalOp(), cutlass.Uint8, num_bits_per_copy=128 + ) + st64 = cute.make_copy_atom( + nv.CopyUniversalOp(), cutlass.Uint8, num_bits_per_copy=64 + ) + NLD = 32 // LDWIDTH + + sign_reg = cute.make_rmem_tensor((NBLK,), cutlass.Float32) + if cutlass.const_expr(APPLY_RHT): + for j in cutlass.range_constexpr(NBLK): + sign_reg[j] = cutlass.Float32(gsign[j]) + + row = bidy_init + while row < M: + base = gid + while base < GPR: + fragbuf = cute.make_rmem_tensor((ILP * 32,), in_dt) + for jj in cutlass.range_constexpr(ILP): + gc = base + jj * nthreads_x + if gc < GPR: + off = cute.assume(row * K + gc * cutlass.Int32(32), divby=32) + for w in cutlass.range_constexpr(NLD): + s = cute.make_tensor( + gx.iterator + off + w * LDWIDTH, + cute.make_layout((LDWIDTH,), stride=(1,)), + ) + f = cute.make_tensor( + fragbuf.iterator + jj * 32 + w * LDWIDTH, + cute.make_layout((LDWIDTH,), stride=(1,)), + ) + cute.copy(ld, s, f) + for jj in cutlass.range_constexpr(ILP): + gc = base + jj * nthreads_x + if gc < GPR: + vals = cute.make_rmem_tensor((32,), cutlass.Float32) + for i in cutlass.range_constexpr(32): + vals[i] = cutlass.Float32(fragbuf[jj * 32 + i]) + packed = cute.make_rmem_tensor((16,), cutlass.Uint8) + for b in cutlass.range_constexpr(NB): + blk = cute.make_tensor( + vals.iterator + b * NBLK, + cute.make_layout((NBLK,), stride=(1,)), + ) + if cutlass.const_expr(APPLY_RHT): + if cutlass.const_expr(FORMAT == 0): + fwht16_sign(blk, sign_reg) + else: + fwht32_sign(blk, sign_reg) + amax = compute_amax(blk) + if cutlass.const_expr(FORMAT == 0): + scb, inv = compute_nvfp4_scale_e4m3(amax, global_scale) + else: + sbi, inv = compute_scale_byte_fp4(amax, USE_RCEIL) + scb = cutlass.Uint8(sbi & cutlass.Int32(0xFF)) + kb = gc * cutlass.Int32(NB) + b + gscale[ + _scale_offset(row, kb, LAYOUT, ATOM_M0, pad_cols, + rest_m, rest_k, K, NBLK) + ] = scb + for p in cutlass.range_constexpr(HALF): + lo = blk[2 * p] * inv + hi = blk[2 * p + 1] * inv + if cutlass.const_expr(FORMAT == 1 and not USE_RCEIL): + if lo > cutlass.Float32(6.0): + lo = cutlass.Float32(6.0) + if lo < cutlass.Float32(-6.0): + lo = cutlass.Float32(-6.0) + if hi > cutlass.Float32(6.0): + hi = cutlass.Float32(6.0) + if hi < cutlass.Float32(-6.0): + hi = cutlass.Float32(-6.0) + packed[b * HALF + p] = _cvt_rn_satfinite_e2m1x2_f32( + hi, lo + ) + offq = cute.assume( + row * qstride + gc * cutlass.Int32(16), divby=16 + ) + d = cute.make_tensor( + gq.iterator + offq, cute.make_layout((16,), stride=(1,)) + ) + cute.copy(st, packed, d) + base = base + nthreads_x * cutlass.Int32(ILP) + # masked remainder (NVFP4 only): leftover 16-block when k_blocks odd + if cutlass.const_expr(FORMAT == 0): + rem = K // cutlass.Int32(16) - cutlass.Int32(2) * GPR + if gid == cutlass.Int32(0): + if rem > cutlass.Int32(0): + kb = cutlass.Int32(2) * GPR + offr = cute.assume(row * K + kb * cutlass.Int32(16), divby=16) + rbuf = cute.make_rmem_tensor((16,), in_dt) + for w in cutlass.range_constexpr(16 // LDWIDTH): + s = cute.make_tensor( + gx.iterator + offr + w * LDWIDTH, + cute.make_layout((LDWIDTH,), stride=(1,)), + ) + f = cute.make_tensor( + rbuf.iterator + w * LDWIDTH, + cute.make_layout((LDWIDTH,), stride=(1,)), + ) + cute.copy(ld, s, f) + blkv = cute.make_rmem_tensor((16,), cutlass.Float32) + for i in cutlass.range_constexpr(16): + blkv[i] = cutlass.Float32(rbuf[i]) + if cutlass.const_expr(APPLY_RHT): + fwht16_sign(blkv, sign_reg) + amax = compute_amax(blkv) + e4m3, inv = compute_nvfp4_scale_e4m3(amax, global_scale) + gscale[ + _scale_offset(row, kb, LAYOUT, 128, pad_cols, rest_m, + rest_k, K, 16) + ] = e4m3 + rpacked = cute.make_rmem_tensor((8,), cutlass.Uint8) + for p in cutlass.range_constexpr(8): + lo = blkv[2 * p] * inv + hi = blkv[2 * p + 1] * inv + rpacked[p] = _cvt_rn_satfinite_e2m1x2_f32(hi, lo) + offrq = cute.assume( + row * qstride + kb * cutlass.Int32(8), divby=8 + ) + dr = cute.make_tensor( + gq.iterator + offrq, cute.make_layout((8,), stride=(1,)) + ) + cute.copy(st64, rpacked, dr) + row = row + cute.arch.grid_dim()[1] + + @cute.jit + def _striped_launch( + gx, gq, gscale, gsign, M, K, GPR, pad_cols, rest_m, rest_k, qstride, gs, + threads, ncta_x, ncta_y, stream, + ILP: cutlass.Constexpr[int], + APPLY_RHT: cutlass.Constexpr[bool], + LAYOUT: cutlass.Constexpr[int], + FORMAT: cutlass.Constexpr[int], + USE_RCEIL: cutlass.Constexpr[bool], + LDWIDTH: cutlass.Constexpr[int], + ): + _striped_kernel( + gx, gq, gscale, gsign, M, K, GPR, pad_cols, rest_m, rest_k, qstride, gs, + ILP, APPLY_RHT, LAYOUT, FORMAT, USE_RCEIL, LDWIDTH, + ).launch( + grid=(ncta_x, ncta_y, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def _wpr_kernel( # noqa: C901 + gx: cute.Tensor, + gq: cute.Tensor, + gscale: cute.Tensor, + gsign: cute.Tensor, + M: cutlass.Int32, + K: cutlass.Int32, + GPR: cutlass.Int32, + pad_cols: cutlass.Int32, + rest_m: cutlass.Int32, + rest_k: cutlass.Int32, + qstride: cutlass.Int32, + global_scale: cutlass.Float32, + ILP: cutlass.Constexpr[int], + APPLY_RHT: cutlass.Constexpr[bool], + LAYOUT: cutlass.Constexpr[int], + FORMAT: cutlass.Constexpr[int], + USE_RCEIL: cutlass.Constexpr[bool], + WARPS: cutlass.Constexpr[int], + LDWIDTH: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + warp_id = tidx // cutlass.Int32(32) + lane = tidx % cutlass.Int32(32) + NBLK = 16 if FORMAT == 0 else 32 + NB = 32 // NBLK + HALF = NBLK // 2 + ATOM_M0 = 128 if FORMAT == 0 else 32 + in_dt = gx.element_type + ld = cute.make_copy_atom(nv.CopyUniversalOp(), in_dt, num_bits_per_copy=128) + st = cute.make_copy_atom( + nv.CopyUniversalOp(), cutlass.Uint8, num_bits_per_copy=128 + ) + NLD = 32 // LDWIDTH + sign_reg = cute.make_rmem_tensor((NBLK,), cutlass.Float32) + if cutlass.const_expr(APPLY_RHT): + for j in cutlass.range_constexpr(NBLK): + sign_reg[j] = cutlass.Float32(gsign[j]) + + row = bidy * cutlass.Int32(WARPS) + warp_id + if row < M: + nthreads_row = cute.arch.grid_dim()[0] * cutlass.Int32(32) + base = bidx * cutlass.Int32(32) + lane + while base < GPR: + fragbuf = cute.make_rmem_tensor((ILP * 32,), in_dt) + for jj in cutlass.range_constexpr(ILP): + gc = base + jj * nthreads_row + if gc < GPR: + off = cute.assume(row * K + gc * cutlass.Int32(32), divby=32) + for w in cutlass.range_constexpr(NLD): + s = cute.make_tensor( + gx.iterator + off + w * LDWIDTH, + cute.make_layout((LDWIDTH,), stride=(1,)), + ) + f = cute.make_tensor( + fragbuf.iterator + jj * 32 + w * LDWIDTH, + cute.make_layout((LDWIDTH,), stride=(1,)), + ) + cute.copy(ld, s, f) + for jj in cutlass.range_constexpr(ILP): + gc = base + jj * nthreads_row + if gc < GPR: + vals = cute.make_rmem_tensor((32,), cutlass.Float32) + for i in cutlass.range_constexpr(32): + vals[i] = cutlass.Float32(fragbuf[jj * 32 + i]) + packed = cute.make_rmem_tensor((16,), cutlass.Uint8) + for b in cutlass.range_constexpr(NB): + blk = cute.make_tensor( + vals.iterator + b * NBLK, + cute.make_layout((NBLK,), stride=(1,)), + ) + if cutlass.const_expr(APPLY_RHT): + if cutlass.const_expr(FORMAT == 0): + fwht16_sign(blk, sign_reg) + else: + fwht32_sign(blk, sign_reg) + amax = compute_amax(blk) + if cutlass.const_expr(FORMAT == 0): + scb, inv = compute_nvfp4_scale_e4m3(amax, global_scale) + else: + sbi, inv = compute_scale_byte_fp4(amax, USE_RCEIL) + scb = cutlass.Uint8(sbi & cutlass.Int32(0xFF)) + kb = gc * cutlass.Int32(NB) + b + gscale[ + _scale_offset(row, kb, LAYOUT, ATOM_M0, pad_cols, + rest_m, rest_k, K, NBLK) + ] = scb + for p in cutlass.range_constexpr(HALF): + lo = blk[2 * p] * inv + hi = blk[2 * p + 1] * inv + if cutlass.const_expr(FORMAT == 1 and not USE_RCEIL): + if lo > cutlass.Float32(6.0): + lo = cutlass.Float32(6.0) + if lo < cutlass.Float32(-6.0): + lo = cutlass.Float32(-6.0) + if hi > cutlass.Float32(6.0): + hi = cutlass.Float32(6.0) + if hi < cutlass.Float32(-6.0): + hi = cutlass.Float32(-6.0) + packed[b * HALF + p] = _cvt_rn_satfinite_e2m1x2_f32( + hi, lo + ) + offq = cute.assume( + row * qstride + gc * cutlass.Int32(16), divby=16 + ) + d = cute.make_tensor( + gq.iterator + offq, cute.make_layout((16,), stride=(1,)) + ) + cute.copy(st, packed, d) + base = base + nthreads_row * cutlass.Int32(ILP) + + @cute.jit + def _wpr_launch( + gx, gq, gscale, gsign, M, K, GPR, pad_cols, rest_m, rest_k, qstride, gs, + block_threads, ncta_x, ncta_y, stream, + ILP: cutlass.Constexpr[int], + APPLY_RHT: cutlass.Constexpr[bool], + LAYOUT: cutlass.Constexpr[int], + FORMAT: cutlass.Constexpr[int], + USE_RCEIL: cutlass.Constexpr[bool], + WARPS: cutlass.Constexpr[int], + LDWIDTH: cutlass.Constexpr[int], + ): + _wpr_kernel( + gx, gq, gscale, gsign, M, K, GPR, pad_cols, rest_m, rest_k, qstride, gs, + ILP, APPLY_RHT, LAYOUT, FORMAT, USE_RCEIL, WARPS, LDWIDTH, + ).launch( + grid=(ncta_x, ncta_y, 1), + block=(block_threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + _LAUNCH_CACHE["v"] = (_striped_launch, _wpr_launch) + return _LAUNCH_CACHE["v"] + + +def _fp4_quantize_unified_impl( + x: torch.Tensor, + sign_vector: Optional[list] = None, + fmt: str = "nvfp4", + scaling_mode: str = "floor", + scale_layout: str = "cublas_blocked", + global_scale: Optional[float] = None, + mapping: str = "wpr", + threads: int = 128, + ilp: int = 2, + rows_per_cta: int = 1, + warps: int = 4, + xsplit: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + import cuda.bindings.driver as cuda + import cutlass + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda and x.dim() == 2 and x.is_contiguous() + assert x.dtype in (torch.float32, torch.bfloat16) + M, K = x.shape + FORMAT = 0 if fmt == "nvfp4" else 1 + BLK = 16 if FORMAT == 0 else 32 + ATOM_M0 = 128 if FORMAT == 0 else 32 + USE_RCEIL = scaling_mode.lower() == "rceil" + L = _LAYOUTS[scale_layout] + if FORMAT == 0: + assert K % 16 == 0, "NVFP4 requires K % 16 == 0" + else: + assert K % 32 == 0, "MXFP4 requires K % 32 == 0" + k_blocks = K // BLK + apply_rht = sign_vector is not None and len(sign_vector) > 0 + if apply_rht: + assert len(sign_vector) == BLK, f"sign_vector must have length {BLK}" + + gs_val = global_scale + if FORMAT == 0 and gs_val is None: + gs_val = _GLOBAL_SCALE_NUMERATOR / x.abs().max().item() + if gs_val is None: + gs_val = 1.0 + + qstride = ceil_div(K // 2, 16) * 16 + q_data = torch.empty((M, qstride), device=x.device, dtype=torch.uint8) + pad_rows = ceil_div(M, 128) * 128 + pad_cols = ceil_div(k_blocks, 4) * 4 + rest_m = ceil_div(M, ATOM_M0 * 4) + rest_k = ceil_div(k_blocks, 4) + neutral = _E4M3_NEUTRAL if FORMAT == 0 else _E8M0_NEUTRAL + if L == 0: + scales_u8 = torch.empty((M, k_blocks), device=x.device, dtype=torch.uint8) + elif L == 1: + scales_u8 = torch.full( + (pad_rows * pad_cols,), neutral, device=x.device, dtype=torch.uint8 + ) + else: + scales_u8 = torch.full( + (ATOM_M0 * 16 * rest_m * rest_k,), neutral, device=x.device, + dtype=torch.uint8, + ) + + sign_src = sign_vector if apply_rht else [0] * BLK + sign_dev = torch.tensor( + [int(s) for s in sign_src], device=x.device, dtype=torch.int32 + ) + + striped_launch, wpr_launch = _get_launches() + GPR = K // 32 + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + ldwidth = 4 if x.dtype == torch.float32 else 8 + common = ( + from_dlpack(x.view(-1), assumed_align=16), + from_dlpack(q_data.view(-1), assumed_align=16), + from_dlpack(scales_u8.view(-1), assumed_align=16), + from_dlpack(sign_dev, assumed_align=16), + cutlass.Int32(M), cutlass.Int32(K), cutlass.Int32(GPR), + cutlass.Int32(pad_cols), cutlass.Int32(rest_m), cutlass.Int32(rest_k), + cutlass.Int32(qstride), cutlass.Float32(gs_val), + ) + + if mapping == "wpr": + assert K % 32 == 0, "wpr mapping requires K % 32 == 0" + block_threads = 32 * warps + ncta_y = ceil_div(M, warps) + wargs = common + (block_threads, xsplit, ncta_y, stream) + key = (str(x.dtype), apply_rht, L, FORMAT, USE_RCEIL, "wpr", warps, ilp) + compiled = _JIT_CACHE.get(key) + if compiled is None: + import cutlass.cute as cute + + compiled = cute.compile( + wpr_launch, *wargs, ilp, apply_rht, L, FORMAT, USE_RCEIL, warps, + ldwidth, + ) + _JIT_CACHE[key] = compiled + compiled(*wargs) + else: + nthreads_needed = (GPR + ilp - 1) // ilp + ncta_x = max(1, (nthreads_needed + threads - 1) // threads) + ncta_y = ceil_div(M, rows_per_cta) + args = common + (threads, ncta_x, ncta_y, stream) + key = (str(x.dtype), apply_rht, L, threads, ilp, FORMAT, USE_RCEIL) + compiled = _JIT_CACHE.get(key) + if compiled is None: + import cutlass.cute as cute + + compiled = cute.compile( + striped_launch, *args, ilp, apply_rht, L, FORMAT, USE_RCEIL, ldwidth + ) + _JIT_CACHE[key] = compiled + compiled(*args) + + return q_data[:, : K // 2], scales_u8 + + +@torch.library.custom_op("torchao::fp4_quantize_unified", mutates_args=()) +def fp4_quantize_unified( + input: torch.Tensor, + sign_vector: list[int], + fmt: str = "nvfp4", + scaling_mode: str = "floor", + scale_layout: str = "cublas_blocked", +) -> Tuple[torch.Tensor, torch.Tensor]: + """Unified FP4 (NVFP4/MXFP4 +/- RHT) quantize custom op. + + Empty ``sign_vector`` selects the plain cast. For NVFP4 the per-tensor + global scale is computed from the input amax. + """ + return _fp4_quantize_unified_impl( + input, + sign_vector=list(sign_vector) if len(sign_vector) > 0 else None, + fmt=fmt, + scaling_mode=scaling_mode, + scale_layout=scale_layout, + ) + + +@fp4_quantize_unified.register_fake +def _( + input: torch.Tensor, + sign_vector: list[int], + fmt: str = "nvfp4", + scaling_mode: str = "floor", + scale_layout: str = "cublas_blocked", +) -> Tuple[torch.Tensor, torch.Tensor]: + M, K = input.shape + blk = 16 if fmt == "nvfp4" else 32 + k_blocks = K // blk + qdata = torch.empty((M, K // 2), device=input.device, dtype=torch.uint8) + if scale_layout == "linear": + scales = torch.empty((M, k_blocks), device=input.device, dtype=torch.uint8) + elif scale_layout == "cublas_blocked": + pr = ceil_div(M, 128) * 128 + pc = ceil_div(k_blocks, 4) * 4 + scales = torch.empty((pr * pc,), device=input.device, dtype=torch.uint8) + else: + atom_m0 = 128 if fmt == "nvfp4" else 32 + rest_m = ceil_div(M, atom_m0 * 4) + rest_k = ceil_div(k_blocks, 4) + scales = torch.empty( + (atom_m0 * 16 * rest_m * rest_k,), device=input.device, dtype=torch.uint8 + ) + return qdata, scales + + +def fp4_quantize_unified_2d( + x: torch.Tensor, + sign_vector=None, + fmt: str = "nvfp4", + scaling_mode: str = "floor", + scale_layout: str = "cublas_blocked", +) -> Tuple[torch.Tensor, torch.Tensor]: + """Gated public wrapper for the unified FP4 (+/- RHT) CuTeDSL quantize cast. + + Raises ``NotImplementedError`` (with the missing-runtime detail) when the + CuTeDSL runtime / SM 10.x / CUDA >= 12.8 requirements are not met. An empty + / ``None`` ``sign_vector`` selects the plain cast. + """ + from torchao.prototype.mx_formats.cutedsl import ( + _fp4_cutedsl_kernels_available, + ) + + if not _fp4_cutedsl_kernels_available: + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + _missing_cutedsl_runtime_packages, + ) + + raise NotImplementedError( + "fp4_quantize_unified requires CUDA SM10.x, CUDA>=12.8, and: " + f"{_missing_cutedsl_runtime_packages() or 'nvidia-cutlass-dsl'}" + ) + return fp4_quantize_unified( + x, list(sign_vector) if sign_vector is not None else [], fmt, scaling_mode, + scale_layout, + ) diff --git a/torchao/prototype/mx_formats/cutedsl/fwht.py b/torchao/prototype/mx_formats/cutedsl/fwht.py new file mode 100644 index 0000000000..efea5539a3 --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/fwht.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Register-resident FWHT + sign transforms for the MXFP4/NVFP4 + RHT casts. + +Implements, per block, the normalized Fast Walsh-Hadamard Transform followed by +an elementwise sign multiply: + + out = (FWHT(vals) / sqrt(N)) * sign + +This equals ``vals @ hadamard_matrix(N) * sign`` where ``hadamard_matrix(N)`` +is torchao's normalized (1/sqrt(N)) symmetric Sylvester/Walsh-Hadamard matrix +(see ``torchao/prototype/spinquant/hadamard_utils.py``). + +* ``fwht32_sign`` (N=32, 5 butterfly stages) is used by the MXFP4 (block 32) + cast. +* ``fwht16_sign`` (N=16, 4 butterfly stages) is used by the NVFP4 (block 16) + cast. + +The transforms are pure device helpers so the fused kernels can call them per +block from registers. ``fwht{32,16}_sign_host`` are tiny one-block-per-row +kernels used only to validate the device helpers against a dense torch +reference. +""" + +import math + +import torch + +from .cute_utils import _cutedsl_runtime_available + +# 1 / sqrt(32); applied once after the 5 butterfly stages. +_INV_SQRT_32 = 1.0 / math.sqrt(32.0) +# 1 / sqrt(16) == 0.25; applied once after the 4 butterfly stages. +_INV_SQRT_16 = 1.0 / math.sqrt(16.0) + + +if _cutedsl_runtime_available(): + import cutlass + import cutlass.cute as cute + + # Normalization scalar as a device Float32 constant. + INV_SQRT_32 = cutlass.Float32(_INV_SQRT_32) + INV_SQRT_16 = cutlass.Float32(_INV_SQRT_16) + + @cute.jit + def fwht32_sign(vals: cute.Tensor, sign: cute.Tensor) -> cute.Tensor: + """In-register normalized FWHT(32) followed by an elementwise sign mul. + + Computes, for the length-32 register fragment ``vals``:: + + out[j] = (FWHT(vals) / sqrt(32))[j] * sign[j] + + which equals ``(vals @ hadamard_matrix(32))[j] * sign[j]``. + + The transform is the standard radix-2 decimation-in-time Walsh-Hadamard + butterfly over strides ``s in {1, 2, 4, 8, 16}``. For each stride, every + pair ``(i, i + s)`` with ``(i & s) == 0`` is touched exactly once:: + + a, b = vals[i] + vals[i + s], vals[i] - vals[i + s] + vals[i], vals[i + s] = a, b + + All pair indices are compile-time constants (the 5 stages and their + pairings are fixed for a length-32 block), so the loops are unrolled via + ``cutlass.range_constexpr``. This pairing/orientation is validated + bit-for-bit-close against ``hadamard_matrix(32)`` by + ``test_mxfp4_rht_cutedsl``. + + Args: + vals: length-32 ``Float32`` register fragment (modified in place and + also returned). + sign: length-32 register fragment of ``{-1, +1}`` values (any + arithmetic dtype; cast to ``Float32`` for the multiply). + + Returns: + The transformed length-32 ``Float32`` fragment (the same ``vals``). + """ + # 5-stage in-register butterfly. Strides are powers of two up to 16. + for stage in cutlass.range_constexpr(5): + s = 1 << stage + for i in cutlass.range_constexpr(32): + if cutlass.const_expr((i & s) == 0): + a = cutlass.Float32(vals[i]) + b = cutlass.Float32(vals[i + s]) + vals[i] = a + b + vals[i + s] = a - b + + # Normalize then apply the sign vector. + for j in cutlass.range_constexpr(32): + vals[j] = cutlass.Float32(vals[j]) * INV_SQRT_32 * cutlass.Float32(sign[j]) + + return vals + + @cute.jit + def fwht16_sign(vals: cute.Tensor, sign: cute.Tensor) -> cute.Tensor: + """In-register normalized FWHT(16) followed by an elementwise sign mul. + + Computes, for the length-16 register fragment ``vals``:: + + out[j] = (FWHT(vals) / sqrt(16))[j] * sign[j] + + which equals ``(vals @ hadamard_matrix(16))[j] * sign[j]``. + + Same radix-2 decimation-in-time Walsh-Hadamard butterfly as + ``fwht32_sign`` but for a 16-element block: 4 stages over strides + ``s in {1, 2, 4, 8}``. For each stride, every pair ``(i, i + s)`` with + ``(i & s) == 0`` is touched exactly once:: + + a, b = vals[i] + vals[i + s], vals[i] - vals[i + s] + vals[i], vals[i + s] = a, b + + All pair indices are compile-time constants (the 4 stages and their + pairings are fixed for a length-16 block), so the loops are unrolled via + ``cutlass.range_constexpr``. + + Args: + vals: length-16 ``Float32`` register fragment (modified in place and + also returned). + sign: length-16 register fragment of ``{-1, +1}`` values (any + arithmetic dtype; cast to ``Float32`` for the multiply). + + Returns: + The transformed length-16 ``Float32`` fragment (the same ``vals``). + """ + # 4-stage in-register butterfly. Strides are powers of two up to 8. + for stage in cutlass.range_constexpr(4): + s = 1 << stage + for i in cutlass.range_constexpr(16): + if cutlass.const_expr((i & s) == 0): + a = cutlass.Float32(vals[i]) + b = cutlass.Float32(vals[i + s]) + vals[i] = a + b + vals[i + s] = a - b + + # Normalize (1/sqrt(16) == 0.25) then apply the sign vector. + for j in cutlass.range_constexpr(16): + vals[j] = cutlass.Float32(vals[j]) * INV_SQRT_16 * cutlass.Float32(sign[j]) + + return vals + + @cute.kernel + def _fwht16_sign_kernel( + gx: cute.Tensor, + gsign: cute.Tensor, + gout: cute.Tensor, + num_rows: cutlass.Int32, + ): + """One thread per row: load 16 f32 + sign, apply ``fwht16_sign``, store. + + ``gx`` / ``gout`` are 2D ``(num_rows, 16)`` float32 views; ``gsign`` is a + 1D ``(16,)`` int32 view broadcast across all rows. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + row = bidx * bdim + tidx + if row < num_rows: + vals = cute.make_rmem_tensor((16,), cutlass.Float32) + sign = cute.make_rmem_tensor((16,), cutlass.Float32) + for j in cutlass.range_constexpr(16): + vals[j] = cutlass.Float32(gx[row, j]) + sign[j] = cutlass.Float32(gsign[j]) + fwht16_sign(vals, sign) + for j in cutlass.range_constexpr(16): + gout[row, j] = cutlass.Float32(vals[j]) + + @cute.jit + def _fwht16_sign_launch( + gx: cute.Tensor, + gsign: cute.Tensor, + gout: cute.Tensor, + num_rows: cutlass.Int32, + stream, + ): + threads = 128 + grid = (num_rows + threads - 1) // threads + _fwht16_sign_kernel(gx, gsign, gout, num_rows).launch( + grid=(grid, 1, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def fwht16_sign_host(x: torch.Tensor, sign: torch.Tensor) -> torch.Tensor: + """Apply the normalized FWHT(16) + sign transform to each row of ``x``. + + Test/validation entry only -- the production kernel inlines the same + ``fwht16_sign`` device helper per block. Matches (to fp32 rounding) + ``(x @ hadamard_matrix(16)) * sign`` row-by-row. + + Args: + x: 2D float32 CUDA tensor ``[N, 16]``. + sign: length-16 integer CUDA tensor of ``{-1, +1}`` values. + + Returns: + A ``(N, 16)`` float32 CUDA tensor. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dim() == 2 and x.shape[1] == 16, "input must be 2D [N, 16]" + assert x.dtype == torch.float32, "input must be float32" + assert sign.is_cuda, "sign must be a CUDA tensor" + assert sign.numel() == 16, "sign must have exactly 16 elements" + + num_rows = x.shape[0] + x = x.contiguous() + sign_i32 = sign.to(torch.int32).contiguous() + out = torch.empty((num_rows, 16), device=x.device, dtype=torch.float32) + + gx = from_dlpack(x) + gsign = from_dlpack(sign_i32) + gout = from_dlpack(out) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _fwht16_sign_launch(gx, gsign, gout, cutlass.Int32(num_rows), stream) + return out + + @cute.kernel + def _fwht32_sign_kernel( + gx: cute.Tensor, + gsign: cute.Tensor, + gout: cute.Tensor, + num_rows: cutlass.Int32, + ): + """One thread per row: load 32 f32 + sign, apply ``fwht32_sign``, store. + + ``gx`` / ``gout`` are 2D ``(num_rows, 32)`` float32 views; ``gsign`` is a + 1D ``(32,)`` int32 view broadcast across all rows. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + row = bidx * bdim + tidx + if row < num_rows: + vals = cute.make_rmem_tensor((32,), cutlass.Float32) + sign = cute.make_rmem_tensor((32,), cutlass.Float32) + for j in cutlass.range_constexpr(32): + vals[j] = cutlass.Float32(gx[row, j]) + sign[j] = cutlass.Float32(gsign[j]) + fwht32_sign(vals, sign) + for j in cutlass.range_constexpr(32): + gout[row, j] = cutlass.Float32(vals[j]) + + @cute.jit + def _fwht32_sign_launch( + gx: cute.Tensor, + gsign: cute.Tensor, + gout: cute.Tensor, + num_rows: cutlass.Int32, + stream, + ): + threads = 128 + grid = (num_rows + threads - 1) // threads + _fwht32_sign_kernel(gx, gsign, gout, num_rows).launch( + grid=(grid, 1, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def fwht32_sign_host(x: torch.Tensor, sign: torch.Tensor) -> torch.Tensor: + """Apply the normalized FWHT(32) + sign transform to each row of ``x``. + + Test/validation entry only -- the production kernel inlines the same + ``fwht32_sign`` device helper per block. Matches (to fp32 rounding) + ``(x @ hadamard_matrix(32)) * sign`` row-by-row. + + Args: + x: 2D float32 CUDA tensor ``[N, 32]``. + sign: length-32 integer CUDA tensor of ``{-1, +1}`` values. + + Returns: + A ``(N, 32)`` float32 CUDA tensor. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dim() == 2 and x.shape[1] == 32, "input must be 2D [N, 32]" + assert x.dtype == torch.float32, "input must be float32" + assert sign.is_cuda, "sign must be a CUDA tensor" + assert sign.numel() == 32, "sign must have exactly 32 elements" + + num_rows = x.shape[0] + x = x.contiguous() + sign_i32 = sign.to(torch.int32).contiguous() + out = torch.empty((num_rows, 32), device=x.device, dtype=torch.float32) + + gx = from_dlpack(x) + gsign = from_dlpack(sign_i32) + gout = from_dlpack(out) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _fwht32_sign_launch(gx, gsign, gout, cutlass.Int32(num_rows), stream) + return out