diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 8c7180d74..892ead6f4 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -469,6 +469,7 @@ def online_quantize_weight(self): online_quant_func = get_hip_quant(online_quant_type) assert online_quant_dtype in [ torch.float8_e4m3fn, + torch.float8_e4m3fnuz, torch.float4_e2m1fn_x2, ], ( f"Unsupported online quant: " @@ -519,9 +520,21 @@ def online_quantize_weight(self): elif self.quant_type == QuantType.per_1x32: # dequant MXFP8 (FP8 elements + 1x32 E8M0 shared scale) weight = weight_dequant_mxfp8(weight, weight_scale) - q_weight, weight_scale = online_quant_func( - weight, quant_dtype=online_quant_dtype - ) + + if online_quant_type == QuantType.per_1x128: + # Linear per_1x128 path uses blockscale GEMM, which consumes + # 128x128 weight scales shaped as (N//128, K//128). + from quantization.quark.utils import ( + quantize_weight_to_fp8_128x128_blockscale, + ) + + q_weight, weight_scale = quantize_weight_to_fp8_128x128_blockscale( + weight, online_quant_dtype + ) + else: + q_weight, weight_scale = online_quant_func( + weight, quant_dtype=online_quant_dtype + ) if need_gather: q_weight, weight_scale = self._shard_quantized_weight( q_weight, weight_scale @@ -533,8 +546,11 @@ def online_quantize_weight(self): self.quant_type = online_quant_type self.params_dtype = online_quant_dtype self.quant_func = online_quant_func + # online_quant_func already returns fnuz when quant_dtype=fnuz on gfx942; + # only normalize when output is still non-fnuz. self.need_normalize_e4m3fn_to_e4m3fnuz = ( online_quant_dtype == torch.float8_e4m3fnuz + and q_weight.dtype != torch.float8_e4m3fnuz ) self._online_quant_info = { "layer": self.prefix, diff --git a/atom/quant_spec.py b/atom/quant_spec.py index e1d1389e2..f9a7fadb8 100644 --- a/atom/quant_spec.py +++ b/atom/quant_spec.py @@ -218,8 +218,9 @@ def parse(self, online_quant_config: dict) -> ParsedQuantConfig: if not isinstance(online_quant_config, dict): raise TypeError("online_quant_config must be a dict parsed from JSON.") - SCHEME_MAP = { + scheme_map = { "ptpc": QuantType.per_Token, + "per_block": QuantType.per_1x128, } def _parse_online_quant_format(quant_format_str: str) -> LayerQuantConfig: @@ -231,10 +232,14 @@ def _parse_online_quant_format(quant_format_str: str) -> LayerQuantConfig: quant_type = QuantType.per_1x32 dtype_str = quant_format_str[2:] else: - parts = quant_format_str.split("_", 1) - if len(parts) == 2 and parts[0] in SCHEME_MAP: - quant_type = SCHEME_MAP[parts[0]] - dtype_str = parts[1] + matched_scheme = None + for scheme in sorted(scheme_map, key=len, reverse=True): + if quant_format_str.startswith(scheme + "_"): + matched_scheme = scheme + break + if matched_scheme is not None: + quant_type = scheme_map[matched_scheme] + dtype_str = quant_format_str[len(matched_scheme) + 1 :] else: raise ValueError( f"Unsupported online quant format: '{quant_format_str}'. " diff --git a/atom/quantization/quark/utils.py b/atom/quantization/quark/utils.py index 3312fce74..a0e58bb11 100644 --- a/atom/quantization/quark/utils.py +++ b/atom/quantization/quark/utils.py @@ -124,3 +124,41 @@ def weight_dequant_mxfp8( y = x.to(torch.float32).reshape(M, n_blocks, block_size) y = y * scale.unsqueeze(-1) return y.reshape(M, K).to(out_dtype) + + +def quantize_weight_to_fp8_128x128_blockscale(weight, quant_dtype): + """Quantize a 2D weight to FP8 with 128x128 block scales. + + Returns: + q_weight: quantized weight with the same shape as input ``weight``. + scale: per-block scale with shape ``(ceil(N/128), ceil(K/128))``. + """ + assert weight.dim() == 2, f"expected 2D weight, got shape={tuple(weight.shape)}" + + w = weight.to(torch.float32).contiguous() + n, k = w.shape + n_blocks = (n + 127) // 128 + k_blocks = (k + 127) // 128 + n_padded = n_blocks * 128 + k_padded = k_blocks * 128 + + if n_padded != n or k_padded != k: + w = torch.nn.functional.pad(w, (0, k_padded - k, 0, n_padded - n)) + + w_blocks = w.view(n_blocks, 128, k_blocks, 128).permute(0, 2, 1, 3).contiguous() + + finfo = torch.finfo(quant_dtype) + block_amax = w_blocks.abs().amax(dim=(2, 3)) + scale = (block_amax / finfo.max).clamp_min(torch.finfo(torch.float32).tiny) + + q_blocks = torch.clamp( + w_blocks / scale.unsqueeze(-1).unsqueeze(-1), min=finfo.min, max=finfo.max + ).to(quant_dtype) + + q_weight = ( + q_blocks.permute(0, 2, 1, 3) + .contiguous() + .view(n_padded, k_padded)[:n, :k] + .contiguous() + ) + return q_weight, scale.contiguous()