diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index 7a1cc76e97..e6fab04bed 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -8,8 +8,13 @@ import fire import torch -import triton -from triton.testing import do_bench + +try: + import triton + from triton.testing import do_bench +except ImportError: + triton = None + do_bench = None from torchao.prototype.mx_formats.config import ScaleCalculationMode from torchao.prototype.mx_formats.kernels import ( @@ -94,7 +99,19 @@ def to_nvfp4_reference_triton_swizzle(x_hp): def benchmark_cuda_function_in_microseconds(f, *args): - return do_bench(lambda: f(*args), return_mode="median") * 1e3 + if do_bench is not None: + return do_bench(lambda: f(*args), return_mode="median") * 1e3 + else: + # Fallback timing when triton is not available + import time + + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(100): + f(*args) + torch.cuda.synchronize() + end = time.perf_counter() + return ((end - start) / 100) * 1e6 # Convert to microseconds def run( @@ -106,7 +123,9 @@ def run( print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}") print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"torch version: {torch.__version__}") - print(f"triton version: {triton.__version__}") + print( + f"triton version: {triton.__version__ if triton is not None else 'not available'}" + ) print(f"mode: {mode}") assert mode in ( "memcpy", @@ -130,6 +149,7 @@ def run( "dim0_mxfp8_cutedsl_2d_rceil", "dim1_mxfp8_cutedsl_2d_floor", "dim1_mxfp8_cutedsl_2d_rceil", + "dim1_mxfp4_rht_cutedsl_floor", ) x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 @@ -597,6 +617,35 @@ def run( bytes_r = x.numel() * bytes_per_el_bf16 bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim1_mxfp4_rht_cutedsl_floor": + from torchao.prototype.mx_formats.cutedsl import mxfp4_rht_quantize_cutedsl_2d + + # Generate sign vector: 32 random signs + sign = ( + (torch.randint(0, 2, (32,), device="cuda") * 2 - 1).to(torch.int32).tolist() + ) + + # Warmup + for _ in range(2): + __ = mxfp4_rht_quantize_cutedsl_2d(x, sign, 32, "floor", True) + + # Benchmark + time_us = benchmark_cuda_function_in_microseconds( + lambda x: mxfp4_rht_quantize_cutedsl_2d(x, sign, 32, "floor", True), + x, + ) + + # Validate output types + y_d1, s_d1 = mxfp4_rht_quantize_cutedsl_2d(x, sign, 32, "floor", True) + assert y_d1.dtype == torch.uint8 + assert s_d1.dtype == torch.float8_e8m0fnu + + # Throughput accounting: input bf16, output packed fp4 (uint8) + scales (fp8) + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + else: raise AssertionError(f"unknown mode {mode}") diff --git a/test/prototype/mx_formats/test_mxfp4_rht_cutedsl.py b/test/prototype/mx_formats/test_mxfp4_rht_cutedsl.py new file mode 100644 index 0000000000..cb75155fff --- /dev/null +++ b/test/prototype/mx_formats/test_mxfp4_rht_cutedsl.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchao.prototype.mx_formats.cutedsl import _mxfp4_rht_cutedsl_kernels_available + +pytestmark = pytest.mark.skipif( + not _mxfp4_rht_cutedsl_kernels_available, + reason="mxfp4 rht cutedsl unavailable (needs SM10.x, CUDA>=12.8, nvidia-cutlass-dsl)", +) + + +class TestE2M1Packing: + def test_pack_bitexact_vs_reference(self): + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + pack32_e2m1_to_bytes, + ) + from torchao.prototype.mx_formats.kernels import ( + f32_to_f4_unpacked, + pack_uint4, + ) + + torch.manual_seed(0) + # len 32, spans the E2M1 grid + saturation + both parities/signs + x = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.5, + -1.0, + -6.0, + 7.0, + -7.0, + 0.25, + 5.0, + -2.5, + ] + * 2, + dtype=torch.float32, + device="cuda", + ) + ours = pack32_e2m1_to_bytes(x) # (16,) uint8 + ref = ( + pack_uint4(f32_to_f4_unpacked(x.reshape(1, 32))).view(torch.uint8).flatten() + ) # (16,) + torch.testing.assert_close(ours, ref, rtol=0, atol=0) + + +class TestFp4Scale: + @pytest.mark.parametrize("mode", ["floor", "rceil"]) + def test_scale_bytes_match_eager(self, mode): + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + compute_block_scale_e8m0_fp4, + ) + from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx + + torch.manual_seed(0) + x = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") * 7.0 + m = { + "floor": ScaleCalculationMode.FLOOR, + "rceil": ScaleCalculationMode.RCEIL, + }[mode] + # plain (unswizzled) e8m0 scales, shape (64, 1) + s_ref, _ = to_mx(x, torch.float4_e2m1fn_x2, block_size=32, scaling_mode=m) + s_ours = compute_block_scale_e8m0_fp4(x, mode) # (64, 1) float8_e8m0fnu + torch.testing.assert_close( + s_ours.view(torch.uint8).flatten(), + s_ref.view(torch.uint8).flatten(), + rtol=0, + atol=0, + ) + + +class TestFwht32: + def test_fwht32_sign_matches_dense(self): + from torchao.prototype.mx_formats.cutedsl.fwht import fwht32_sign_host + from torchao.prototype.spinquant.hadamard_utils import hadamard_matrix + + torch.manual_seed(0) + x = torch.randn(128, 32, dtype=torch.float32, device="cuda") + sign = (torch.randint(0, 2, (32,), device="cuda") * 2 - 1).to(torch.int32) + # normalized (1/sqrt(32)), symmetric Sylvester/Walsh-Hadamard matrix + H = hadamard_matrix(32, device="cuda").to(torch.float32) + ref = (x @ H) * sign.to(torch.float32) + ours = fwht32_sign_host(x, sign) # (128, 32) f32 + torch.testing.assert_close(ours, ref, rtol=1e-3, atol=1e-3) + + +class TestMxfp4RhtSmoke: + def test_runs_and_shapes(self): + from torchao.prototype.mx_formats.cutedsl import ( + mxfp4_rht_quantize_cutedsl_2d, + ) + + torch.manual_seed(0) + M, K = 128, 256 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + sign = [1, -1] * 16 + q, s = mxfp4_rht_quantize_cutedsl_2d(x, sign, 32, "floor", True) + assert q.shape == (M, K // 2) + assert q.dtype == torch.uint8 + assert q.stride() == (K // 2, 1) + assert s.dtype == torch.float8_e8m0fnu + assert int((s.view(torch.uint8) != 0).sum()) > 0 + + +class TestMxfp4RhtE2E: + @pytest.mark.parametrize("mode", ["floor", "rceil"]) + @pytest.mark.parametrize("shape", [(128, 256), (256, 512), (512, 128)]) + def test_bitexact_vs_emulated_same_rht(self, mode, shape): + # (A) Feed the SAME RHT values to both sides via the validated + # ``fwht32_sign_host`` helper (the EXACT transform the kernel applies + # internally). This isolates quant/pack/scale/swizzle: since the FWHT is + # identical on both sides and Tasks 1-2 validated quant/pack/scale + # bit-exactly, this must be bit-exact -- any diff is a real kernel bug. + from torchao.prototype.mx_formats.cutedsl import ( + mxfp4_rht_quantize_cutedsl_2d, + ) + from torchao.prototype.mx_formats.cutedsl.fwht import fwht32_sign_host + from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx + from torchao.prototype.mx_formats.utils import to_blocked + + torch.manual_seed(0) + M, K = shape + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + sign = (torch.randint(0, 2, (32,), device="cuda") * 2 - 1).to(torch.int32) + + # The kernel TMA-loads the bf16 input and widens each element to fp32 + # before the FWHT, so the reference must apply the FWHT to the SAME + # bf16-rounded-then-widened values (``x.float()``). ``fwht32_sign_host`` + # is bit-identical to the device ``fwht32_sign`` the kernel inlines. + rht = fwht32_sign_host(x.float().reshape(-1, 32), sign).reshape(M, K) + + sm = { + "floor": ScaleCalculationMode.FLOOR, + "rceil": ScaleCalculationMode.RCEIL, + }[mode] + # to_mx returns (scale, data) in that order. + s_ref, q_ref = to_mx( + rht, torch.float4_e2m1fn_x2, block_size=32, scaling_mode=sm + ) + s_ref_sw = to_blocked(s_ref.view(M, K // 32)) + + q, s = mxfp4_rht_quantize_cutedsl_2d(x, sign.tolist(), 32, mode, True) + + torch.testing.assert_close( + q.view(torch.uint8), q_ref.view(torch.uint8), rtol=0, atol=0 + ) + torch.testing.assert_close( + s.view(torch.uint8).flatten()[: s_ref_sw.numel()], + s_ref_sw.view(torch.uint8).flatten(), + rtol=0, + atol=0, + ) + assert q.stride() == (K // 2, 1) + + @pytest.mark.parametrize("mode", ["floor", "rceil"]) + def test_sqnr_vs_dense_reference(self, mode): + # (B) Faithfulness of the WHOLE fused pipeline (incl. the FWHT) vs the + # true high-precision dense RHT ``(x @ H) * sign``. + from torchao.prototype.mx_formats.cutedsl import ( + mxfp4_rht_quantize_cutedsl_2d, + ) + from torchao.prototype.mx_formats.kernels import ( + f4_unpacked_to_f32, + unpack_uint4, + ) + from torchao.prototype.spinquant.hadamard_utils import hadamard_matrix + from torchao.quantization.utils import compute_error + + torch.manual_seed(0) + M, K = 256, 512 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + sign = (torch.randint(0, 2, (32,), device="cuda") * 2 - 1).to(torch.int32) + + # True high-precision dense RHT on the (bf16-rounded) input the kernel + # sees: (x @ H) * sign. hadamard_matrix(32) is normalized (1/sqrt(32)). + H = hadamard_matrix(32, device="cuda").to(torch.float32) + rht_true = ((x.float().reshape(-1, 32) @ H) * sign.float()).reshape(M, K) + + # Plain (unswizzled) scales of shape (M, K // 32) for easy dequant. + q, s = mxfp4_rht_quantize_cutedsl_2d(x, sign.tolist(), 32, mode, False) + + # Dequant: unpack two nibbles/byte -> e2m1 codes (low nibble first), + # decode to fp32 values, then multiply by 2^(e8m0_byte - 127) per block. + codes = unpack_uint4(q) # (M, K) uint8 fp4 codes in bits 0-3 + vals = f4_unpacked_to_f32(codes).reshape(M, K) + e8 = s.view(torch.uint8).to(torch.int32).reshape(M, K // 32) + scale = torch.pow( + torch.tensor(2.0, device="cuda"), (e8 - 127).float() + ).repeat_interleave(32, dim=1) + deq = vals * scale + + sqnr = compute_error(rht_true, deq).item() + assert sqnr >= 13.0, f"SQNR {sqnr} dB below 13 dB for mode={mode}" + + +class TestMxTensorThreading: + @pytest.mark.parametrize("mode", ["floor", "rceil"]) + def test_mxtensor_cutedsl_matches_standalone(self, mode): + # The opt-in CUTEDSL path through MXTensor.to_mx must produce qdata/scale + # bit-identical to the standalone op called with the same arguments + # (same is_swizzled_scales=True -> apples-to-apples). + from torchao.prototype.mx_formats.config import MXFP4CastKernelChoice + from torchao.prototype.mx_formats.cutedsl import ( + mxfp4_rht_quantize_cutedsl_2d, + ) + from torchao.prototype.mx_formats.mx_tensor import ( + MXTensor, + ScaleCalculationMode, + ) + + torch.manual_seed(0) + M, K = 256, 512 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + sign = ( + (torch.randint(0, 2, (32,), device="cuda") * 2 - 1).to(torch.int32).tolist() + ) + sm = { + "floor": ScaleCalculationMode.FLOOR, + "rceil": ScaleCalculationMode.RCEIL, + }[mode] + mxt = MXTensor.to_mx( + x, + torch.float4_e2m1fn_x2, + block_size=32, + scaling_mode=sm, + is_swizzled_scales=True, + mxfp4_cast_kernel_choice=MXFP4CastKernelChoice.CUTEDSL, + rht_sign_vector=sign, + ) + q_ref, s_ref = mxfp4_rht_quantize_cutedsl_2d(x, sign, 32, mode, True) + torch.testing.assert_close( + mxt.qdata.view(torch.uint8), q_ref.view(torch.uint8), rtol=0, atol=0 + ) + torch.testing.assert_close( + mxt.scale.view(torch.uint8).flatten(), + s_ref.view(torch.uint8).flatten(), + rtol=0, + atol=0, + ) + + def test_default_path_unchanged(self): + # The default (TORCH) fp4 cast still works and does NOT require a sign + # vector -- the new trailing params are opt-in only. + from torchao.prototype.mx_formats.mx_tensor import MXTensor + + torch.manual_seed(0) + x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda") + mxt = MXTensor.to_mx(x, torch.float4_e2m1fn_x2, block_size=32) + assert mxt.qdata.shape == (128, 128) diff --git a/test/prototype/mx_formats/test_nvfp4_rht_cutedsl.py b/test/prototype/mx_formats/test_nvfp4_rht_cutedsl.py new file mode 100644 index 0000000000..0ac29e06ed --- /dev/null +++ b/test/prototype/mx_formats/test_nvfp4_rht_cutedsl.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchao.prototype.mx_formats.cutedsl import ( + _mxfp4_rht_cutedsl_kernels_available, +) + +pytestmark = pytest.mark.skipif( + not _mxfp4_rht_cutedsl_kernels_available, + reason="cutedsl nvfp4 unavailable (SM10.x, CUDA>=12.8, nvidia-cutlass-dsl)", +) + + +def _eager_nvfp4_e4m3_block_scales( + x: torch.Tensor, per_tensor_scale: torch.Tensor +) -> torch.Tensor: + """torchao eager NVFP4 per-16-block E4M3 scale bytes, plain (unswizzled). + + Mirrors the two-level path of + ``torchao.prototype.mx_formats.nvfp4_tensor.nvfp4_quantize``: + + block_scale = amax / F4_E2M1_MAX # amax / 6.0, fp32 + scaled_block_scales = block_scale / per_tensor_scale + e4m3 = clamp(scaled, E4M3_EPS, 448).to(float8_e4m3fn) + """ + from torchao.prototype.mx_formats.nvfp4_tensor import nvfp4_quantize + + scales, _ = nvfp4_quantize(x, block_size=16, per_tensor_scale=per_tensor_scale) + return scales + + +class TestNvfp4E4M3Scale: + def test_scale_bytes_match_eager(self): + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + compute_block_scale_e4m3_nvfp4, + ) + from torchao.prototype.mx_formats.nvfp4_tensor import ( + per_tensor_amax_to_scale, + ) + + torch.manual_seed(0) + x = torch.randn(64, 16, dtype=torch.bfloat16, device="cuda") * 7.0 + + # global_scale per torchao's convention. torchao stores it as a divisor + # (``per_tensor_scale = amax_global / (448 * 6)``); our host helper takes + # the multiplicative ``global_scale = 1 / per_tensor_scale``. + amax_global = torch.max(torch.abs(x)) + per_tensor_scale = per_tensor_amax_to_scale(amax_global) + global_scale = (1.0 / per_tensor_scale).item() + + s_ref = _eager_nvfp4_e4m3_block_scales(x, per_tensor_scale) + assert s_ref.shape == (64, 1) + assert s_ref.dtype == torch.float8_e4m3fn + + s_ours = compute_block_scale_e4m3_nvfp4(x, global_scale) + assert s_ours.shape == (64, 1) + assert s_ours.dtype == torch.float8_e4m3fn + + torch.testing.assert_close( + s_ours.view(torch.uint8).flatten(), + s_ref.view(torch.uint8).flatten(), + rtol=0, + atol=0, + ) + + def test_scale_bytes_match_eager_wide(self): + # Wider K (multiple 16-blocks per row) + a separate seed/scale. + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + compute_block_scale_e4m3_nvfp4, + ) + from torchao.prototype.mx_formats.nvfp4_tensor import ( + per_tensor_amax_to_scale, + ) + + torch.manual_seed(7) + x = torch.randn(33, 128, dtype=torch.bfloat16, device="cuda") * 3.0 + + amax_global = torch.max(torch.abs(x)) + per_tensor_scale = per_tensor_amax_to_scale(amax_global) + global_scale = (1.0 / per_tensor_scale).item() + + s_ref = _eager_nvfp4_e4m3_block_scales(x, per_tensor_scale) + s_ours = compute_block_scale_e4m3_nvfp4(x, global_scale) + assert s_ref.shape == (33, 8) + assert s_ours.shape == (33, 8) + + torch.testing.assert_close( + s_ours.view(torch.uint8).flatten(), + s_ref.view(torch.uint8).flatten(), + rtol=0, + atol=0, + ) + + +class TestNvfp4RhtSmoke: + """Shape / dtype / stride + non-zero-scale smoke for the fused kernel. + + Exercises BOTH the plain NVFP4 cast (no RHT) and the fused FWHT(16) + NVFP4 + RHT cast through the public gated wrapper. Bit-exact correctness vs eager is + a separate task; here we only check the kernel compiles, launches, and emits + well-formed outputs. + """ + + def _global_scale(self, x: torch.Tensor) -> float: + # 2688 == F8E4M3_MAX (448) * F4_E2M1_MAX (6); global_scale is the + # multiplicative reciprocal of torchao's per_tensor_scale. + return 2688.0 / x.abs().max().item() + + def _check_outputs(self, q, s, M, K, block_size=16): + assert q.shape == (M, K // 2) + assert q.dtype == torch.uint8 + assert q.stride() == (K // 2, 1) + assert s.dtype == torch.float8_e4m3fn + # scales must be non-zero (a degenerate all-zero scale tensor would mean + # the consumer never wrote anything). + s_u8 = s.view(torch.uint8) + assert int((s_u8 != 0).sum().item()) > 0, "scales are all zero" + + def test_plain_nvfp4_cast(self): + from torchao.prototype.mx_formats.cutedsl import ( + nvfp4_rht_quantize_cutedsl_2d, + ) + + torch.manual_seed(0) + M, K = 128, 256 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 5.0 + global_scale = self._global_scale(x) + + # sign_vector=None -> plain NVFP4 cast (no RHT). + q, s = nvfp4_rht_quantize_cutedsl_2d( + x, global_scale, sign_vector=None, is_swizzled_scales=True + ) + self._check_outputs(q, s, M, K) + + # Empty list is the same plain-cast path. + q2, s2 = nvfp4_rht_quantize_cutedsl_2d( + x, global_scale, sign_vector=[], is_swizzled_scales=True + ) + self._check_outputs(q2, s2, M, K) + + def test_rht_nvfp4_cast(self): + from torchao.prototype.mx_formats.cutedsl import ( + nvfp4_rht_quantize_cutedsl_2d, + ) + + torch.manual_seed(0) + M, K = 128, 256 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 5.0 + global_scale = self._global_scale(x) + + sign_vector = [1, -1] * 8 # len 16 + q, s = nvfp4_rht_quantize_cutedsl_2d( + x, global_scale, sign_vector=sign_vector, is_swizzled_scales=True + ) + self._check_outputs(q, s, M, K) + + def test_plain_nvfp4_cast_unswizzled(self): + from torchao.prototype.mx_formats.cutedsl import ( + nvfp4_rht_quantize_cutedsl_2d, + ) + + torch.manual_seed(1) + M, K = 128, 256 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 5.0 + global_scale = self._global_scale(x) + + q, s = nvfp4_rht_quantize_cutedsl_2d( + x, global_scale, sign_vector=None, is_swizzled_scales=False + ) + self._check_outputs(q, s, M, K) + assert s.shape == (M, K // 16) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 33cf98cce5..6d343e2085 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -33,6 +33,19 @@ class MXFP8Dim1CastKernelChoice(Enum): CUTEDSL = "cutedsl" +class MXFP4CastKernelChoice(str, Enum): + """ + Defines which kernel to use for mxfp4 casting. + + TORCH: emulated PyTorch cast (default). + CUTEDSL: fused Random Hadamard Transform + E2M1 quantize CuTeDSL kernel + (requires CUDA SM10.x, CUDA>=12.8, and nvidia-cutlass-dsl). + """ + + TORCH = "torch" + CUTEDSL = "cutedsl" + + # register as pytree constant so we can use dynamo nonstrict trace in torchao.prototype.moe_training.ep @register_as_pytree_constant class ScaleCalculationMode(Enum): diff --git a/torchao/prototype/mx_formats/cutedsl/__init__.py b/torchao/prototype/mx_formats/cutedsl/__init__.py new file mode 100644 index 0000000000..105e05a198 --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/__init__.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Fused MXFP4 + Random Hadamard Transform CuTeDSL quantize cast. + +This subpackage holds an optional, self-contained CuTeDSL implementation of an +MXFP4 (E2M1, block 32, E8M0 scales) quantize cast fused with a Random Hadamard +Transform. It is gated behind: + +* a Blackwell-family GPU (SM 10.x), +* CUDA >= 12.8, +* the CuTeDSL runtime packages (``nvidia-cutlass-dsl`` and friends). + +No ``cutlass`` import happens at module scope so importing this package is safe +on machines without the CuTeDSL runtime (the gate flag simply evaluates False). +""" + +import torch + +from torchao.utils import is_cuda_version_at_least + +from .cute_utils import _cutedsl_runtime_available + + +def _is_sm_10x() -> bool: + """Return True iff a Blackwell-family (SM 10.x) GPU is available.""" + return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10 + + +_mxfp4_rht_cutedsl_kernels_available = ( + _is_sm_10x() and is_cuda_version_at_least(12, 8) and _cutedsl_runtime_available() +) + + +def pack32_e2m1_to_bytes(x: torch.Tensor) -> torch.Tensor: + """Lazily re-exported test/validation entry for E2M1 packing. + + See ``cute_utils.pack32_e2m1_to_bytes``. Imported lazily so that importing + this package does not require the CuTeDSL runtime. + """ + from .cute_utils import pack32_e2m1_to_bytes as _impl + + return _impl(x) + + +def mxfp4_rht_quantize_cutedsl_2d( + x, + sign_vector, + block_size: int = 32, + scaling_mode: str = "floor", + is_swizzled_scales: bool = True, +): + """Lazily re-exported gated wrapper for the fused MXFP4 + RHT cast. + + See ``mxfp4_rht_quantize.mxfp4_rht_quantize_cutedsl_2d``. Imported lazily so + that importing this package does not require the CuTeDSL runtime. + """ + from .mxfp4_rht_quantize import mxfp4_rht_quantize_cutedsl_2d as _impl + + return _impl(x, sign_vector, block_size, scaling_mode, is_swizzled_scales) + + +def mxfp4_rht_quantize_cutedsl( + x, + sign_vector, + block_size: int = 32, + scaling_mode: str = "floor", + is_swizzled_scales: bool = True, + stage_count: int = 2, +): + """Lazily re-exported custom op for the fused MXFP4 + RHT cast. + + See ``mxfp4_rht_quantize.mxfp4_rht_quantize_cutedsl``. Imported lazily so + that importing this package does not require the CuTeDSL runtime. + """ + from .mxfp4_rht_quantize import mxfp4_rht_quantize_cutedsl as _impl + + return _impl( + x, sign_vector, block_size, scaling_mode, is_swizzled_scales, stage_count + ) + + +def nvfp4_rht_quantize_cutedsl_2d( + x, + global_scale, + sign_vector=None, + block_size: int = 16, + is_swizzled_scales: bool = True, +): + """Lazily re-exported gated wrapper for the fused NVFP4 (+/- RHT) cast. + + See ``nvfp4_rht_quantize.nvfp4_rht_quantize_cutedsl_2d``. Imported lazily so + that importing this package does not require the CuTeDSL runtime. + ``sign_vector=None`` (or empty) selects the plain NVFP4 cast. + """ + from .nvfp4_rht_quantize import nvfp4_rht_quantize_cutedsl_2d as _impl + + return _impl(x, global_scale, sign_vector, block_size, is_swizzled_scales) + + +def nvfp4_rht_quantize_cutedsl( + x, + global_scale, + sign_vector, + block_size: int = 16, + is_swizzled_scales: bool = True, + stage_count: int = 2, +): + """Lazily re-exported custom op for the fused NVFP4 (+/- RHT) cast. + + See ``nvfp4_rht_quantize.nvfp4_rht_quantize_cutedsl``. Imported lazily so + that importing this package does not require the CuTeDSL runtime. + """ + from .nvfp4_rht_quantize import nvfp4_rht_quantize_cutedsl as _impl + + return _impl( + x, global_scale, sign_vector, block_size, is_swizzled_scales, stage_count + ) + + +__all__ = [ + "_is_sm_10x", + "_mxfp4_rht_cutedsl_kernels_available", + "pack32_e2m1_to_bytes", + "mxfp4_rht_quantize_cutedsl_2d", + "mxfp4_rht_quantize_cutedsl", + "nvfp4_rht_quantize_cutedsl_2d", + "nvfp4_rht_quantize_cutedsl", +] diff --git a/torchao/prototype/mx_formats/cutedsl/cute_utils.py b/torchao/prototype/mx_formats/cutedsl/cute_utils.py new file mode 100644 index 0000000000..785b9dc64e --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/cute_utils.py @@ -0,0 +1,607 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared utilities for the MXFP4 + RHT CuTeDSL quantize kernels.""" + +import importlib.util + +import torch + +# Runtime package detection +_CUTEDSL_RUNTIME_PACKAGES = { + "cuda.bindings.driver": "cuda-python", + "cutlass": "nvidia-cutlass-dsl", + "cutlass.cute": "nvidia-cutlass-dsl", + "tvm_ffi": "apache-tvm-ffi", +} + + +def _missing_cutedsl_runtime_packages() -> list[str]: + """Check which CuTeDSL runtime packages are missing. + + Returns: + List of missing package names + """ + missing = [] + for module_name, package_name in _CUTEDSL_RUNTIME_PACKAGES.items(): + try: + spec = importlib.util.find_spec(module_name) + except (ModuleNotFoundError, ValueError): + # ModuleNotFoundError: parent module doesn't exist (e.g., 'cuda' on CPU) + # ValueError: can occur with malformed module names + spec = None + + if spec is None and package_name not in missing: + missing.append(package_name) + return missing + + +def _cutedsl_runtime_available() -> bool: + """Check if all CuTeDSL runtime packages are available. + + Returns: + True if all required packages are installed + """ + return len(_missing_cutedsl_runtime_packages()) == 0 + + +if _cutedsl_runtime_available(): + import cutlass + import cutlass.cute as cute + from cutlass._mlir.dialects import llvm + from cutlass.base_dsl._mlir_helpers import arith as _dsl_arith + from cutlass.cutlass_dsl import T, dsl_user_op + + # FP4 (E2M1) constants. F4_E2M1_MAX == 6.0. + INV_F4_E2M1_MAX = cutlass.Float32(1.0 / 6.0) + + # NVFP4 two-level (E4M3 block scale) constants. + # E4M3_EPS is torch.finfo(float8_e4m3fn).tiny == 2**-6 == 0.015625, the + # smallest *normal* E4M3 value. The eager NVFP4 path clamps the block scale + # to [E4M3_EPS, F8E4M3_MAX] before the float8_e4m3fn cast, so we replicate + # the lower clamp here (the saturating PTX cvt handles the upper bound). + F8E4M3_MAX = cutlass.Float32(448.0) + E4M3_EPS = cutlass.Float32(0.015625) + + # FP4 E8M0 scale constants. NOTE: these are the FP4 values, NOT the FP8 + # ones -- `F4_E2M1_MAX_POW2 == 2` (the FP8 helper uses 8), and the RCEIL + # descale divisor is `F4_E2M1_MAX == 6.0` (the FP8 helper uses 448). + F4_E2M1_MAX_POW2 = 2 # log2 of the largest power-of-two <= F4_E2M1_MAX (6.0) + E8M0_EXPONENT_BIAS = 127 + E8M0_EXPONENT_NAN_VAL = 255 + + @dsl_user_op + def _cvt_rn_satfinite_e2m1x2_f32( + hi: cutlass.Float32, + lo: cutlass.Float32, + *, + loc=None, + ip=None, + ) -> cutlass.Uint8: + """PTX inline assembly that packs two f32 values into one E2M1x2 byte. + + Uses inline PTX on Blackwell-family targets because CuTeDSL does not + currently lower the float32 -> E2M1 pair conversion to + ``cvt.rn.satfinite.e2m1x2.f32`` on its own. + + The PTX result of ``cvt.rn.satfinite.e2m1x2.f32 d, a, b`` is a ``.b8`` + value, but inline-asm output registers must be at least 16-bit and + ptxas rejects a 16-bit register directly as the ``cvt`` destination. So + (mirroring cutlass's ``cvt.rn.f16x2.e2m1x2`` wrappers in + ``cute/arch/nvvm_wrappers.py``) we ``cvt`` into a ``.reg .b8`` and + assemble it into a ``.b16`` output via ``mov.b16 $0, {d, z}`` with a + zero high byte. The low byte is then masked out and returned. + + Convention (validated bit-exactly by ``test_mxfp4_rht_cutedsl``): the + second PTX operand ``b`` (``lo``) lands in the LOW nibble and the first + operand ``a`` (``hi``) lands in the HIGH nibble, i.e. the returned byte + is ``(e2m1(hi) << 4) | e2m1(lo)``. + """ + packed = cutlass.Uint16( + llvm.inline_asm( + T.i16(), + [ + cutlass.Float32(hi).ir_value(loc=loc, ip=ip), + cutlass.Float32(lo).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b8 d, z, w;\n\t" + ".reg .b16 zero16;\n\t" + "mov.u16 zero16, 0;\n\t" + "mov.b16 {z, w}, zero16;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 d, $1, $2;\n\t" + "mov.b16 $0, {d, z};\n\t" + "}", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return cutlass.Uint8(packed & cutlass.Uint16(0xFF)) + + @cute.kernel + def _pack32_e2m1_kernel( + gx: cute.Tensor, + gout: cute.Tensor, + ): + """One-block, single-thread kernel: pack 32 f32 -> 16 E2M1x2 bytes. + + For ``p`` in ``0..15`` it emits + ``_cvt_rn_satfinite_e2m1x2_f32(hi=x[2p+1], lo=x[2p])`` so that the even + column ``2p`` lands in the low nibble and the odd column ``2p+1`` in the + high nibble of output byte ``p``. + """ + tidx, _, _ = cute.arch.thread_idx() + if tidx == 0: + for p in cutlass.range_constexpr(16): + lo = cutlass.Float32(gx[2 * p]) + hi = cutlass.Float32(gx[2 * p + 1]) + gout[p] = _cvt_rn_satfinite_e2m1x2_f32(hi, lo) + + @cute.jit + def _pack32_e2m1_launch( + gx: cute.Tensor, + gout: cute.Tensor, + stream, + ): + _pack32_e2m1_kernel(gx, gout).launch( + grid=(1, 1, 1), + block=(32, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def pack32_e2m1_to_bytes(x: torch.Tensor) -> torch.Tensor: + """Pack a length-32 fp32 CUDA vector into 16 E2M1x2 bytes. + + Test/validation entry only -- the production kernel inlines the same + ``_cvt_rn_satfinite_e2m1x2_f32`` logic. Even input columns go to the low + nibble, odd columns to the high nibble (matching torchao's packed-fp4 + byte order). + + Args: + x: 1D float32 CUDA tensor of length 32. + + Returns: + A ``(16,)`` uint8 CUDA tensor. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dtype == torch.float32, "input must be float32" + assert x.numel() == 32, "input must have exactly 32 elements" + x = x.contiguous() + out = torch.empty((16,), device=x.device, dtype=torch.uint8) + + gx = from_dlpack(x) + gout = from_dlpack(out) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _pack32_e2m1_launch(gx, gout, stream) + return out + + # ------------------------------------------------------------------ + # FP4 E8M0 block-scale helpers + # ------------------------------------------------------------------ + + @dsl_user_op + def _cvt_rp_satfinite_ue8m0x2_f32( + a: cutlass.Float32, + *, + loc=None, + ip=None, + ) -> cutlass.Uint16: + """PTX inline assembly for RCEIL E8M0 conversion (Blackwell SM10.x). + + ``cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, descale`` packs two e8m0 + results into a uint16; the low byte holds the e8m0 of ``descale``. + Hardware handles NaN -> 255, Inf -> 254, subnormals -> 0. + """ + return cutlass.Uint16( + llvm.inline_asm( + T.i16(), + [cutlass.Float32(a).ir_value(loc=loc, ip=ip)], + "cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;", + "=h,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @cute.jit + def compute_amax(vals_block: cute.Tensor): + """Compute the absolute maximum of a block of values as Float32.""" + vals_vec = vals_block.load() + abs_vec = cute.where(vals_vec < 0, -vals_vec, vals_vec) + return cutlass.Float32( + abs_vec.reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0) + ) + + @cute.jit + def compute_scale_floor_fp4(amax: cutlass.Float32): + """FP4 FLOOR-mode E8M0 biased scale byte + fp32 inverse scale. + + Mirrors torchao eager ``to_mx`` (FLOOR branch) for + ``elem_dtype=torch.float4_e2m1fn_x2``: + + extracted_pow2 = ((bits(amax) >> 23) & 0xFF) - 127 + scale_unbiased = extracted_pow2 - F4_E2M1_MAX_POW2 # -2, NOT -8 + scale_unbiased = clamp(scale_unbiased, -127, 128) + scale_biased(byte) = scale_unbiased + 127 + inv_scale = 2 ** (-scale_unbiased) + + The ``inv_scale`` is the value the quantizer multiplies each element by + before the E2M1 conversion (it is ``1 / 2**scale_unbiased``). The eager + FLOOR path floors the divisor to ``2**-126``; the matching cap here is + ``inv_scale <= 2**126`` (i.e. ``scale_unbiased >= -126``), so we clamp + the exponent used for ``inv_scale`` at the bottom. + + Returns ``(scale_biased, inv_scale)`` where ``scale_biased`` is the + biased E8M0 byte as Int32 (caller stores low 8 bits) and ``inv_scale`` + is a ``Float32``. + """ + bits = _dsl_arith.bitcast(amax.ir_value(), _dsl_arith.T.i32()) + exp_i = ((bits >> cutlass.Int32(23)) & cutlass.Int32(0xFF)) - cutlass.Int32( + E8M0_EXPONENT_BIAS + ) + scale_exp_unbiased = exp_i - cutlass.Int32(F4_E2M1_MAX_POW2) + if scale_exp_unbiased < -E8M0_EXPONENT_BIAS: + scale_exp_unbiased = cutlass.Int32(-E8M0_EXPONENT_BIAS) + if scale_exp_unbiased > E8M0_EXPONENT_BIAS + 1: + scale_exp_unbiased = cutlass.Int32(E8M0_EXPONENT_BIAS + 1) + scale_biased = scale_exp_unbiased + E8M0_EXPONENT_BIAS + # inv_scale = 2 ** (-scale_unbiased); divisor floored to 2**-126, so the + # inverse is capped at 2**126 (scale_exp_unbiased >= -126). + inv_exp = -scale_exp_unbiased + if inv_exp > cutlass.Int32(126): + inv_exp = cutlass.Int32(126) + inv_scale = cute.exp2(cutlass.Float32(inv_exp)) + return scale_biased, inv_scale + + @cute.jit + def compute_scale_rceil_fp4(amax: cutlass.Float32): + """FP4 RCEIL-mode E8M0 biased scale byte + fp32 inverse scale. + + Mirrors torchao eager ``_to_mx_rceil`` for + ``elem_dtype=torch.float4_e2m1fn_x2`` (``max_pos == F4_E2M1_MAX == 6.0``): + + descale = amax / 6.0 (i.e. amax * INV_F4_E2M1_MAX, NOT 1/448) + biased = cvt.rp.satfinite.ue8m0x2.f32(0.0, descale) # low byte + inv_scale = 2 ** (127 - biased) + + Hardware saturates / handles NaN -> 255, Inf -> 254. The ``inv_scale`` + is derived from the biased byte exactly like the mxfp8 RCEIL helper: + byte ``0xFF`` (NaN) -> 0.0, byte ``0`` -> ``2**126`` (subnormal floor), + else ``2**(127 - biased)``. + + Returns ``(scale_biased, inv_scale)`` where ``scale_biased`` is the + biased E8M0 byte as Int32 (caller stores low 8 bits) and ``inv_scale`` + is a ``Float32``. + """ + descale = amax * INV_F4_E2M1_MAX + scale_biased = cutlass.Int32(_cvt_rp_satfinite_ue8m0x2_f32(descale)) + inv_scale = cutlass.Float32(1.0) + if scale_biased == 0xFF: + inv_scale = cutlass.Float32(0.0) + elif scale_biased == 0: + inv_scale = cute.exp2(cutlass.Float32(126.0)) + else: + inv_scale = cute.exp2(cutlass.Float32(127 - scale_biased)) + return scale_biased, inv_scale + + @cute.jit + def compute_scale_byte_fp4( + amax: cutlass.Float32, + USE_RCEIL: cutlass.Constexpr[bool], + ): + """Dispatch to the FP4 FLOOR / RCEIL E8M0 biased scale byte + inv_scale. + + Matches eager ``to_mx``: a NaN ``amax`` maps to the E8M0 NaN byte + (255). For FLOOR a non-NaN ``amax`` always uses the bit-extraction + path (an all-zero block yields byte ``-2 + 127 = 125``, the same as + eager, since ``floor(log2(0)) = -127`` clamps and ``0 -> extracted + pow2 = -127``... handled by clamp). RCEIL uses the saturating PTX cvt, + which itself maps subnormal/zero descale -> 0. + + Returns ``(scale_biased, inv_scale)``. For a NaN ``amax`` the + ``inv_scale`` is ``0.0`` (matching the E8M0 NaN byte). + """ + # NaN amax -> E8M0 NaN byte, matching eager to_mx. + scale_biased = cutlass.Int32(E8M0_EXPONENT_NAN_VAL) + inv_scale = cutlass.Float32(0.0) + if amax == amax: # not NaN + if cutlass.const_expr(USE_RCEIL): + scale_biased, inv_scale = compute_scale_rceil_fp4(amax) + else: + scale_biased, inv_scale = compute_scale_floor_fp4(amax) + return scale_biased, inv_scale + + @cute.kernel + def _block_scale_e8m0_fp4_kernel( + gx: cute.Tensor, + gscale: cute.Tensor, + num_blocks: cutlass.Int32, + USE_RCEIL: cutlass.Constexpr[bool], + ): + """One thread per 32-element block: amax -> E8M0 biased scale byte. + + ``gx`` is a 2D ``(num_blocks, 32)`` float32 view; ``gscale`` is a 1D + ``(num_blocks,)`` uint8 view (the flattened ``(N, K//32)`` scales). + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + block = bidx * bdim + tidx + if block < num_blocks: + vals = cute.make_rmem_tensor((32,), cutlass.Float32) + for j in cutlass.range_constexpr(32): + vals[j] = cutlass.Float32(gx[block, j]) + amax = compute_amax(vals) + scale_biased, _ = compute_scale_byte_fp4(amax, USE_RCEIL) + gscale[block] = cutlass.Uint8(scale_biased & cutlass.Int32(0xFF)) + + @cute.jit + def _block_scale_e8m0_fp4_launch( + gx: cute.Tensor, + gscale: cute.Tensor, + num_blocks: cutlass.Int32, + stream, + USE_RCEIL: cutlass.Constexpr[bool], + ): + threads = 128 + grid = (num_blocks + threads - 1) // threads + _block_scale_e8m0_fp4_kernel(gx, gscale, num_blocks, USE_RCEIL).launch( + grid=(grid, 1, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def compute_block_scale_e8m0_fp4(x: torch.Tensor, mode: str) -> torch.Tensor: + """Compute per-32-block E8M0 biased scale bytes for FP4. + + Bit-exact (validated by ``test_mxfp4_rht_cutedsl``) against torchao + eager ``to_mx(..., elem_dtype=torch.float4_e2m1fn_x2, block_size=32)`` + plain (unswizzled) scales, for FLOOR and RCEIL modes. + + Args: + x: 2D CUDA tensor ``[N, K]`` (bf16 or fp32) with ``K % 32 == 0``. + mode: ``"floor"`` or ``"rceil"``. + + Returns: + ``(N, K // 32)`` ``torch.float8_e8m0fnu`` CUDA tensor of scale bytes. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dim() == 2, "input must be 2D [N, K]" + n, k = x.shape + assert k % 32 == 0, "K must be divisible by 32" + mode = mode.lower() + assert mode in ("floor", "rceil"), f"unsupported scaling mode: {mode}" + use_rceil = mode == "rceil" + + kb = k // 32 + num_blocks = n * kb + # float32 [num_blocks, 32] view of the per-block elements. + x_f32 = x.to(torch.float32).contiguous().reshape(num_blocks, 32) + scale_u8 = torch.empty((num_blocks,), device=x.device, dtype=torch.uint8) + + gx = from_dlpack(x_f32) + gscale = from_dlpack(scale_u8) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _block_scale_e8m0_fp4_launch( + gx, gscale, cutlass.Int32(num_blocks), stream, use_rceil + ) + return scale_u8.view(torch.float8_e8m0fnu).reshape(n, kb) + + # ------------------------------------------------------------------ + # NVFP4 E4M3 two-level block-scale helpers + # ------------------------------------------------------------------ + + @dsl_user_op + def _cvt_rn_satfinite_e4m3x2_f32( + hi: cutlass.Float32, + lo: cutlass.Float32, + *, + loc=None, + ip=None, + ) -> cutlass.Uint8: + """PTX inline assembly: convert two f32 values to one E4M3x2 byte each. + + ``cvt.rn.satfinite.e4m3x2.f32 d, a, b`` produces a ``.b16`` whose low + byte is ``e4m3(b)`` (``lo``) and high byte is ``e4m3(a)`` (``hi``); the + instruction itself round-to-nearest-saturates to the E4M3 finite max + (448). We pass ``hi == lo`` and return the low byte, so the result is + ``e4m3(lo)``. + + Unlike the E2M1 variant the E4M3x2 cvt already has a ``.b16`` output, so + no ``mov.b16`` reassembly is needed. + """ + packed = cutlass.Uint16( + llvm.inline_asm( + T.i16(), + [ + cutlass.Float32(hi).ir_value(loc=loc, ip=ip), + cutlass.Float32(lo).ir_value(loc=loc, ip=ip), + ], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return cutlass.Uint8(packed & cutlass.Uint16(0xFF)) + + @dsl_user_op + def _cvt_e4m3_byte_to_f32( + byte: cutlass.Uint8, + *, + loc=None, + ip=None, + ) -> cutlass.Float32: + """Dequantize a single E4M3 byte back to Float32. + + Zero-extend the byte into a ``.b16`` (low byte = the e4m3 byte, high + byte = 0) and feed that 16-bit value straight to + ``cvt.rn.f16x2.e4m3x2`` (whose source operand is a ``.b16`` holding the + e4m3x2 pair). The low f16 of the result holds ``e4m3(byte)``, which we + widen to f32 -- the exact value a float8_e4m3fn ``.to(torch.float32)`` + produces for the byte. + + NOTE: we pass the zero-extended ``.b16`` directly as the cvt source. + Extracting it into a ``.reg .b8`` first (``mov.b16 {b, z}, $1`` then + ``cvt ... b``) is rejected by ptxas on the Blackwell + (``sm_100a`` / ``--enable-tvm-ffi``) lowering path with "Arguments + mismatch for instruction 'cvt'" -- ``cvt.f16x2.e4m3x2`` requires a 16-bit + source operand. + """ + src_i16 = cutlass.Uint16(byte) + rst_i32 = cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [src_i16.ir_value(loc=loc, ip=ip)], + "cvt.rn.f16x2.e4m3x2 $0, $1;", + "=r,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + # Low 16 bits = f16(e4m3(byte)); widen to f32. + lo_f16_bits = cutlass.Uint16(rst_i32 & cutlass.Uint32(0xFFFF)) + f16_val = cutlass.Float16( + llvm.bitcast(T.f16(), lo_f16_bits.ir_value(loc=loc, ip=ip)) + ) + return cutlass.Float32(f16_val) + + @cute.jit + def compute_nvfp4_scale_e4m3( + amax: cutlass.Float32, + global_scale: cutlass.Float32, + ): + """NVFP4 two-level E4M3 block scale byte + fp32 data inverse-scale. + + Bit-exactly mirrors the two-level path of torchao eager + ``nvfp4_quantize`` (``nvfp4_tensor.py``), where ``global_scale`` is the + multiplicative reciprocal of torchao's stored ``per_tensor_scale`` + (i.e. ``global_scale = 1 / per_tensor_scale``): + + block_scale = amax / 6.0 # amax * INV_F4_E2M1_MAX + local = block_scale * global_scale # == block_scale / per_tensor_scale + local = clamp(local, E4M3_EPS, 448) + e4m3_byte = cvt.rn.satfinite.e4m3x2.f32(local) + inv_scale = global_scale / dequant_e4m3(e4m3_byte) + + The ``inv_scale`` is what the data quantizer multiplies each element by + before the E2M1 conversion; it matches eager's + ``reciprocal_scale = (1/per_tensor_scale) / scaled_block_scales_fp32``. + + Returns ``(e4m3_byte: Uint8, inv_scale: Float32)``. + """ + block_scale = amax * INV_F4_E2M1_MAX + local = block_scale * global_scale + # Lower clamp to the smallest E4M3 normal (matches eager torch.clamp); + # the upper clamp to 448 is handled by the saturating PTX cvt, but we + # apply it explicitly to be faithful to the eager op order. + if local < E4M3_EPS: + local = E4M3_EPS + if local > F8E4M3_MAX: + local = F8E4M3_MAX + e4m3_byte = _cvt_rn_satfinite_e4m3x2_f32(local, local) + dequant = _cvt_e4m3_byte_to_f32(e4m3_byte) + inv_scale = global_scale / dequant + return e4m3_byte, inv_scale + + @cute.kernel + def _block_scale_e4m3_nvfp4_kernel( + gx: cute.Tensor, + gscale: cute.Tensor, + global_scale: cutlass.Float32, + num_blocks: cutlass.Int32, + ): + """One thread per 16-element block: amax -> NVFP4 E4M3 scale byte. + + ``gx`` is a 2D ``(num_blocks, 16)`` float32 view; ``gscale`` is a 1D + ``(num_blocks,)`` uint8 view (the flattened ``(N, K//16)`` scales). + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + block = bidx * bdim + tidx + if block < num_blocks: + vals = cute.make_rmem_tensor((16,), cutlass.Float32) + for j in cutlass.range_constexpr(16): + vals[j] = cutlass.Float32(gx[block, j]) + amax = compute_amax(vals) + e4m3_byte, _ = compute_nvfp4_scale_e4m3(amax, global_scale) + gscale[block] = e4m3_byte + + @cute.jit + def _block_scale_e4m3_nvfp4_launch( + gx: cute.Tensor, + gscale: cute.Tensor, + global_scale: cutlass.Float32, + num_blocks: cutlass.Int32, + stream, + ): + threads = 128 + grid = (num_blocks + threads - 1) // threads + _block_scale_e4m3_nvfp4_kernel(gx, gscale, global_scale, num_blocks).launch( + grid=(grid, 1, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def compute_block_scale_e4m3_nvfp4( + x: torch.Tensor, global_scale: float + ) -> torch.Tensor: + """Compute per-16-block NVFP4 E4M3 scale bytes. + + Bit-exact (validated by ``test_nvfp4_rht_cutedsl``) against the + two-level path of torchao eager + ``nvfp4_tensor.nvfp4_quantize(x, block_size=16, per_tensor_scale=...)`` + plain (unswizzled) ``float8_e4m3fn`` scales, where + ``global_scale = 1 / per_tensor_scale``. + + Args: + x: 2D CUDA tensor ``[N, K]`` (bf16 or fp32) with ``K % 16 == 0``. + global_scale: Per-tensor multiplicative global scale (the reciprocal + of torchao's ``per_tensor_scale``). + + Returns: + ``(N, K // 16)`` ``torch.float8_e4m3fn`` CUDA tensor of scale bytes. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dim() == 2, "input must be 2D [N, K]" + n, k = x.shape + assert k % 16 == 0, "K must be divisible by 16" + + kb = k // 16 + num_blocks = n * kb + # float32 [num_blocks, 16] view of the per-block elements. + x_f32 = x.to(torch.float32).contiguous().reshape(num_blocks, 16) + scale_u8 = torch.empty((num_blocks,), device=x.device, dtype=torch.uint8) + + gx = from_dlpack(x_f32) + gscale = from_dlpack(scale_u8) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _block_scale_e4m3_nvfp4_launch( + gx, + gscale, + cutlass.Float32(global_scale), + cutlass.Int32(num_blocks), + stream, + ) + return scale_u8.view(torch.float8_e4m3fn).reshape(n, kb) diff --git a/torchao/prototype/mx_formats/cutedsl/fwht.py b/torchao/prototype/mx_formats/cutedsl/fwht.py new file mode 100644 index 0000000000..efea5539a3 --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/fwht.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Register-resident FWHT + sign transforms for the MXFP4/NVFP4 + RHT casts. + +Implements, per block, the normalized Fast Walsh-Hadamard Transform followed by +an elementwise sign multiply: + + out = (FWHT(vals) / sqrt(N)) * sign + +This equals ``vals @ hadamard_matrix(N) * sign`` where ``hadamard_matrix(N)`` +is torchao's normalized (1/sqrt(N)) symmetric Sylvester/Walsh-Hadamard matrix +(see ``torchao/prototype/spinquant/hadamard_utils.py``). + +* ``fwht32_sign`` (N=32, 5 butterfly stages) is used by the MXFP4 (block 32) + cast. +* ``fwht16_sign`` (N=16, 4 butterfly stages) is used by the NVFP4 (block 16) + cast. + +The transforms are pure device helpers so the fused kernels can call them per +block from registers. ``fwht{32,16}_sign_host`` are tiny one-block-per-row +kernels used only to validate the device helpers against a dense torch +reference. +""" + +import math + +import torch + +from .cute_utils import _cutedsl_runtime_available + +# 1 / sqrt(32); applied once after the 5 butterfly stages. +_INV_SQRT_32 = 1.0 / math.sqrt(32.0) +# 1 / sqrt(16) == 0.25; applied once after the 4 butterfly stages. +_INV_SQRT_16 = 1.0 / math.sqrt(16.0) + + +if _cutedsl_runtime_available(): + import cutlass + import cutlass.cute as cute + + # Normalization scalar as a device Float32 constant. + INV_SQRT_32 = cutlass.Float32(_INV_SQRT_32) + INV_SQRT_16 = cutlass.Float32(_INV_SQRT_16) + + @cute.jit + def fwht32_sign(vals: cute.Tensor, sign: cute.Tensor) -> cute.Tensor: + """In-register normalized FWHT(32) followed by an elementwise sign mul. + + Computes, for the length-32 register fragment ``vals``:: + + out[j] = (FWHT(vals) / sqrt(32))[j] * sign[j] + + which equals ``(vals @ hadamard_matrix(32))[j] * sign[j]``. + + The transform is the standard radix-2 decimation-in-time Walsh-Hadamard + butterfly over strides ``s in {1, 2, 4, 8, 16}``. For each stride, every + pair ``(i, i + s)`` with ``(i & s) == 0`` is touched exactly once:: + + a, b = vals[i] + vals[i + s], vals[i] - vals[i + s] + vals[i], vals[i + s] = a, b + + All pair indices are compile-time constants (the 5 stages and their + pairings are fixed for a length-32 block), so the loops are unrolled via + ``cutlass.range_constexpr``. This pairing/orientation is validated + bit-for-bit-close against ``hadamard_matrix(32)`` by + ``test_mxfp4_rht_cutedsl``. + + Args: + vals: length-32 ``Float32`` register fragment (modified in place and + also returned). + sign: length-32 register fragment of ``{-1, +1}`` values (any + arithmetic dtype; cast to ``Float32`` for the multiply). + + Returns: + The transformed length-32 ``Float32`` fragment (the same ``vals``). + """ + # 5-stage in-register butterfly. Strides are powers of two up to 16. + for stage in cutlass.range_constexpr(5): + s = 1 << stage + for i in cutlass.range_constexpr(32): + if cutlass.const_expr((i & s) == 0): + a = cutlass.Float32(vals[i]) + b = cutlass.Float32(vals[i + s]) + vals[i] = a + b + vals[i + s] = a - b + + # Normalize then apply the sign vector. + for j in cutlass.range_constexpr(32): + vals[j] = cutlass.Float32(vals[j]) * INV_SQRT_32 * cutlass.Float32(sign[j]) + + return vals + + @cute.jit + def fwht16_sign(vals: cute.Tensor, sign: cute.Tensor) -> cute.Tensor: + """In-register normalized FWHT(16) followed by an elementwise sign mul. + + Computes, for the length-16 register fragment ``vals``:: + + out[j] = (FWHT(vals) / sqrt(16))[j] * sign[j] + + which equals ``(vals @ hadamard_matrix(16))[j] * sign[j]``. + + Same radix-2 decimation-in-time Walsh-Hadamard butterfly as + ``fwht32_sign`` but for a 16-element block: 4 stages over strides + ``s in {1, 2, 4, 8}``. For each stride, every pair ``(i, i + s)`` with + ``(i & s) == 0`` is touched exactly once:: + + a, b = vals[i] + vals[i + s], vals[i] - vals[i + s] + vals[i], vals[i + s] = a, b + + All pair indices are compile-time constants (the 4 stages and their + pairings are fixed for a length-16 block), so the loops are unrolled via + ``cutlass.range_constexpr``. + + Args: + vals: length-16 ``Float32`` register fragment (modified in place and + also returned). + sign: length-16 register fragment of ``{-1, +1}`` values (any + arithmetic dtype; cast to ``Float32`` for the multiply). + + Returns: + The transformed length-16 ``Float32`` fragment (the same ``vals``). + """ + # 4-stage in-register butterfly. Strides are powers of two up to 8. + for stage in cutlass.range_constexpr(4): + s = 1 << stage + for i in cutlass.range_constexpr(16): + if cutlass.const_expr((i & s) == 0): + a = cutlass.Float32(vals[i]) + b = cutlass.Float32(vals[i + s]) + vals[i] = a + b + vals[i + s] = a - b + + # Normalize (1/sqrt(16) == 0.25) then apply the sign vector. + for j in cutlass.range_constexpr(16): + vals[j] = cutlass.Float32(vals[j]) * INV_SQRT_16 * cutlass.Float32(sign[j]) + + return vals + + @cute.kernel + def _fwht16_sign_kernel( + gx: cute.Tensor, + gsign: cute.Tensor, + gout: cute.Tensor, + num_rows: cutlass.Int32, + ): + """One thread per row: load 16 f32 + sign, apply ``fwht16_sign``, store. + + ``gx`` / ``gout`` are 2D ``(num_rows, 16)`` float32 views; ``gsign`` is a + 1D ``(16,)`` int32 view broadcast across all rows. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + row = bidx * bdim + tidx + if row < num_rows: + vals = cute.make_rmem_tensor((16,), cutlass.Float32) + sign = cute.make_rmem_tensor((16,), cutlass.Float32) + for j in cutlass.range_constexpr(16): + vals[j] = cutlass.Float32(gx[row, j]) + sign[j] = cutlass.Float32(gsign[j]) + fwht16_sign(vals, sign) + for j in cutlass.range_constexpr(16): + gout[row, j] = cutlass.Float32(vals[j]) + + @cute.jit + def _fwht16_sign_launch( + gx: cute.Tensor, + gsign: cute.Tensor, + gout: cute.Tensor, + num_rows: cutlass.Int32, + stream, + ): + threads = 128 + grid = (num_rows + threads - 1) // threads + _fwht16_sign_kernel(gx, gsign, gout, num_rows).launch( + grid=(grid, 1, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def fwht16_sign_host(x: torch.Tensor, sign: torch.Tensor) -> torch.Tensor: + """Apply the normalized FWHT(16) + sign transform to each row of ``x``. + + Test/validation entry only -- the production kernel inlines the same + ``fwht16_sign`` device helper per block. Matches (to fp32 rounding) + ``(x @ hadamard_matrix(16)) * sign`` row-by-row. + + Args: + x: 2D float32 CUDA tensor ``[N, 16]``. + sign: length-16 integer CUDA tensor of ``{-1, +1}`` values. + + Returns: + A ``(N, 16)`` float32 CUDA tensor. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dim() == 2 and x.shape[1] == 16, "input must be 2D [N, 16]" + assert x.dtype == torch.float32, "input must be float32" + assert sign.is_cuda, "sign must be a CUDA tensor" + assert sign.numel() == 16, "sign must have exactly 16 elements" + + num_rows = x.shape[0] + x = x.contiguous() + sign_i32 = sign.to(torch.int32).contiguous() + out = torch.empty((num_rows, 16), device=x.device, dtype=torch.float32) + + gx = from_dlpack(x) + gsign = from_dlpack(sign_i32) + gout = from_dlpack(out) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _fwht16_sign_launch(gx, gsign, gout, cutlass.Int32(num_rows), stream) + return out + + @cute.kernel + def _fwht32_sign_kernel( + gx: cute.Tensor, + gsign: cute.Tensor, + gout: cute.Tensor, + num_rows: cutlass.Int32, + ): + """One thread per row: load 32 f32 + sign, apply ``fwht32_sign``, store. + + ``gx`` / ``gout`` are 2D ``(num_rows, 32)`` float32 views; ``gsign`` is a + 1D ``(32,)`` int32 view broadcast across all rows. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + row = bidx * bdim + tidx + if row < num_rows: + vals = cute.make_rmem_tensor((32,), cutlass.Float32) + sign = cute.make_rmem_tensor((32,), cutlass.Float32) + for j in cutlass.range_constexpr(32): + vals[j] = cutlass.Float32(gx[row, j]) + sign[j] = cutlass.Float32(gsign[j]) + fwht32_sign(vals, sign) + for j in cutlass.range_constexpr(32): + gout[row, j] = cutlass.Float32(vals[j]) + + @cute.jit + def _fwht32_sign_launch( + gx: cute.Tensor, + gsign: cute.Tensor, + gout: cute.Tensor, + num_rows: cutlass.Int32, + stream, + ): + threads = 128 + grid = (num_rows + threads - 1) // threads + _fwht32_sign_kernel(gx, gsign, gout, num_rows).launch( + grid=(grid, 1, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + def fwht32_sign_host(x: torch.Tensor, sign: torch.Tensor) -> torch.Tensor: + """Apply the normalized FWHT(32) + sign transform to each row of ``x``. + + Test/validation entry only -- the production kernel inlines the same + ``fwht32_sign`` device helper per block. Matches (to fp32 rounding) + ``(x @ hadamard_matrix(32)) * sign`` row-by-row. + + Args: + x: 2D float32 CUDA tensor ``[N, 32]``. + sign: length-32 integer CUDA tensor of ``{-1, +1}`` values. + + Returns: + A ``(N, 32)`` float32 CUDA tensor. + """ + import cuda.bindings.driver as cuda + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda, "input must be a CUDA tensor" + assert x.dim() == 2 and x.shape[1] == 32, "input must be 2D [N, 32]" + assert x.dtype == torch.float32, "input must be float32" + assert sign.is_cuda, "sign must be a CUDA tensor" + assert sign.numel() == 32, "sign must have exactly 32 elements" + + num_rows = x.shape[0] + x = x.contiguous() + sign_i32 = sign.to(torch.int32).contiguous() + out = torch.empty((num_rows, 32), device=x.device, dtype=torch.float32) + + gx = from_dlpack(x) + gsign = from_dlpack(sign_i32) + gout = from_dlpack(out) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + _fwht32_sign_launch(gx, gsign, gout, cutlass.Int32(num_rows), stream) + return out diff --git a/torchao/prototype/mx_formats/cutedsl/mxfp4_rht_quantize.py b/torchao/prototype/mx_formats/cutedsl/mxfp4_rht_quantize.py new file mode 100644 index 0000000000..b077716da7 --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/mxfp4_rht_quantize.py @@ -0,0 +1,842 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Fused MXFP4 (E2M1, block 32, E8M0) + Random Hadamard Transform CuTeDSL cast. + +Clones the structure of the MXFP8 1x32 CuTeDSL quantize kernel +(``torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_2d_1x32.py``) +and applies the MXFP4 + RHT deltas: + +* the E2M1 output is two codes per byte, so the output smem / global qdata are + half-width ``[M, K // 2]`` ``uint8``; +* per 32-element block the consumer applies the register-resident FWHT(32) + + sign transform (``fwht.fwht32_sign``) before computing amax / scale / packing; +* values are clamped to ``+-6.0`` (``F4_E2M1_MAX``) before the E2M1 cvt in FLOOR + mode (RCEIL's cvt saturates on its own); +* 16 E2M1x2 bytes per block are assembled and written as 4 ``uint32`` stores; +* the biased E8M0 scale byte is written either in the cuBLAS-blocked padded + layout (``is_swizzled_scales=True``) or as a plain ``(M, K // 32)`` tensor. + +The op is gated behind a Blackwell (SM 10.x) GPU, CUDA >= 12.8, and the CuTeDSL +runtime packages (see ``cutedsl/__init__.py``). +""" + +import functools +from typing import Tuple + +import torch + +from torchao.utils import ceil_div + +from .cute_utils import ( + _cvt_rn_satfinite_e2m1x2_f32, + compute_amax, + compute_scale_byte_fp4, +) +from .fwht import fwht32_sign + +# Config format: +# (compute_warps, tile_m, tile_k, k_tiles_per_cta) +_CUTEDSL_CONFIGS = { + "bf16_default": (4, 128, 32, 4), + "fallback": (6, 128, 32, 2), +} + + +def _select_cutedsl_config( + input_dtype: torch.dtype, + scaling_mode: str, +) -> Tuple[str, Tuple[int, int, int, int]]: + """Select kernel configuration based on input dtype.""" + if input_dtype == torch.bfloat16: + config_name = "bf16_default" + else: + config_name = "fallback" + return config_name, _CUTEDSL_CONFIGS[config_name] + + +def _make_tile_smem_layouts(tile_m: int, tile_k: int): + """Row-major smem layouts for the input ``(TILE_M, TILE_K)`` and the + half-width output ``(TILE_M, TILE_K // 2)`` tiles.""" + import cutlass.cute as cute + + smem_layout_in = cute.make_layout( + (tile_m, tile_k), + stride=(tile_k, 1), + ) + smem_layout_out = cute.make_layout( + (tile_m, tile_k // 2), + stride=(tile_k // 2, 1), + ) + return smem_layout_in, smem_layout_out + + +@functools.cache +def _compile_mxfp4_rht_quantize_2d_cutedsl( + input_dtype_name: str, + scaling_mode: str, + compute_warps: int, + tile_m: int, + tile_k: int, + requested_stage_count: int, + k_tiles_per_cta: int, + blocked_scale_output: bool, +): + """Compile the fused 2D MXFP4 + RHT quantize kernel using CuTeDSL. + + Warp-specialized TMA kernel mirroring the MXFP8 1x32 template: + - warp 0: producer (TMA global->shared input, shared->global half-width + output); + - warps [1..compute_warps]: consumers (FWHT + quantize + E2M1 pack). + """ + import cuda.bindings.driver as cuda + import cutlass + import cutlass.cute as cute + import cutlass.utils as utils + from cutlass.cute.nvgpu import cpasync, tcgen05 + from cutlass.cute.runtime import make_fake_stream, make_fake_tensor + + if input_dtype_name == "torch.float32": + INPUT_CUTLASS_DTYPE = cutlass.Float32 + elif input_dtype_name == "torch.bfloat16": + INPUT_CUTLASS_DTYPE = cutlass.BFloat16 + else: + raise ValueError( + f"Unsupported input dtype for CuTeDSL mxfp4 quantize_2d: {input_dtype_name}" + ) + + COMPUTE_WARPS = compute_warps + TILE_M = tile_m + TILE_K = tile_k + TILE_K_HALF = tile_k // 2 + K_TILES_PER_CTA = k_tiles_per_cta + BLOCKED_SCALE_OUTPUT_VALUE = blocked_scale_output + + THREADS_PER_BLOCK = (1 + COMPUTE_WARPS) * 32 + assert COMPUTE_WARPS >= 1 + assert TILE_M > 0 and TILE_K > 0 + assert TILE_K % 32 == 0 + + SCALE_DIM_K_VALUE = 32 + SCALE_DIM_K_HALF = SCALE_DIM_K_VALUE // 2 # 16 packed bytes per block + K_BLOCKS_PER_TILE = TILE_K // SCALE_DIM_K_VALUE + assert K_BLOCKS_PER_TILE > 0 + assert requested_stage_count >= 1 + assert requested_stage_count <= 2 + assert K_TILES_PER_CTA >= 1 + STAGE_COUNT_VALUE = min(requested_stage_count, K_TILES_PER_CTA) + + input_elem_bytes = 4 if input_dtype_name == "torch.float32" else 2 + TILE_COPY_BYTES_IN = TILE_M * TILE_K * input_elem_bytes + # Half-width uint8 output: 1 byte per packed E2M1 pair. + TILE_COPY_BYTES_OUT = TILE_M * TILE_K_HALF # noqa: F841 (documented contract) + M_THREADS = COMPUTE_WARPS * 32 + M_ITERS_PER_LANE = ceil_div(TILE_M, M_THREADS) + + USE_RCEIL_VALUE = scaling_mode == "rceil" + # F4_E2M1_MAX == 6.0; clamp pre-cvt values for FLOOR (RCEIL cvt saturates). + F4_MAX = cutlass.Float32(6.0) + + @cute.struct + class SharedStorage: + tma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, STAGE_COUNT_VALUE] + in_smem: cute.struct.Align[ + cute.struct.MemRange[ + INPUT_CUTLASS_DTYPE, STAGE_COUNT_VALUE * TILE_M * TILE_K + ], + 128, + ] + out_smem: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Uint8, STAGE_COUNT_VALUE * TILE_M * TILE_K_HALF + ], + 128, + ] + + class Mxfp4RhtQuantize2dKernel: + @cute.jit + def _load_block_full_smem_to_reg( + self, + sIN_tile: cute.Tensor, + m_rel: cutlass.Int32, + k_base: cutlass.Int32, + ): + """Load a full 32-element quantization block from smem to registers.""" + vals_block = cute.make_rmem_tensor((SCALE_DIM_K_VALUE,), cutlass.Float32) + for i in cutlass.range_constexpr(SCALE_DIM_K_VALUE): + vals_block[i] = cutlass.Float32(sIN_tile[m_rel, k_base + i]) + return vals_block + + @cute.jit + def _store_scales_reg_to_gmem_vec( + self, + scales_tensor: cute.Tensor, + m: cutlass.Int64, + k_block_base: cutlass.Int64, + scale_buffer: cute.Tensor, + num_scales: cutlass.Int32, + BLOCKED_SCALE_OUTPUT: cutlass.Constexpr[bool], + ): + """Store scales from registers to gmem (uint32 vectorized for blocked).""" + if cutlass.const_expr(BLOCKED_SCALE_OUTPUT): + if num_scales == 4: + scales_tensor_u32 = cute.recast_tensor( + scales_tensor, cutlass.Uint32 + ) + scale_buffer_u32 = cute.recast_tensor(scale_buffer, cutlass.Uint32) + scales_tensor_u32[m, k_block_base // cutlass.Int64(4)] = ( + scale_buffer_u32[0] + ) + else: + for i in range(num_scales): + k_block = k_block_base + i + scales_tensor[m, k_block] = scale_buffer[i] + else: + for i in range(num_scales): + k_block = k_block_base + i + scales_tensor[m, k_block] = scale_buffer[i] + + @cute.jit + def _store_q_e2m1_block_to_smem( + self, + packed_bytes: cute.Tensor, + sOUT_tile: cute.Tensor, + m_rel: cutlass.Int32, + sout_base: cutlass.Int32, + ): + """Store the 16 packed E2M1x2 bytes of a block via 4 uint32 writes. + + ``packed_bytes`` is a length-16 ``Uint8`` register fragment holding + the bytes for one 32-element block. ``sout_base`` is the byte offset + of the block within the half-width output tile (``k_base // 2``). + """ + sOUT_tile_u32 = cute.recast_tensor(sOUT_tile, cutlass.Uint32) + packed_u32 = cute.recast_tensor(packed_bytes, cutlass.Uint32) + base_u32 = sout_base // cutlass.Int32(4) + for w in cutlass.range_constexpr(SCALE_DIM_K_HALF // 4): + sOUT_tile_u32[m_rel, base_u32 + w] = packed_u32[w] + + @cute.jit + def _quantize_block_then_store_reg_to_smem_full( + self, + rht_block: cute.Tensor, + inv_scale: cutlass.Float32, + sOUT_tile: cute.Tensor, + m_rel: cutlass.Int32, + k_base: cutlass.Int32, + USE_RCEIL: cutlass.Constexpr[bool], + ): + """Scale, (clamp for FLOOR), pack 32 RHT values to 16 E2M1x2 bytes. + + ``rht_block`` is the post-FWHT length-32 ``Float32`` fragment. Even + column ``2p`` -> low nibble, odd column ``2p + 1`` -> high nibble of + output byte ``p`` (validated bit-exactly in Task 1). + """ + packed_bytes = cute.make_rmem_tensor((SCALE_DIM_K_HALF,), cutlass.Uint8) + for p in cutlass.range_constexpr(SCALE_DIM_K_HALF): + lo = cutlass.Float32(rht_block[2 * p]) * inv_scale + hi = cutlass.Float32(rht_block[2 * p + 1]) * inv_scale + if not cutlass.const_expr(USE_RCEIL): + # FLOOR: clamp to +-F4_E2M1_MAX before the cvt. + if lo > F4_MAX: + lo = F4_MAX + if lo < -F4_MAX: + lo = -F4_MAX + if hi > F4_MAX: + hi = F4_MAX + if hi < -F4_MAX: + hi = -F4_MAX + packed_bytes[p] = _cvt_rn_satfinite_e2m1x2_f32(hi, lo) + sout_base = k_base // cutlass.Int32(2) + self._store_q_e2m1_block_to_smem(packed_bytes, sOUT_tile, m_rel, sout_base) + + @cute.jit + def _issue_tma_load( + self, + tma_atom_in: cute.CopyAtom, + gIN_tile: cute.Tensor, + sIN_tile: cute.Tensor, + tma_mbar_ptr: cutlass.Int64, + warp_idx: cutlass.Int32, + ): + """Issue TMA load from global to shared memory (producer warp only).""" + if warp_idx == 0: + cta_layout = cute.make_layout((1,)) + sIN_for_tma_partition = cute.group_modes(sIN_tile, 0, 1) + gIN_for_tma_partition = cute.group_modes(gIN_tile, 0, 1) + tINs, tINg = cpasync.tma_partition( + tma_atom_in, + 0, + cta_layout, + sIN_for_tma_partition, + gIN_for_tma_partition, + ) + tINg_stage0 = tINg[(None, 0)] + tINs_stage0 = tINs[(None, 0)] + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + tma_mbar_ptr, TILE_COPY_BYTES_IN + ) + cute.copy( + tma_atom_in, + tINg_stage0, + tINs_stage0, + tma_bar_ptr=tma_mbar_ptr, + ) + + @cute.jit + def _issue_tma_store( + self, + tma_atom_out: cute.CopyAtom, + gOUT_tile: cute.Tensor, + sOUT_tile: cute.Tensor, + warp_idx: cutlass.Int32, + ): + """Issue TMA store from shared to global memory (producer warp only).""" + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + cute.arch.sync_threads() + if warp_idx == 0: + cta_layout = cute.make_layout((1,)) + sOUT_for_tma_partition = cute.group_modes(sOUT_tile, 0, 1) + gOUT_for_tma_partition = cute.group_modes(gOUT_tile, 0, 1) + tOUTs, tOUTg = cpasync.tma_partition( + tma_atom_out, + 0, + cta_layout, + sOUT_for_tma_partition, + gOUT_for_tma_partition, + ) + tOUTs_stage0 = tOUTs[(None, 0)] + tOUTg_stage0 = tOUTg[(None, 0)] + cute.copy( + tma_atom_out, + tOUTs_stage0, + tOUTg_stage0, + ) + + @cute.kernel + def kernel( + self, + inp_mk: cute.Tensor, + tma_atom_in: cute.CopyAtom, + tma_tensor_in: cute.Tensor, + out_mk: cute.Tensor, + tma_atom_out: cute.CopyAtom, + tma_tensor_out: cute.Tensor, + scales_out_u8: cute.Tensor, + sign_vec: cute.Tensor, + M: cutlass.Int64, + K: cutlass.Int64, + k_blocks: cutlass.Int64, + m_cta_tiles: cutlass.Int64, + k_cta_tiles: cutlass.Int64, + blocked_scale_layout: cute.Layout, + SCALE_DIM_K: cutlass.Constexpr[int], + USE_RCEIL: cutlass.Constexpr[bool], + STAGE_COUNT: cutlass.Constexpr[int], + ): + """Main fused MXFP4 + RHT quantize kernel (warp-specialized TMA).""" + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + bidx, bidy, _ = cute.arch.block_idx() + + smem_allocator = utils.SmemAllocator() + storage = smem_allocator.allocate(SharedStorage) + tma_mbar_ptr0 = storage.tma_mbar_ptr.data_ptr() + tma_mbar_ptr1 = tma_mbar_ptr0 + if cutlass.const_expr(STAGE_COUNT_VALUE > 1): + tma_mbar_ptr1 = tma_mbar_ptr0 + 1 + + smem_layout_in, smem_layout_out = _make_tile_smem_layouts(TILE_M, TILE_K) + staged_layout_in = cute.make_layout( + (STAGE_COUNT_VALUE, TILE_M, TILE_K), + stride=(TILE_M * TILE_K, TILE_K, 1), + ) + staged_layout_out = cute.make_layout( + (STAGE_COUNT_VALUE, TILE_M, TILE_K_HALF), + stride=(TILE_M * TILE_K_HALF, TILE_K_HALF, 1), + ) + sIN_staged = storage.in_smem.get_tensor(staged_layout_in) + sOUT_staged = storage.out_smem.get_tensor(staged_layout_out) + stage_elems_in = TILE_M * TILE_K + stage_elems_out = TILE_M * TILE_K_HALF + sIN_tile0 = cute.make_tensor( + sIN_staged.iterator + 0 * stage_elems_in, smem_layout_in + ) + sOUT_tile0 = cute.make_tensor( + sOUT_staged.iterator + 0 * stage_elems_out, smem_layout_out + ) + sIN_tile1 = sIN_tile0 + sOUT_tile1 = sOUT_tile0 + if cutlass.const_expr(STAGE_COUNT_VALUE > 1): + sIN_tile1 = cute.make_tensor( + sIN_staged.iterator + 1 * stage_elems_in, smem_layout_in + ) + sOUT_tile1 = cute.make_tensor( + sOUT_staged.iterator + 1 * stage_elems_out, smem_layout_out + ) + + if tidx == 0: + cpasync.prefetch_descriptor(tma_atom_in) + cpasync.prefetch_descriptor(tma_atom_out) + cute.arch.mbarrier_init(tma_mbar_ptr0, 1) + if cutlass.const_expr(STAGE_COUNT_VALUE > 1): + cute.arch.mbarrier_init(tma_mbar_ptr1, 1) + cute.arch.mbarrier_init_fence() + cute.arch.sync_threads() + + # Load the length-32 sign vector into registers (broadcast over rows). + sign_reg = cute.make_rmem_tensor((SCALE_DIM_K_VALUE,), cutlass.Float32) + for j in cutlass.range_constexpr(SCALE_DIM_K_VALUE): + sign_reg[j] = cutlass.Float32(sign_vec[j]) + + k_tile_group_idx = cutlass.Int64(bidx) + m_tile = cutlass.Int64(bidy) + m0 = m_tile * TILE_M + if cutlass.const_expr(BLOCKED_SCALE_OUTPUT_VALUE): + scales_tensor = cute.make_tensor( + scales_out_u8.iterator, + blocked_scale_layout, + ) + else: + scales_tensor = scales_out_u8 + + for tile_step in cutlass.range_constexpr(K_TILES_PER_CTA): + k_tile_eff = k_tile_group_idx * K_TILES_PER_CTA + tile_step + + stage_idx = tile_step % STAGE_COUNT + + sIN_tile = sIN_tile0 + sOUT_tile = sOUT_tile0 + tma_mbar_ptr = tma_mbar_ptr0 + if cutlass.const_expr(STAGE_COUNT > 1): + tma_mbar_ptr = tma_mbar_ptr0 + stage_idx + if cutlass.const_expr(STAGE_COUNT > 1): + if stage_idx == 1: + sIN_tile = sIN_tile1 + sOUT_tile = sOUT_tile1 + + tma_phase = (tile_step // STAGE_COUNT) % 2 + + if cutlass.const_expr( + tile_step == 0 or not (STAGE_COUNT > 1 and K_TILES_PER_CTA > 1) + ): + gIN_tile = cute.local_tile( + tma_tensor_in, (TILE_M, TILE_K), (m_tile, k_tile_eff) + ) + self._issue_tma_load( + tma_atom_in, + gIN_tile, + sIN_tile, + tma_mbar_ptr, + warp_idx, + ) + + if cutlass.const_expr(STAGE_COUNT > 1 and K_TILES_PER_CTA > 1): + if cutlass.const_expr(tile_step + 1 < K_TILES_PER_CTA): + k_tile_next = k_tile_group_idx * K_TILES_PER_CTA + tile_step + 1 + next_stage_idx = (tile_step + 1) % STAGE_COUNT + sIN_tile_next = sIN_tile0 + tma_mbar_ptr_next = tma_mbar_ptr0 + if cutlass.const_expr(STAGE_COUNT > 1): + tma_mbar_ptr_next = tma_mbar_ptr0 + next_stage_idx + if cutlass.const_expr(STAGE_COUNT > 1): + if next_stage_idx == 1: + sIN_tile_next = sIN_tile1 + + gIN_tile_next = cute.local_tile( + tma_tensor_in, (TILE_M, TILE_K), (m_tile, k_tile_next) + ) + self._issue_tma_load( + tma_atom_in, + gIN_tile_next, + sIN_tile_next, + tma_mbar_ptr_next, + warp_idx, + ) + + if warp_idx >= 1 and warp_idx <= compute_warps: + cute.arch.mbarrier_wait(tma_mbar_ptr, tma_phase) + lane = tidx % 32 + m_lane = (warp_idx - 1) * 32 + lane + + for mm in cutlass.range_constexpr(M_ITERS_PER_LANE): + m_rel = m_lane + mm * M_THREADS + m = m0 + m_rel + if m_rel < TILE_M: + scale_buffer = cute.make_rmem_tensor( + (K_BLOCKS_PER_TILE,), cutlass.Uint8 + ) + + for kb in cutlass.range_constexpr(K_BLOCKS_PER_TILE): + k_base = kb * SCALE_DIM_K_VALUE + vals_block = self._load_block_full_smem_to_reg( + sIN_tile, + m_rel, + k_base, + ) + + # Fused RHT: in-register FWHT(32) + sign. + fwht32_sign(vals_block, sign_reg) + + amax = compute_amax(vals_block) + scale_biased, inv_scale = compute_scale_byte_fp4( + amax, USE_RCEIL + ) + scale_buffer[kb] = cutlass.Uint8( + scale_biased & cutlass.Int32(0xFF) + ) + + self._quantize_block_then_store_reg_to_smem_full( + vals_block, + inv_scale, + sOUT_tile, + m_rel, + k_base, + USE_RCEIL, + ) + + k_block_base = k_tile_eff * K_BLOCKS_PER_TILE + self._store_scales_reg_to_gmem_vec( + scales_tensor, + m, + k_block_base, + scale_buffer, + cutlass.Int32(K_BLOCKS_PER_TILE), + BLOCKED_SCALE_OUTPUT_VALUE, + ) + + gOUT_tile = cute.local_tile( + tma_tensor_out, (TILE_M, TILE_K_HALF), (m_tile, k_tile_eff) + ) + self._issue_tma_store( + tma_atom_out, + gOUT_tile, + sOUT_tile, + warp_idx, + ) + + @cute.jit + def __call__( + self, + inp_mk: cute.Tensor, + out_mk: cute.Tensor, + scales_out_u8: cute.Tensor, + sign_vec: cute.Tensor, + M: cutlass.Int64, + K: cutlass.Int64, + k_blocks: cutlass.Int64, + m_cta_tiles: cutlass.Int64, + k_cta_tiles: cutlass.Int64, + stream: cuda.CUstream, + ): + """Kernel launcher: set up TMA descriptors and blocked scale layout.""" + smem_layout_in, smem_layout_out = _make_tile_smem_layouts(TILE_M, TILE_K) + g2s_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + tma_atom_in, tma_tensor_in = cpasync.make_tiled_tma_atom( + g2s_op, + inp_mk, + smem_layout_in, + (TILE_M, TILE_K), + ) + tma_atom_out, tma_tensor_out = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + out_mk, + smem_layout_out, + (TILE_M, TILE_K_HALF), + ) + + blocked_scale_layout = cute.make_layout((1,)) + if cutlass.const_expr(BLOCKED_SCALE_OUTPUT_VALUE): + padded_scale_cols = cute.round_up(k_blocks, 4) + m_block_tiles = cute.ceil_div(M, 128) + k_block_tiles = padded_scale_cols // cutlass.Int64(4) + blocked_scale_layout = cute.make_layout( + ((32, 4, m_block_tiles), (4, k_block_tiles)), + stride=( + (16, 4, cutlass.Int64(128) * padded_scale_cols), + (1, cutlass.Int64(512)), + ), + ) + + self.kernel( + inp_mk, + tma_atom_in, + tma_tensor_in, + out_mk, + tma_atom_out, + tma_tensor_out, + scales_out_u8, + sign_vec, + M, + K, + k_blocks, + m_cta_tiles, + k_cta_tiles, + blocked_scale_layout, + SCALE_DIM_K=SCALE_DIM_K_VALUE, + USE_RCEIL=USE_RCEIL_VALUE, + STAGE_COUNT=STAGE_COUNT_VALUE, + ).launch( + grid=(k_cta_tiles, m_cta_tiles, 1), + block=(THREADS_PER_BLOCK, 1, 1), + cluster=(1, 1, 1), + smem=SharedStorage.size_in_bytes(), # pyrefly: ignore [missing-attribute] + stream=stream, + ) + + kernel = Mxfp4RhtQuantize2dKernel() + + m = cute.sym_int(divisibility=128) + k = cute.sym_int(divisibility=128) + k_half = cute.sym_int(divisibility=64) + kb = cute.sym_int() + inp_stride0 = cute.sym_int() + inp_stride1 = cute.sym_int() + out_stride0 = cute.sym_int() + out_stride1 = cute.sym_int() + scale_stride0 = cute.sym_int() + scale_stride1 = cute.sym_int() + sign_stride0 = cute.sym_int() + + fake_inp = make_fake_tensor( + INPUT_CUTLASS_DTYPE, + (m, k), + stride=(inp_stride0, inp_stride1), + ) + fake_out = make_fake_tensor( + cutlass.Uint8, + (m, k_half), + stride=(out_stride0, out_stride1), + ) + if blocked_scale_output: + scale_flat = cute.sym_int() + fake_scales = make_fake_tensor( + cutlass.Uint8, + (scale_flat,), + stride=(scale_stride0,), + ) + else: + fake_scales = make_fake_tensor( + cutlass.Uint8, + (m, kb), + stride=(scale_stride0, scale_stride1), + ) + fake_sign = make_fake_tensor( + cutlass.Int32, + (32,), + stride=(sign_stride0,), + ) + fake_stream = make_fake_stream() + + return cute.compile( + kernel, + inp_mk=fake_inp, + out_mk=fake_out, + scales_out_u8=fake_scales, + sign_vec=fake_sign, + M=0, + K=0, + k_blocks=0, + m_cta_tiles=1, + k_cta_tiles=1, + stream=fake_stream, + options="--enable-tvm-ffi", + ) + + +def _mxfp4_rht_quantize_cutedsl_impl( + x: torch.Tensor, + sign_vector: list[int], + block_size: int = 32, + scaling_mode: str = "floor", + is_swizzled_scales: bool = True, + stage_count: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Host wrapper: launch the fused MXFP4 + RHT CuTeDSL quantize kernel. + + Args: + x: 2D contiguous bf16/fp32 input ``(M, K)`` with ``M % 128 == 0`` and + ``K % 128 == 0``. + sign_vector: length-32 list of ``{-1, +1}`` for the RHT sign multiply. + block_size: only 32 is supported. + scaling_mode: ``"floor"`` or ``"rceil"``. + is_swizzled_scales: write scales in the cuBLAS-blocked padded layout + (``True``) or a plain ``(M, K // 32)`` tensor (``False``). + stage_count: pipeline stages (1 or 2). + + Returns: + ``(qdata, scales)`` where ``qdata`` is row-major ``(M, K // 2)`` uint8 + (packed E2M1x2) and ``scales`` is ``float8_e8m0fnu`` in the requested + layout. + """ + import cuda.bindings.driver as cuda + + assert x.is_cuda, "Input tensor must be CUDA" + assert x.dim() == 2, "Input tensor must be 2D" + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dtype in ( + torch.float32, + torch.bfloat16, + ), "Input tensor must be float32 or bfloat16" + assert block_size == 32, "Only block_size=32 is supported" + scaling_mode = scaling_mode.lower() + assert scaling_mode in ("floor", "rceil"), ( + f"Unsupported scaling_mode={scaling_mode!r}; expected 'floor' or 'rceil'" + ) + assert len(sign_vector) == 32, "sign_vector must have length 32" + + M, K = x.shape + assert K % 32 == 0, "K must be divisible by 32" + assert M % 128 == 0, "M must be divisible by 128 (TMA tiling)" + assert K % 128 == 0, "K must be divisible by 128 (TMA tiling)" + + _, config = _select_cutedsl_config(x.dtype, scaling_mode) + compute_warps, tile_m, tile_k, k_tiles_per_cta = config + assert stage_count >= 1, "stage_count must be >= 1" + assert stage_count <= 2, "stage_count must be <= 2" + + k_blocks = K // block_size + + # Half-width row-major packed output: stride (K // 2, 1). + q_data = torch.empty_strided( + (M, K // 2), + (K // 2, 1), + device=x.device, + dtype=torch.uint8, + ) + + padded_scale_rows = ceil_div(M, 128) * 128 + padded_scale_cols = ceil_div(k_blocks, 4) * 4 + if is_swizzled_scales: + scales_u8 = torch.empty( + (padded_scale_rows * padded_scale_cols,), + device=x.device, + dtype=torch.uint8, + ) + else: + scales_u8 = torch.empty( + (M, k_blocks), + device=x.device, + dtype=torch.uint8, + ) + + sign_dev = torch.tensor( + [int(s) for s in sign_vector], device=x.device, dtype=torch.int32 + ) + + compiled = _compile_mxfp4_rht_quantize_2d_cutedsl( + str(x.dtype), + scaling_mode, + compute_warps, + tile_m, + tile_k, + stage_count, + k_tiles_per_cta, + is_swizzled_scales, + ) + + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + m_cta_tiles = ceil_div(M, tile_m) + k_cta_tiles = ceil_div(K, tile_k * k_tiles_per_cta) + + compiled( + x, + q_data, + scales_u8, + sign_dev, + int(M), + int(K), + int(k_blocks), + int(m_cta_tiles), + int(k_cta_tiles), + stream, + ) + + scales = scales_u8.view(torch.float8_e8m0fnu) + scales = ( + scales.view(padded_scale_rows, padded_scale_cols) + if is_swizzled_scales + else scales.view(M, k_blocks) + ) + return q_data, scales + + +@torch.library.custom_op("torchao::mxfp4_rht_quantize_cutedsl", mutates_args=()) +def mxfp4_rht_quantize_cutedsl( + x: torch.Tensor, + sign_vector: list[int], + block_size: int = 32, + scaling_mode: str = "floor", + is_swizzled_scales: bool = True, + stage_count: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _mxfp4_rht_quantize_cutedsl_impl( + x, + sign_vector, + block_size=block_size, + scaling_mode=scaling_mode, + is_swizzled_scales=is_swizzled_scales, + stage_count=stage_count, + ) + + +@mxfp4_rht_quantize_cutedsl.register_fake +def _( + x: torch.Tensor, + sign_vector: list[int], + block_size: int = 32, + scaling_mode: str = "floor", + is_swizzled_scales: bool = True, + stage_count: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + m, k = x.shape + q = torch.empty_strided( + (m, k // 2), (k // 2, 1), device=x.device, dtype=torch.uint8 + ) # row-major pinned + kb = k // block_size + if is_swizzled_scales: + scales = x.new_empty( + (ceil_div(m, 128) * 128, ceil_div(kb, 4) * 4), + dtype=torch.float8_e8m0fnu, + ) + else: + scales = x.new_empty((m, kb), dtype=torch.float8_e8m0fnu) + return q, scales + + +def mxfp4_rht_quantize_cutedsl_2d( + x: torch.Tensor, + sign_vector, + block_size: int = 32, + scaling_mode: str = "floor", + is_swizzled_scales: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Gated public wrapper for the fused MXFP4 + RHT CuTeDSL quantize op. + + Raises ``NotImplementedError`` (with the missing-runtime detail) when the + CuTeDSL runtime / SM 10.x / CUDA >= 12.8 requirements are not met. + """ + from torchao.prototype.mx_formats.cutedsl import ( + _mxfp4_rht_cutedsl_kernels_available, + ) + + if not _mxfp4_rht_cutedsl_kernels_available: + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + _missing_cutedsl_runtime_packages, + ) + + raise NotImplementedError( + "mxfp4_rht_quantize_cutedsl requires CUDA SM10.x, CUDA>=12.8, and: " + f"{_missing_cutedsl_runtime_packages() or 'nvidia-cutlass-dsl'}" + ) + return mxfp4_rht_quantize_cutedsl( + x, list(sign_vector), block_size, scaling_mode, is_swizzled_scales + ) diff --git a/torchao/prototype/mx_formats/cutedsl/nvfp4_rht_quantize.py b/torchao/prototype/mx_formats/cutedsl/nvfp4_rht_quantize.py new file mode 100644 index 0000000000..e410502ebe --- /dev/null +++ b/torchao/prototype/mx_formats/cutedsl/nvfp4_rht_quantize.py @@ -0,0 +1,1214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Fused NVFP4 (E2M1, block 16, E4M3 two-level scale) +/- RHT CuTeDSL cast. + +Clones the structure of the fused MXFP4 + RHT CuTeDSL kernel +(``mxfp4_rht_quantize.py``) and applies the NVFP4 deltas: + +* block size is 16 (not 32): each quantization block has 16 elements, packed to + 8 ``E2M1x2`` bytes (8 ``cvt.rn.satfinite.e2m1x2.f32`` calls per block, written + as 2 ``uint32`` stores). The output qdata is still half-width ``[M, K // 2]`` + ``uint8`` row-major; a K-tile of 128 holds 8 sixteen-element blocks; +* the block scale is the NVFP4 two-level E4M3 scale: per block, + ``compute_amax`` -> ``compute_nvfp4_scale_e4m3(amax, global_scale)`` -> + ``(e4m3_byte, inv_scale)``. The E4M3 byte is stored as ``float8_e4m3fn`` and + the data is quantized with ``inv_scale`` (``code = e2m1(val * inv_scale)``); +* the Random Hadamard Transform is optional. When a length-16 ``sign_vector`` is + provided the consumer applies the register-resident FWHT(16) + sign transform + (``fwht.fwht16_sign``) before amax / scale / packing; when it is empty (``[]``) + the transform is skipped via a compile-time ``apply_rht`` constexpr so the + no-RHT path carries no FWHT overhead (plain NVFP4 cast); +* values are clamped to ``+-6.0`` (``F4_E2M1_MAX``) before the E2M1 cvt; +* the E4M3 scale byte is written either in the cuBLAS-blocked padded layout + (``is_swizzled_scales=True``) or as a plain ``(M, K // 16)`` tensor. + +The op is gated behind a Blackwell (SM 10.x) GPU, CUDA >= 12.8, and the CuTeDSL +runtime packages (see ``cutedsl/__init__.py``). +""" + +import functools +from typing import Tuple + +import torch + +from torchao.utils import ceil_div + +from .cute_utils import ( + _cvt_rn_satfinite_e2m1x2_f32, + compute_amax, + compute_nvfp4_scale_e4m3, +) +from .fwht import fwht16_sign + +# Config format: +# (compute_warps, tile_m, tile_k, k_tiles_per_cta) +_CUTEDSL_CONFIGS = { + "bf16_default": (4, 128, 128, 4), + "fallback": (6, 128, 128, 2), +} + + +def _select_cutedsl_config( + input_dtype: torch.dtype, +) -> Tuple[str, Tuple[int, int, int, int]]: + """Select kernel configuration based on input dtype.""" + if input_dtype == torch.bfloat16: + config_name = "bf16_default" + else: + config_name = "fallback" + return config_name, _CUTEDSL_CONFIGS[config_name] + + +def _make_tile_smem_layouts(tile_m: int, tile_k: int): + """Row-major smem layouts for the input ``(TILE_M, TILE_K)`` and the + half-width output ``(TILE_M, TILE_K // 2)`` tiles.""" + import cutlass.cute as cute + + smem_layout_in = cute.make_layout( + (tile_m, tile_k), + stride=(tile_k, 1), + ) + smem_layout_out = cute.make_layout( + (tile_m, tile_k // 2), + stride=(tile_k // 2, 1), + ) + return smem_layout_in, smem_layout_out + + +@functools.cache +def _compile_nvfp4_rht_quantize_2d_cutedsl( + input_dtype_name: str, + apply_rht: bool, + compute_warps: int, + tile_m: int, + tile_k: int, + requested_stage_count: int, + k_tiles_per_cta: int, + blocked_scale_output: bool, +): + """Compile the fused 2D NVFP4 (+/- RHT) quantize kernel using CuTeDSL. + + Warp-specialized TMA kernel mirroring the MXFP4 1x32 template: + - warp 0: producer (TMA global->shared input, shared->global half-width + output); + - warps [1..compute_warps]: consumers (optional FWHT(16) + quantize + E2M1 + pack). + """ + import cuda.bindings.driver as cuda + import cutlass + import cutlass.cute as cute + import cutlass.utils as utils + from cutlass.cute.nvgpu import cpasync, tcgen05 + from cutlass.cute.runtime import make_fake_stream, make_fake_tensor + + if input_dtype_name == "torch.float32": + INPUT_CUTLASS_DTYPE = cutlass.Float32 + elif input_dtype_name == "torch.bfloat16": + INPUT_CUTLASS_DTYPE = cutlass.BFloat16 + else: + raise ValueError( + f"Unsupported input dtype for CuTeDSL nvfp4 quantize_2d: {input_dtype_name}" + ) + + COMPUTE_WARPS = compute_warps + TILE_M = tile_m + TILE_K = tile_k + TILE_K_HALF = tile_k // 2 + K_TILES_PER_CTA = k_tiles_per_cta + BLOCKED_SCALE_OUTPUT_VALUE = blocked_scale_output + APPLY_RHT_VALUE = apply_rht + + THREADS_PER_BLOCK = (1 + COMPUTE_WARPS) * 32 + assert COMPUTE_WARPS >= 1 + assert TILE_M > 0 and TILE_K > 0 + assert TILE_K % 16 == 0 + + SCALE_DIM_K_VALUE = 16 + SCALE_DIM_K_HALF = SCALE_DIM_K_VALUE // 2 # 8 packed bytes per block + K_BLOCKS_PER_TILE = TILE_K // SCALE_DIM_K_VALUE + assert K_BLOCKS_PER_TILE > 0 + assert requested_stage_count >= 1 + assert requested_stage_count <= 2 + assert K_TILES_PER_CTA >= 1 + STAGE_COUNT_VALUE = min(requested_stage_count, K_TILES_PER_CTA) + + input_elem_bytes = 4 if input_dtype_name == "torch.float32" else 2 + TILE_COPY_BYTES_IN = TILE_M * TILE_K * input_elem_bytes + # Half-width uint8 output: 1 byte per packed E2M1 pair. + TILE_COPY_BYTES_OUT = TILE_M * TILE_K_HALF # noqa: F841 (documented contract) + M_THREADS = COMPUTE_WARPS * 32 + M_ITERS_PER_LANE = ceil_div(TILE_M, M_THREADS) + + # F4_E2M1_MAX == 6.0; clamp pre-cvt values before the E2M1 conversion. + F4_MAX = cutlass.Float32(6.0) + + @cute.struct + class SharedStorage: + tma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, STAGE_COUNT_VALUE] + in_smem: cute.struct.Align[ + cute.struct.MemRange[ + INPUT_CUTLASS_DTYPE, STAGE_COUNT_VALUE * TILE_M * TILE_K + ], + 128, + ] + out_smem: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Uint8, STAGE_COUNT_VALUE * TILE_M * TILE_K_HALF + ], + 128, + ] + + class Nvfp4RhtQuantize2dKernel: + @cute.jit + def _load_block_full_smem_to_reg( + self, + sIN_tile: cute.Tensor, + m_rel: cutlass.Int32, + k_base: cutlass.Int32, + ): + """Load a full 16-element quantization block from smem to registers.""" + vals_block = cute.make_rmem_tensor((SCALE_DIM_K_VALUE,), cutlass.Float32) + for i in cutlass.range_constexpr(SCALE_DIM_K_VALUE): + vals_block[i] = cutlass.Float32(sIN_tile[m_rel, k_base + i]) + return vals_block + + @cute.jit + def _store_scales_reg_to_gmem_vec( + self, + scales_tensor: cute.Tensor, + m: cutlass.Int64, + k_block_base: cutlass.Int64, + scale_buffer: cute.Tensor, + num_scales: cutlass.Int32, + BLOCKED_SCALE_OUTPUT: cutlass.Constexpr[bool], + ): + """Store scales from registers to gmem (uint32 vectorized for blocked).""" + if cutlass.const_expr(BLOCKED_SCALE_OUTPUT): + if num_scales == 4: + scales_tensor_u32 = cute.recast_tensor( + scales_tensor, cutlass.Uint32 + ) + scale_buffer_u32 = cute.recast_tensor(scale_buffer, cutlass.Uint32) + scales_tensor_u32[m, k_block_base // cutlass.Int64(4)] = ( + scale_buffer_u32[0] + ) + else: + for i in range(num_scales): + k_block = k_block_base + i + scales_tensor[m, k_block] = scale_buffer[i] + else: + for i in range(num_scales): + k_block = k_block_base + i + scales_tensor[m, k_block] = scale_buffer[i] + + @cute.jit + def _store_q_e2m1_block_to_smem( + self, + packed_bytes: cute.Tensor, + sOUT_tile: cute.Tensor, + m_rel: cutlass.Int32, + sout_base: cutlass.Int32, + ): + """Store the 8 packed E2M1x2 bytes of a block via 2 uint32 writes. + + ``packed_bytes`` is a length-8 ``Uint8`` register fragment holding + the bytes for one 16-element block. ``sout_base`` is the byte offset + of the block within the half-width output tile (``k_base // 2``). + """ + sOUT_tile_u32 = cute.recast_tensor(sOUT_tile, cutlass.Uint32) + packed_u32 = cute.recast_tensor(packed_bytes, cutlass.Uint32) + base_u32 = sout_base // cutlass.Int32(4) + for w in cutlass.range_constexpr(SCALE_DIM_K_HALF // 4): + sOUT_tile_u32[m_rel, base_u32 + w] = packed_u32[w] + + @cute.jit + def _quantize_block_then_store_reg_to_smem_full( + self, + rht_block: cute.Tensor, + inv_scale: cutlass.Float32, + sOUT_tile: cute.Tensor, + m_rel: cutlass.Int32, + k_base: cutlass.Int32, + ): + """Scale, clamp, pack 16 values to 8 E2M1x2 bytes. + + ``rht_block`` is the (optionally post-FWHT) length-16 ``Float32`` + fragment. Even column ``2p`` -> low nibble, odd column ``2p + 1`` -> + high nibble of output byte ``p`` (validated bit-exactly in Task 1). + """ + packed_bytes = cute.make_rmem_tensor((SCALE_DIM_K_HALF,), cutlass.Uint8) + for p in cutlass.range_constexpr(SCALE_DIM_K_HALF): + lo = cutlass.Float32(rht_block[2 * p]) * inv_scale + hi = cutlass.Float32(rht_block[2 * p + 1]) * inv_scale + # Clamp to +-F4_E2M1_MAX before the cvt. + if lo > F4_MAX: + lo = F4_MAX + if lo < -F4_MAX: + lo = -F4_MAX + if hi > F4_MAX: + hi = F4_MAX + if hi < -F4_MAX: + hi = -F4_MAX + packed_bytes[p] = _cvt_rn_satfinite_e2m1x2_f32(hi, lo) + sout_base = k_base // cutlass.Int32(2) + self._store_q_e2m1_block_to_smem(packed_bytes, sOUT_tile, m_rel, sout_base) + + @cute.jit + def _issue_tma_load( + self, + tma_atom_in: cute.CopyAtom, + gIN_tile: cute.Tensor, + sIN_tile: cute.Tensor, + tma_mbar_ptr: cutlass.Int64, + warp_idx: cutlass.Int32, + ): + """Issue TMA load from global to shared memory (producer warp only).""" + if warp_idx == 0: + cta_layout = cute.make_layout((1,)) + sIN_for_tma_partition = cute.group_modes(sIN_tile, 0, 1) + gIN_for_tma_partition = cute.group_modes(gIN_tile, 0, 1) + tINs, tINg = cpasync.tma_partition( + tma_atom_in, + 0, + cta_layout, + sIN_for_tma_partition, + gIN_for_tma_partition, + ) + tINg_stage0 = tINg[(None, 0)] + tINs_stage0 = tINs[(None, 0)] + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + tma_mbar_ptr, TILE_COPY_BYTES_IN + ) + cute.copy( + tma_atom_in, + tINg_stage0, + tINs_stage0, + tma_bar_ptr=tma_mbar_ptr, + ) + + @cute.jit + def _issue_tma_store( + self, + tma_atom_out: cute.CopyAtom, + gOUT_tile: cute.Tensor, + sOUT_tile: cute.Tensor, + warp_idx: cutlass.Int32, + ): + """Issue TMA store from shared to global memory (producer warp only). + + The store is a ``cp.async.bulk`` (S2G TMA) copy and is asynchronous. + We commit it into a bulk-async group and then wait for the group to + drain (``read=True`` -> the smem source has been fully read) before + returning. Without this drain, the next ``tile_step`` consumer would + overwrite ``sOUT_tile`` (single-buffered, or the same stage buffer in + the 2-stage pipeline) while this store is still reading it -- a + store/recompute race that intermittently corrupted both qdata and + scales for multi-K-tile shapes (K >= TILE_K * 2). The trailing + ``sync_threads`` makes the drain CTA-wide so all consumer warps see a + free buffer before reuse. + """ + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + cute.arch.sync_threads() + if warp_idx == 0: + cta_layout = cute.make_layout((1,)) + sOUT_for_tma_partition = cute.group_modes(sOUT_tile, 0, 1) + gOUT_for_tma_partition = cute.group_modes(gOUT_tile, 0, 1) + tOUTs, tOUTg = cpasync.tma_partition( + tma_atom_out, + 0, + cta_layout, + sOUT_for_tma_partition, + gOUT_for_tma_partition, + ) + tOUTs_stage0 = tOUTs[(None, 0)] + tOUTg_stage0 = tOUTg[(None, 0)] + cute.copy( + tma_atom_out, + tOUTs_stage0, + tOUTg_stage0, + ) + cute.arch.cp_async_bulk_commit_group() + # Wait for the store to finish reading sOUT before it is reused. + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.sync_threads() + + @cute.kernel + def kernel( + self, + inp_mk: cute.Tensor, + tma_atom_in: cute.CopyAtom, + tma_tensor_in: cute.Tensor, + out_mk: cute.Tensor, + tma_atom_out: cute.CopyAtom, + tma_tensor_out: cute.Tensor, + scales_out_u8: cute.Tensor, + sign_vec: cute.Tensor, + global_scale: cutlass.Float32, + M: cutlass.Int64, + K: cutlass.Int64, + k_blocks: cutlass.Int64, + m_cta_tiles: cutlass.Int64, + k_cta_tiles: cutlass.Int64, + blocked_scale_layout: cute.Layout, + SCALE_DIM_K: cutlass.Constexpr[int], + APPLY_RHT: cutlass.Constexpr[bool], + STAGE_COUNT: cutlass.Constexpr[int], + ): + """Main fused NVFP4 (+/- RHT) quantize kernel (warp-specialized TMA).""" + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + bidx, bidy, _ = cute.arch.block_idx() + + smem_allocator = utils.SmemAllocator() + storage = smem_allocator.allocate(SharedStorage) + tma_mbar_ptr0 = storage.tma_mbar_ptr.data_ptr() + tma_mbar_ptr1 = tma_mbar_ptr0 + if cutlass.const_expr(STAGE_COUNT_VALUE > 1): + tma_mbar_ptr1 = tma_mbar_ptr0 + 1 + + smem_layout_in, smem_layout_out = _make_tile_smem_layouts(TILE_M, TILE_K) + staged_layout_in = cute.make_layout( + (STAGE_COUNT_VALUE, TILE_M, TILE_K), + stride=(TILE_M * TILE_K, TILE_K, 1), + ) + staged_layout_out = cute.make_layout( + (STAGE_COUNT_VALUE, TILE_M, TILE_K_HALF), + stride=(TILE_M * TILE_K_HALF, TILE_K_HALF, 1), + ) + sIN_staged = storage.in_smem.get_tensor(staged_layout_in) + sOUT_staged = storage.out_smem.get_tensor(staged_layout_out) + stage_elems_in = TILE_M * TILE_K + stage_elems_out = TILE_M * TILE_K_HALF + sIN_tile0 = cute.make_tensor( + sIN_staged.iterator + 0 * stage_elems_in, smem_layout_in + ) + sOUT_tile0 = cute.make_tensor( + sOUT_staged.iterator + 0 * stage_elems_out, smem_layout_out + ) + sIN_tile1 = sIN_tile0 + sOUT_tile1 = sOUT_tile0 + if cutlass.const_expr(STAGE_COUNT_VALUE > 1): + sIN_tile1 = cute.make_tensor( + sIN_staged.iterator + 1 * stage_elems_in, smem_layout_in + ) + sOUT_tile1 = cute.make_tensor( + sOUT_staged.iterator + 1 * stage_elems_out, smem_layout_out + ) + + if tidx == 0: + cpasync.prefetch_descriptor(tma_atom_in) + cpasync.prefetch_descriptor(tma_atom_out) + cute.arch.mbarrier_init(tma_mbar_ptr0, 1) + if cutlass.const_expr(STAGE_COUNT_VALUE > 1): + cute.arch.mbarrier_init(tma_mbar_ptr1, 1) + cute.arch.mbarrier_init_fence() + cute.arch.sync_threads() + + # Load the length-16 sign vector into registers (broadcast over rows). + # Only needed for the RHT path; the no-RHT path skips the FWHT entirely. + sign_reg = cute.make_rmem_tensor((SCALE_DIM_K_VALUE,), cutlass.Float32) + if cutlass.const_expr(APPLY_RHT): + for j in cutlass.range_constexpr(SCALE_DIM_K_VALUE): + sign_reg[j] = cutlass.Float32(sign_vec[j]) + + k_tile_group_idx = cutlass.Int64(bidx) + m_tile = cutlass.Int64(bidy) + m0 = m_tile * TILE_M + # Real number of K-tiles. A CTA always loops K_TILES_PER_CTA times, + # but when ``K < TILE_K * K_TILES_PER_CTA`` (e.g. K=128 with TILE_K=128 + # and K_TILES_PER_CTA=4) the trailing tile_steps address K-tiles past + # the real K range. The TMA load/store are bounds-clamped by their + # descriptors, but the raw (non-TMA) per-block scale store is NOT, so + # those OOB k_tile_eff would write into wrong scale columns/rows. Guard + # the scale store with this count. + k_tiles_total = K // cutlass.Int64(TILE_K) + if cutlass.const_expr(BLOCKED_SCALE_OUTPUT_VALUE): + scales_tensor = cute.make_tensor( + scales_out_u8.iterator, + blocked_scale_layout, + ) + else: + scales_tensor = scales_out_u8 + + for tile_step in cutlass.range_constexpr(K_TILES_PER_CTA): + k_tile_eff = k_tile_group_idx * K_TILES_PER_CTA + tile_step + + stage_idx = tile_step % STAGE_COUNT + + sIN_tile = sIN_tile0 + sOUT_tile = sOUT_tile0 + tma_mbar_ptr = tma_mbar_ptr0 + if cutlass.const_expr(STAGE_COUNT > 1): + tma_mbar_ptr = tma_mbar_ptr0 + stage_idx + if cutlass.const_expr(STAGE_COUNT > 1): + if stage_idx == 1: + sIN_tile = sIN_tile1 + sOUT_tile = sOUT_tile1 + + tma_phase = (tile_step // STAGE_COUNT) % 2 + + if cutlass.const_expr( + tile_step == 0 or not (STAGE_COUNT > 1 and K_TILES_PER_CTA > 1) + ): + gIN_tile = cute.local_tile( + tma_tensor_in, (TILE_M, TILE_K), (m_tile, k_tile_eff) + ) + self._issue_tma_load( + tma_atom_in, + gIN_tile, + sIN_tile, + tma_mbar_ptr, + warp_idx, + ) + + if cutlass.const_expr(STAGE_COUNT > 1 and K_TILES_PER_CTA > 1): + if cutlass.const_expr(tile_step + 1 < K_TILES_PER_CTA): + k_tile_next = k_tile_group_idx * K_TILES_PER_CTA + tile_step + 1 + next_stage_idx = (tile_step + 1) % STAGE_COUNT + sIN_tile_next = sIN_tile0 + tma_mbar_ptr_next = tma_mbar_ptr0 + if cutlass.const_expr(STAGE_COUNT > 1): + tma_mbar_ptr_next = tma_mbar_ptr0 + next_stage_idx + if cutlass.const_expr(STAGE_COUNT > 1): + if next_stage_idx == 1: + sIN_tile_next = sIN_tile1 + + gIN_tile_next = cute.local_tile( + tma_tensor_in, (TILE_M, TILE_K), (m_tile, k_tile_next) + ) + self._issue_tma_load( + tma_atom_in, + gIN_tile_next, + sIN_tile_next, + tma_mbar_ptr_next, + warp_idx, + ) + + if warp_idx >= 1 and warp_idx <= compute_warps: + cute.arch.mbarrier_wait(tma_mbar_ptr, tma_phase) + lane = tidx % 32 + m_lane = (warp_idx - 1) * 32 + lane + + for mm in cutlass.range_constexpr(M_ITERS_PER_LANE): + m_rel = m_lane + mm * M_THREADS + m = m0 + m_rel + if m_rel < TILE_M: + scale_buffer = cute.make_rmem_tensor( + (K_BLOCKS_PER_TILE,), cutlass.Uint8 + ) + + for kb in cutlass.range_constexpr(K_BLOCKS_PER_TILE): + k_base = kb * SCALE_DIM_K_VALUE + vals_block = self._load_block_full_smem_to_reg( + sIN_tile, + m_rel, + k_base, + ) + + # Optional fused RHT: in-register FWHT(16) + sign. + if cutlass.const_expr(APPLY_RHT): + fwht16_sign(vals_block, sign_reg) + + amax = compute_amax(vals_block) + e4m3_byte, inv_scale = compute_nvfp4_scale_e4m3( + amax, global_scale + ) + scale_buffer[kb] = e4m3_byte + + # Always quantize into smem so the (TMA-clamped) + # output store has valid data and the producer/ + # consumer pipeline stays balanced for every + # tile_step, exactly as in the all-tiles-valid + # case. + self._quantize_block_then_store_reg_to_smem_full( + vals_block, + inv_scale, + sOUT_tile, + m_rel, + k_base, + ) + + # Guard the raw (non-bounds-checked) scale store: for + # K < TILE_K*K_TILES_PER_CTA the trailing tile_steps + # address K-tiles past the real K range; their + # k_block_base would scribble into wrong scale + # columns/rows. + if k_tile_eff < k_tiles_total: + k_block_base = k_tile_eff * K_BLOCKS_PER_TILE + self._store_scales_reg_to_gmem_vec( + scales_tensor, + m, + k_block_base, + scale_buffer, + cutlass.Int32(K_BLOCKS_PER_TILE), + BLOCKED_SCALE_OUTPUT_VALUE, + ) + + # Guard the output qdata TMA store too: for K-tiles past the real + # K range the destination ``gOUT_tile`` addresses byte columns + # beyond the ``(M, K // 2)`` output tensor; the bulk S2G copy is + # NOT bounds-clamped here, so an OOB store scribbles into adjacent + # device memory (deterministically corrupting the *next* op's + # freshly-allocated output). The in-bounds check is CTA-uniform + # (depends only on bidx), so guarding the whole call -- including + # the sync_threads inside _issue_tma_store -- is barrier-safe. + if k_tile_eff < k_tiles_total: + gOUT_tile = cute.local_tile( + tma_tensor_out, (TILE_M, TILE_K_HALF), (m_tile, k_tile_eff) + ) + self._issue_tma_store( + tma_atom_out, + gOUT_tile, + sOUT_tile, + warp_idx, + ) + + @cute.jit + def __call__( + self, + inp_mk: cute.Tensor, + out_mk: cute.Tensor, + scales_out_u8: cute.Tensor, + sign_vec: cute.Tensor, + global_scale: cutlass.Float32, + M: cutlass.Int64, + K: cutlass.Int64, + k_blocks: cutlass.Int64, + m_cta_tiles: cutlass.Int64, + k_cta_tiles: cutlass.Int64, + stream: cuda.CUstream, + ): + """Kernel launcher: set up TMA descriptors and blocked scale layout.""" + smem_layout_in, smem_layout_out = _make_tile_smem_layouts(TILE_M, TILE_K) + g2s_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + tma_atom_in, tma_tensor_in = cpasync.make_tiled_tma_atom( + g2s_op, + inp_mk, + smem_layout_in, + (TILE_M, TILE_K), + ) + tma_atom_out, tma_tensor_out = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + out_mk, + smem_layout_out, + (TILE_M, TILE_K_HALF), + ) + + blocked_scale_layout = cute.make_layout((1,)) + if cutlass.const_expr(BLOCKED_SCALE_OUTPUT_VALUE): + padded_scale_cols = cute.round_up(k_blocks, 4) + m_block_tiles = cute.ceil_div(M, 128) + k_block_tiles = padded_scale_cols // cutlass.Int64(4) + blocked_scale_layout = cute.make_layout( + ((32, 4, m_block_tiles), (4, k_block_tiles)), + stride=( + (16, 4, cutlass.Int64(128) * padded_scale_cols), + (1, cutlass.Int64(512)), + ), + ) + + self.kernel( + inp_mk, + tma_atom_in, + tma_tensor_in, + out_mk, + tma_atom_out, + tma_tensor_out, + scales_out_u8, + sign_vec, + global_scale, + M, + K, + k_blocks, + m_cta_tiles, + k_cta_tiles, + blocked_scale_layout, + SCALE_DIM_K=SCALE_DIM_K_VALUE, + APPLY_RHT=APPLY_RHT_VALUE, + STAGE_COUNT=STAGE_COUNT_VALUE, + ).launch( + grid=(k_cta_tiles, m_cta_tiles, 1), + block=(THREADS_PER_BLOCK, 1, 1), + cluster=(1, 1, 1), + smem=SharedStorage.size_in_bytes(), # pyrefly: ignore [missing-attribute] + stream=stream, + ) + + kernel = Nvfp4RhtQuantize2dKernel() + + m = cute.sym_int(divisibility=128) + k = cute.sym_int(divisibility=128) + k_half = cute.sym_int(divisibility=64) + kb = cute.sym_int() + inp_stride0 = cute.sym_int() + inp_stride1 = cute.sym_int() + out_stride0 = cute.sym_int() + out_stride1 = cute.sym_int() + scale_stride0 = cute.sym_int() + scale_stride1 = cute.sym_int() + sign_stride0 = cute.sym_int() + + fake_inp = make_fake_tensor( + INPUT_CUTLASS_DTYPE, + (m, k), + stride=(inp_stride0, inp_stride1), + ) + fake_out = make_fake_tensor( + cutlass.Uint8, + (m, k_half), + stride=(out_stride0, out_stride1), + ) + if blocked_scale_output: + scale_flat = cute.sym_int() + fake_scales = make_fake_tensor( + cutlass.Uint8, + (scale_flat,), + stride=(scale_stride0,), + ) + else: + fake_scales = make_fake_tensor( + cutlass.Uint8, + (m, kb), + stride=(scale_stride0, scale_stride1), + ) + fake_sign = make_fake_tensor( + cutlass.Int32, + (16,), + stride=(sign_stride0,), + ) + fake_stream = make_fake_stream() + + return cute.compile( + kernel, + inp_mk=fake_inp, + out_mk=fake_out, + scales_out_u8=fake_scales, + sign_vec=fake_sign, + global_scale=cutlass.Float32(1.0), + M=0, + K=0, + k_blocks=0, + m_cta_tiles=1, + k_cta_tiles=1, + stream=fake_stream, + options="--enable-tvm-ffi", + ) + + +# ============================================================================ +# Max-bandwidth streaming path +# ============================================================================ +# A throughput-oriented alternative to the warp-specialized TMA kernel above. +# Byte-for-byte identical output for every mode (dtype x RHT x swizzled x shape) +# but ~1.5x faster (reaches the HBM roofline) via a simpler streaming design: +# each thread owns 2 contiguous 16-element NVFP4 blocks (= 32 input elems), reads +# them with forced 128-bit ``LDG`` (``CopyUniversalOp`` num_bits_per_copy=128 + +# ``cute.assume`` on the offset -- a plain ``iterator + dynamic_offset`` silently +# degrades to scalar loads), writes the 16 packed bytes with one wide 128-bit +# ``STG``, and uses heavy per-thread ILP to hide the load latency + the e2m1/e4m3 +# cvts. No explicit ``+-6`` clamps (``cvt.rn.satfinite`` saturates -> still +# bit-exact). The 2D ``(row = grid.y, group-in-row = grid.x)`` map keeps +# ``(row, kb)`` division-free for the swizzled scale offset. +# +# IMPORTANT: compiled with CONCRETE ``from_dlpack`` tensors, NOT +# ``make_fake_tensor`` -- the symbolic AOT path mis-lowers the single-byte scale +# store (corrupting scales while leaving the 128-bit qdata store correct). The +# kernel indexes purely through flat iterators + runtime ``Int32`` shape args, so +# one concrete compile generalizes across shapes; cached on +# ``(dtype, apply_rht, is_swizzled_scales, threads, ilp)``. + +# (dtype, apply_rht, is_swizzled_scales, threads, ilp) -> compiled cute kernel. +_MAXBW_JIT_CACHE: dict = {} + + +@functools.cache +def _get_maxbw_launch(): + """Define and return the (uncompiled) ``@cute.jit`` max-bandwidth launcher. + + Gated behind the CuTeDSL runtime: all ``cutlass`` imports and the kernel + definition live inside this function so importing the module never requires + the runtime to be installed. + """ + import cutlass + import cutlass.cute as cute + import cutlass.cute.nvgpu as nv + + from .cute_utils import ( + _cvt_rn_satfinite_e2m1x2_f32, + compute_amax, + compute_nvfp4_scale_e4m3, + ) + from .fwht import fwht16_sign + + @cute.kernel + def _maxbw_kernel( + gx: cute.Tensor, # (M*K,) input flat (bf16 or fp32) + gq: cute.Tensor, # (M*K // 2,) uint8 flat (packed E2M1x2) + gscale: cute.Tensor, # scale buffer flat uint8 + gsign: cute.Tensor, # (16,) int32 sign vector (RHT only) + M: cutlass.Int32, + K: cutlass.Int32, + GPR: cutlass.Int32, # groups per row = K // 32 + pad_cols: cutlass.Int32, # padded scale cols = ceil(K // 16, 4) * 4 (swizzled) + global_scale: cutlass.Float32, + ILP: cutlass.Constexpr[int], + APPLY_RHT: cutlass.Constexpr[bool], + SWIZZLED: cutlass.Constexpr[bool], + LDWIDTH: cutlass.Constexpr[int], # elems per 128-bit load (8 bf16, 4 fp32) + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy_init, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + gid = bidx * bdim + tidx + nthreads_x = cute.arch.grid_dim()[0] * bdim + + in_dt = gx.element_type + ld = cute.make_copy_atom(nv.CopyUniversalOp(), in_dt, num_bits_per_copy=128) + st = cute.make_copy_atom( + nv.CopyUniversalOp(), cutlass.Uint8, num_bits_per_copy=128 + ) + + # number of 128-bit loads to cover 32 input elems + NLD = 32 // LDWIDTH + + # Load the length-16 sign vector once (RHT only). + sign_reg = cute.make_rmem_tensor((16,), cutlass.Float32) + if cutlass.const_expr(APPLY_RHT): + for j in cutlass.range_constexpr(16): + sign_reg[j] = cutlass.Float32(gsign[j]) + + # 2D grid-stride over (row, gcol). grid.y indexes rows. + row = bidy_init + while row < M: + base = gid + while base < GPR: + # ---- issue ALL loads first (ILP) into a flat (ILP, 32) buffer ---- + fragbuf = cute.make_rmem_tensor((ILP * 32,), in_dt) + for jj in cutlass.range_constexpr(ILP): + gc = base + jj * nthreads_x + if gc < GPR: + off = cute.assume(row * K + gc * cutlass.Int32(32), divby=32) + for w in cutlass.range_constexpr(NLD): + s = cute.make_tensor( + gx.iterator + off + w * LDWIDTH, + cute.make_layout((LDWIDTH,), stride=(1,)), + ) + f = cute.make_tensor( + fragbuf.iterator + jj * 32 + w * LDWIDTH, + cute.make_layout((LDWIDTH,), stride=(1,)), + ) + cute.copy(ld, s, f) + + # ---- consume: math + wide store per group ---- + for jj in cutlass.range_constexpr(ILP): + gc = base + jj * nthreads_x + if gc < GPR: + vals = cute.make_rmem_tensor((32,), cutlass.Float32) + for i in cutlass.range_constexpr(32): + vals[i] = cutlass.Float32(fragbuf[jj * 32 + i]) + packed = cute.make_rmem_tensor((16,), cutlass.Uint8) + for b in cutlass.range_constexpr(2): + blk = cute.make_tensor( + vals.iterator + b * 16, + cute.make_layout((16,), stride=(1,)), + ) + if cutlass.const_expr(APPLY_RHT): + fwht16_sign(blk, sign_reg) + amax = compute_amax(blk) + e4m3, inv = compute_nvfp4_scale_e4m3(amax, global_scale) + kb = gc * cutlass.Int32(2) + b + # scale store + if cutlass.const_expr(SWIZZLED): + r128 = row // cutlass.Int32(128) + r32 = row % cutlass.Int32(32) + r32_4 = (row // cutlass.Int32(32)) % cutlass.Int32(4) + kb4 = kb // cutlass.Int32(4) + kbm = kb % cutlass.Int32(4) + soff = ( + r128 * cutlass.Int32(128) * pad_cols + + kb4 * cutlass.Int32(512) + + r32 * cutlass.Int32(16) + + r32_4 * cutlass.Int32(4) + + kbm + ) + gscale[soff] = e4m3 + else: + gscale[row * (K // cutlass.Int32(16)) + kb] = e4m3 + for p in cutlass.range_constexpr(8): + lo = blk[2 * p] * inv + hi = blk[2 * p + 1] * inv + packed[b * 8 + p] = _cvt_rn_satfinite_e2m1x2_f32(hi, lo) + offq = cute.assume( + row * (K // cutlass.Int32(2)) + gc * cutlass.Int32(16), + divby=16, + ) + d = cute.make_tensor( + gq.iterator + offq, cute.make_layout((16,), stride=(1,)) + ) + cute.copy(st, packed, d) + + base = base + nthreads_x * cutlass.Int32(ILP) + row = row + cute.arch.grid_dim()[1] + + @cute.jit + def _maxbw_launch( + gx, + gq, + gscale, + gsign, + M, + K, + GPR, + pad_cols, + gs, + threads, + ncta_x, + ncta_y, + stream, + ILP: cutlass.Constexpr[int], + APPLY_RHT: cutlass.Constexpr[bool], + SWIZZLED: cutlass.Constexpr[bool], + LDWIDTH: cutlass.Constexpr[int], + ): + _maxbw_kernel( + gx, + gq, + gscale, + gsign, + M, + K, + GPR, + pad_cols, + gs, + ILP, + APPLY_RHT, + SWIZZLED, + LDWIDTH, + ).launch( + grid=(ncta_x, ncta_y, 1), + block=(threads, 1, 1), + cluster=(1, 1, 1), + stream=stream, + ) + + return _maxbw_launch + + +def _maxbw_quantize( + x: torch.Tensor, + global_scale: float, + sign_vector, + is_swizzled_scales: bool = True, + threads: int = 128, + ilp: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Host wrapper: launch the max-bandwidth NVFP4 (+/- RHT) quantize kernel. + + Output contract matches the warp-specialized TMA path exactly: ``qdata`` is + row-major ``(M, K // 2)`` uint8 (packed E2M1x2); ``scales`` is + ``float8_e4m3fn`` in the cuBLAS-blocked padded layout (``is_swizzled_scales``) + or a plain ``(M, K // 16)`` tensor. + """ + import cuda.bindings.driver as cuda + import cutlass + from cutlass.cute.runtime import from_dlpack + + assert x.is_cuda and x.dim() == 2 and x.is_contiguous() + assert x.dtype in (torch.float32, torch.bfloat16) + M, K = x.shape + assert M % 128 == 0 and K % 128 == 0 + k_blocks = K // 16 + apply_rht = sign_vector is not None and len(sign_vector) > 0 + if apply_rht: + assert len(sign_vector) == 16 + + q_data = torch.empty_strided( + (M, K // 2), (K // 2, 1), device=x.device, dtype=torch.uint8 + ) + padded_scale_rows = ceil_div(M, 128) * 128 + padded_scale_cols = ceil_div(k_blocks, 4) * 4 + if is_swizzled_scales: + scales_u8 = torch.empty( + (padded_scale_rows * padded_scale_cols,), + device=x.device, + dtype=torch.uint8, + ) + else: + scales_u8 = torch.empty((M, k_blocks), device=x.device, dtype=torch.uint8) + + sign_src = sign_vector if apply_rht else [0] * 16 + sign_dev = torch.tensor( + [int(s) for s in sign_src], device=x.device, dtype=torch.int32 + ) + + GPR = K // 32 + nthreads_needed = (GPR + ilp - 1) // ilp + ncta_x = (nthreads_needed + threads - 1) // threads + ncta_y = M + ldwidth = 4 if x.dtype == torch.float32 else 8 + + launch = _get_maxbw_launch() + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + + def _args(): + return ( + from_dlpack(x.view(-1), assumed_align=16), + from_dlpack(q_data.view(-1), assumed_align=16), + from_dlpack(scales_u8.view(-1), assumed_align=16), + from_dlpack(sign_dev, assumed_align=16), + cutlass.Int32(M), + cutlass.Int32(K), + cutlass.Int32(GPR), + cutlass.Int32(padded_scale_cols), + cutlass.Float32(float(global_scale)), + threads, + ncta_x, + ncta_y, + stream, + ) + + key = (str(x.dtype), apply_rht, is_swizzled_scales, threads, ilp) + compiled = _MAXBW_JIT_CACHE.get(key) + if compiled is None: + # Compile with CONCRETE tensors (see section header); constexprs baked. + compiled = cutlass.cute.compile( + launch, + *_args(), + ilp, + apply_rht, + is_swizzled_scales, + ldwidth, + ) + _MAXBW_JIT_CACHE[key] = compiled + compiled(*_args()) + + scales = scales_u8.view(torch.float8_e4m3fn) + scales = ( + scales.view(padded_scale_rows, padded_scale_cols) + if is_swizzled_scales + else scales.view(M, k_blocks) + ) + return q_data, scales + + +def _nvfp4_rht_quantize_cutedsl_impl( + x: torch.Tensor, + global_scale: float, + sign_vector: list[int], + block_size: int = 16, + is_swizzled_scales: bool = True, + stage_count: int = 2, + use_maxbw: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Host wrapper: launch the fused NVFP4 (+/- RHT) CuTeDSL quantize kernel. + + Args: + x: 2D contiguous bf16/fp32 input ``(M, K)`` with ``M % 128 == 0`` and + ``K % 128 == 0``. + global_scale: per-tensor multiplicative global scale (the reciprocal of + torchao's ``per_tensor_scale``). + sign_vector: length-16 list of ``{-1, +1}`` for the RHT sign multiply, or + an empty list ``[]`` to skip the transform (plain NVFP4 cast). + block_size: only 16 is supported. + is_swizzled_scales: write scales in the cuBLAS-blocked padded layout + (``True``) or a plain ``(M, K // 16)`` tensor (``False``). + stage_count: pipeline stages (1 or 2). Only used by the legacy TMA path. + use_maxbw: dispatch to the max-bandwidth streaming kernel + (``_maxbw_quantize``, the default), which is byte-for-byte identical + to the legacy warp-specialized TMA path but ~1.5x faster (reaches the + HBM roofline). Set ``False`` for the TMA path. + + Returns: + ``(qdata, scales)`` where ``qdata`` is row-major ``(M, K // 2)`` uint8 + (packed E2M1x2) and ``scales`` is ``float8_e4m3fn`` in the requested + layout. + """ + import cuda.bindings.driver as cuda + + assert x.is_cuda, "Input tensor must be CUDA" + assert x.dim() == 2, "Input tensor must be 2D" + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dtype in ( + torch.float32, + torch.bfloat16, + ), "Input tensor must be float32 or bfloat16" + assert block_size == 16, "Only block_size=16 is supported" + apply_rht = len(sign_vector) > 0 + if apply_rht: + assert len(sign_vector) == 16, "sign_vector must have length 16 (or be empty)" + + M, K = x.shape + assert K % 16 == 0, "K must be divisible by 16" + assert M % 128 == 0, "M must be divisible by 128 (TMA tiling)" + assert K % 128 == 0, "K must be divisible by 128 (TMA tiling)" + + # Default fast path: the max-bandwidth streaming kernel (``_maxbw_quantize`` + # above). Byte-for-byte identical output to the legacy TMA path below + # (validated across dtype x RHT x swizzled x shape), at ~1.5x the throughput. + if use_maxbw: + return _maxbw_quantize( + x, + float(global_scale), + sign_vector if apply_rht else None, + is_swizzled_scales=is_swizzled_scales, + ) + + _, config = _select_cutedsl_config(x.dtype) + compute_warps, tile_m, tile_k, k_tiles_per_cta = config + assert stage_count >= 1, "stage_count must be >= 1" + assert stage_count <= 2, "stage_count must be <= 2" + + k_blocks = K // block_size + + # Half-width row-major packed output: stride (K // 2, 1). + q_data = torch.empty_strided( + (M, K // 2), + (K // 2, 1), + device=x.device, + dtype=torch.uint8, + ) + + padded_scale_rows = ceil_div(M, 128) * 128 + padded_scale_cols = ceil_div(k_blocks, 4) * 4 + if is_swizzled_scales: + scales_u8 = torch.empty( + (padded_scale_rows * padded_scale_cols,), + device=x.device, + dtype=torch.uint8, + ) + else: + scales_u8 = torch.empty( + (M, k_blocks), + device=x.device, + dtype=torch.uint8, + ) + + # The kernel always reads a length-16 sign tensor; for the no-RHT path it is + # never dereferenced (guarded by the ``apply_rht`` constexpr), so a zero + # placeholder is fine. + sign_src = sign_vector if apply_rht else [0] * 16 + sign_dev = torch.tensor( + [int(s) for s in sign_src], device=x.device, dtype=torch.int32 + ) + + compiled = _compile_nvfp4_rht_quantize_2d_cutedsl( + str(x.dtype), + apply_rht, + compute_warps, + tile_m, + tile_k, + stage_count, + k_tiles_per_cta, + is_swizzled_scales, + ) + + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + m_cta_tiles = ceil_div(M, tile_m) + k_cta_tiles = ceil_div(K, tile_k * k_tiles_per_cta) + + compiled( + x, + q_data, + scales_u8, + sign_dev, + float(global_scale), + int(M), + int(K), + int(k_blocks), + int(m_cta_tiles), + int(k_cta_tiles), + stream, + ) + + scales = scales_u8.view(torch.float8_e4m3fn) + scales = ( + scales.view(padded_scale_rows, padded_scale_cols) + if is_swizzled_scales + else scales.view(M, k_blocks) + ) + return q_data, scales + + +@torch.library.custom_op("torchao::nvfp4_rht_quantize_cutedsl", mutates_args=()) +def nvfp4_rht_quantize_cutedsl( + x: torch.Tensor, + global_scale: float, + sign_vector: list[int], + block_size: int = 16, + is_swizzled_scales: bool = True, + stage_count: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _nvfp4_rht_quantize_cutedsl_impl( + x, + global_scale, + sign_vector, + block_size=block_size, + is_swizzled_scales=is_swizzled_scales, + stage_count=stage_count, + ) + + +@nvfp4_rht_quantize_cutedsl.register_fake +def _( + x: torch.Tensor, + global_scale: float, + sign_vector: list[int], + block_size: int = 16, + is_swizzled_scales: bool = True, + stage_count: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + m, k = x.shape + q = torch.empty_strided( + (m, k // 2), (k // 2, 1), device=x.device, dtype=torch.uint8 + ) # row-major pinned + kb = k // block_size + if is_swizzled_scales: + scales = x.new_empty( + (ceil_div(m, 128) * 128, ceil_div(kb, 4) * 4), + dtype=torch.float8_e4m3fn, + ) + else: + scales = x.new_empty((m, kb), dtype=torch.float8_e4m3fn) + return q, scales + + +def nvfp4_rht_quantize_cutedsl_2d( + x: torch.Tensor, + global_scale, + sign_vector=None, + block_size: int = 16, + is_swizzled_scales: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Gated public wrapper for the fused NVFP4 (+/- RHT) CuTeDSL quantize op. + + ``sign_vector=None`` (or an empty list) selects the plain NVFP4 cast; a + length-16 ``sign_vector`` enables the fused RHT. Raises + ``NotImplementedError`` (with the missing-runtime detail) when the CuTeDSL + runtime / SM 10.x / CUDA >= 12.8 requirements are not met. + """ + from torchao.prototype.mx_formats.cutedsl import ( + _mxfp4_rht_cutedsl_kernels_available, + ) + + if not _mxfp4_rht_cutedsl_kernels_available: + from torchao.prototype.mx_formats.cutedsl.cute_utils import ( + _missing_cutedsl_runtime_packages, + ) + + raise NotImplementedError( + "nvfp4_rht_quantize_cutedsl requires CUDA SM10.x, CUDA>=12.8, and: " + f"{_missing_cutedsl_runtime_packages() or 'nvidia-cutlass-dsl'}" + ) + sv = list(sign_vector) if sign_vector is not None else [] + return nvfp4_rht_quantize_cutedsl( + x, float(global_scale), sv, block_size, is_swizzled_scales + ) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 4cdbdd1b3e..279e5d2c89 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -39,6 +39,7 @@ from torch.nn.functional import ScalingType, SwizzleType from torchao.prototype.mx_formats.config import ( + MXFP4CastKernelChoice, MXFP8Dim0CastKernelChoice, ScaleCalculationMode, ) @@ -643,7 +644,23 @@ def to_mx( act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, is_swizzled_scales: bool = False, mxfp8_dim0_cast_kernel_choice: MXFP8Dim0CastKernelChoice = MXFP8Dim0CastKernelChoice.TORCH, + mxfp4_cast_kernel_choice: MXFP4CastKernelChoice = MXFP4CastKernelChoice.TORCH, + rht_sign_vector: Optional[list[int]] = None, ): + """ + Quantize ``data_hp`` to an ``MXTensor``. + + ``mxfp4_cast_kernel_choice`` / ``rht_sign_vector`` (both trailing, defaulted): + When ``elem_dtype == torch.float4_e2m1fn_x2`` and + ``mxfp4_cast_kernel_choice == MXFP4CastKernelChoice.CUTEDSL``, a fused + per-32-block Random Hadamard Transform (using ``rht_sign_vector``, which + must have length ``block_size``) is applied before quantization. The + resulting ``MXTensor`` therefore represents ``RHT(x)`` and dequantizes to + ``≈RHT(x)``; the caller is responsible for applying the inverse rotation + after dequantize (RHT is caller-managed; no inverse is stored on the + tensor). The default ``MXFP4CastKernelChoice.TORCH`` path is unchanged and + does not require ``rht_sign_vector``. + """ assert mxfp8_dim0_cast_kernel_choice in ( MXFP8Dim0CastKernelChoice.TRITON, MXFP8Dim0CastKernelChoice.TORCH, @@ -651,6 +668,35 @@ def to_mx( f"unsupported kernel choice for mxfp8_dim0_cast_kernel_choice: {mxfp8_dim0_cast_kernel_choice}" ) + if ( + elem_dtype == torch.float4_e2m1fn_x2 + and mxfp4_cast_kernel_choice == MXFP4CastKernelChoice.CUTEDSL + ): + assert rht_sign_vector is not None and len(rht_sign_vector) == block_size, ( + "MXFP4 CUTEDSL cast requires rht_sign_vector of length block_size" + ) + from torchao.prototype.mx_formats.cutedsl import ( + mxfp4_rht_quantize_cutedsl_2d, + ) + + data_lp, scale_e8m0 = mxfp4_rht_quantize_cutedsl_2d( + data_hp, + rht_sign_vector, + block_size, + scaling_mode.value.lower(), + is_swizzled_scales, + ) + return MXTensor( + data_lp, + scale_e8m0, + elem_dtype, + block_size, + data_hp.dtype, + kernel_preference, + act_quant_kwargs, + is_swizzled_scales, + ) + triton_kernel_supported = ( elem_dtype == torch.float8_e4m3fn and not is_swizzled_scales ) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 6bbcde79b4..a12be578c6 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -46,6 +46,7 @@ class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs): block_size: int = 16 is_swizzled_scales: bool = False use_triton_kernel: bool = False + use_cutedsl_kernel: bool = False use_dynamic_per_tensor_scale: bool = False @@ -75,6 +76,7 @@ class NVFP4Tensor(TorchAOBaseTensor): optional_tensor_attribute_names = [ "is_swizzled_scales", "use_triton_kernel", + "use_cutedsl_kernel", "act_quant_kwargs", ] @@ -88,6 +90,7 @@ def __new__( act_per_tensor_scale=None, is_swizzled_scales=False, use_triton_kernel=False, + use_cutedsl_kernel=False, act_quant_kwargs=None, ): # FP4 tensor size handling two paths, contiguous or not @@ -119,6 +122,7 @@ def __new__( self.act_per_tensor_scale = act_per_tensor_scale self.is_swizzled_scales = is_swizzled_scales self.use_triton_kernel = use_triton_kernel + self.use_cutedsl_kernel = use_cutedsl_kernel self.act_quant_kwargs = act_quant_kwargs return self @@ -126,7 +130,7 @@ def __repr__(self): return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self.per_tensor_scale}, d: {self.qdata}, d_hp: {self.dequantize(self.orig_dtype)}" def _quantization_type(self): - return f"{self.is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}" + return f"{self.is_swizzled_scales=}, {self.use_triton_kernel=}, {self.use_cutedsl_kernel=}, {self.act_quant_kwargs=}" @staticmethod def to_nvfp4( @@ -136,6 +140,7 @@ def to_nvfp4( act_per_tensor_scale: Optional[torch.Tensor] = None, is_swizzled_scales: bool = False, use_triton_kernel: bool = False, + use_cutedsl_kernel: bool = False, act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None, ): """Convert high precision tensor to NVFP4 format. @@ -157,7 +162,31 @@ def to_nvfp4( assert len(data_hp.shape) in (2, 3), "unsupported" leading_dims, M, K = data_hp.shape[:-2], data_hp.shape[-2], data_hp.shape[-1] - if use_triton_kernel: + if use_cutedsl_kernel: + assert len(data_hp.shape) == 2, ( + "CuTeDSL NVFP4 backend only supports 2D inputs" + ) + assert M % 128 == 0 and K % 128 == 0, ( + "CuTeDSL NVFP4 backend requires M and K divisible by 128" + ) + if per_tensor_scale is not None: + assert per_tensor_scale.numel() == 1, ( + "CuTeDSL NVFP4 backend only supports a scalar per_tensor_scale" + ) + global_scale = float((1.0 / per_tensor_scale).item()) + else: + global_scale = 1.0 + from .cutedsl import nvfp4_rht_quantize_cutedsl_2d + + data_lp, blockwise_scales = nvfp4_rht_quantize_cutedsl_2d( + data_hp.contiguous(), + global_scale, + None, + block_size=block_size, + is_swizzled_scales=is_swizzled_scales, + ) + blockwise_scales = blockwise_scales.flatten() + elif use_triton_kernel: assert is_swizzled_scales, "Triton kernel only supports swizzled scales" assert K % 16 == 0, ( f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}" @@ -190,6 +219,7 @@ def to_nvfp4( act_per_tensor_scale, is_swizzled_scales, use_triton_kernel, + use_cutedsl_kernel, act_quant_kwargs, ) @@ -331,6 +361,7 @@ def nvfp4_to_copy(func, types, args, kwargs): tensor.act_per_tensor_scale, tensor.is_swizzled_scales, tensor.use_triton_kernel, + tensor.use_cutedsl_kernel, tensor.act_quant_kwargs, ) return res @@ -365,6 +396,7 @@ def nvfp4_pin_memory(func, types, args, kwargs): else None, tensor.is_swizzled_scales, tensor.use_triton_kernel, + tensor.use_cutedsl_kernel, tensor.act_quant_kwargs, ) @@ -394,6 +426,7 @@ def nvfp4_slice(func, types, args, kwargs): x.act_per_tensor_scale, x.is_swizzled_scales, x.use_triton_kernel, + x.use_cutedsl_kernel, x.act_quant_kwargs, ) @@ -414,6 +447,7 @@ def nvfp4_t(func, types, args, kwargs): old.act_per_tensor_scale, old.is_swizzled_scales, old.use_triton_kernel, + old.use_cutedsl_kernel, old.act_quant_kwargs, ) return new @@ -436,6 +470,7 @@ def nvfp4_transpose(func, types, args, kwargs): old.act_per_tensor_scale, old.is_swizzled_scales, old.use_triton_kernel, + old.use_cutedsl_kernel, old.act_quant_kwargs, ) return new @@ -457,6 +492,7 @@ def nvfp4_view_op(func, types, args, kwargs): args[0].act_per_tensor_scale, args[0].is_swizzled_scales, args[0].use_triton_kernel, + args[0].use_cutedsl_kernel, args[0].act_quant_kwargs, ) @@ -479,6 +515,7 @@ def nvfp4_select(func, types, args, kwargs): old.act_per_tensor_scale, old.is_swizzled_scales, old.use_triton_kernel, + old.use_cutedsl_kernel, old.act_quant_kwargs, ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -610,6 +647,7 @@ def nvfp4_linear(func, types, args, kwargs): per_tensor_scale=per_tensor_scale, is_swizzled_scales=k.is_swizzled_scales, use_triton_kernel=k.use_triton_kernel, + use_cutedsl_kernel=k.use_cutedsl_kernel, ) res = _addmm_nvfp4_dispatch(input_tensor, weight_tensor.t(), func, bias=bias) res = res.reshape(*orig_shape[:-1], res.shape[-1]) @@ -645,6 +683,7 @@ def nvfp4_mm(func, types, args, kwargs): per_tensor_scale=per_tensor_scale, is_swizzled_scales=k.is_swizzled_scales, use_triton_kernel=k.use_triton_kernel, + use_cutedsl_kernel=k.use_cutedsl_kernel, ) return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func) @@ -699,6 +738,7 @@ def nvfp4_addmm(func, types, args, kwargs): per_tensor_scale=per_tensor_scale, is_swizzled_scales=k.is_swizzled_scales, use_triton_kernel=k.use_triton_kernel, + use_cutedsl_kernel=k.use_cutedsl_kernel, ) return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func, bias=bias)