Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def online_quantize_weight(self):
online_quant_func = get_hip_quant(online_quant_type)
assert online_quant_dtype in [
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float4_e2m1fn_x2,
], (
f"Unsupported online quant: "
Expand Down Expand Up @@ -519,9 +520,21 @@ def online_quantize_weight(self):
elif self.quant_type == QuantType.per_1x32:
# dequant MXFP8 (FP8 elements + 1x32 E8M0 shared scale)
weight = weight_dequant_mxfp8(weight, weight_scale)
q_weight, weight_scale = online_quant_func(
weight, quant_dtype=online_quant_dtype
)

if online_quant_type == QuantType.per_1x128:
# Linear per_1x128 path uses blockscale GEMM, which consumes
# 128x128 weight scales shaped as (N//128, K//128).
from quantization.quark.utils import (
quantize_weight_to_fp8_128x128_blockscale,
)

q_weight, weight_scale = quantize_weight_to_fp8_128x128_blockscale(
weight, online_quant_dtype
)
else:
q_weight, weight_scale = online_quant_func(
weight, quant_dtype=online_quant_dtype
)
if need_gather:
q_weight, weight_scale = self._shard_quantized_weight(
q_weight, weight_scale
Expand All @@ -533,8 +546,11 @@ def online_quantize_weight(self):
self.quant_type = online_quant_type
self.params_dtype = online_quant_dtype
self.quant_func = online_quant_func
# online_quant_func already returns fnuz when quant_dtype=fnuz on gfx942;
# only normalize when output is still non-fnuz.
self.need_normalize_e4m3fn_to_e4m3fnuz = (
online_quant_dtype == torch.float8_e4m3fnuz
and q_weight.dtype != torch.float8_e4m3fnuz
)
self._online_quant_info = {
"layer": self.prefix,
Expand Down
15 changes: 10 additions & 5 deletions atom/quant_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,9 @@ def parse(self, online_quant_config: dict) -> ParsedQuantConfig:
if not isinstance(online_quant_config, dict):
raise TypeError("online_quant_config must be a dict parsed from JSON.")

SCHEME_MAP = {
scheme_map = {
"ptpc": QuantType.per_Token,
"per_block": QuantType.per_1x128,
}

def _parse_online_quant_format(quant_format_str: str) -> LayerQuantConfig:
Expand All @@ -231,10 +232,14 @@ def _parse_online_quant_format(quant_format_str: str) -> LayerQuantConfig:
quant_type = QuantType.per_1x32
dtype_str = quant_format_str[2:]
else:
parts = quant_format_str.split("_", 1)
if len(parts) == 2 and parts[0] in SCHEME_MAP:
quant_type = SCHEME_MAP[parts[0]]
dtype_str = parts[1]
matched_scheme = None
for scheme in sorted(scheme_map, key=len, reverse=True):
if quant_format_str.startswith(scheme + "_"):
matched_scheme = scheme
break
if matched_scheme is not None:
quant_type = scheme_map[matched_scheme]
dtype_str = quant_format_str[len(matched_scheme) + 1 :]
else:
raise ValueError(
f"Unsupported online quant format: '{quant_format_str}'. "
Expand Down
38 changes: 38 additions & 0 deletions atom/quantization/quark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,41 @@ def weight_dequant_mxfp8(
y = x.to(torch.float32).reshape(M, n_blocks, block_size)
y = y * scale.unsqueeze(-1)
return y.reshape(M, K).to(out_dtype)


def quantize_weight_to_fp8_128x128_blockscale(weight, quant_dtype):
"""Quantize a 2D weight to FP8 with 128x128 block scales.

Returns:
q_weight: quantized weight with the same shape as input ``weight``.
scale: per-block scale with shape ``(ceil(N/128), ceil(K/128))``.
"""
assert weight.dim() == 2, f"expected 2D weight, got shape={tuple(weight.shape)}"

w = weight.to(torch.float32).contiguous()
n, k = w.shape
n_blocks = (n + 127) // 128
k_blocks = (k + 127) // 128
n_padded = n_blocks * 128
k_padded = k_blocks * 128

if n_padded != n or k_padded != k:
w = torch.nn.functional.pad(w, (0, k_padded - k, 0, n_padded - n))

w_blocks = w.view(n_blocks, 128, k_blocks, 128).permute(0, 2, 1, 3).contiguous()

finfo = torch.finfo(quant_dtype)
block_amax = w_blocks.abs().amax(dim=(2, 3))
scale = (block_amax / finfo.max).clamp_min(torch.finfo(torch.float32).tiny)

q_blocks = torch.clamp(
w_blocks / scale.unsqueeze(-1).unsqueeze(-1), min=finfo.min, max=finfo.max
).to(quant_dtype)

q_weight = (
q_blocks.permute(0, 2, 1, 3)
.contiguous()
.view(n_padded, k_padded)[:n, :k]
.contiguous()
)
return q_weight, scale.contiguous()
Loading