From aece3521b18f41bab9b163453d5c2b2244b21f49 Mon Sep 17 00:00:00 2001 From: ali-ch-001 Date: Thu, 25 Jun 2026 02:08:43 +0500 Subject: [PATCH 1/2] docs: update static quantization tutorial with modern torchao API usage and improved calibration workflow --- .../eager_tutorials/static_quantization.rst | 347 +++++++----------- 1 file changed, 129 insertions(+), 218 deletions(-) diff --git a/docs/source/eager_tutorials/static_quantization.rst b/docs/source/eager_tutorials/static_quantization.rst index d9f79c5c9b..a4d4db48c2 100644 --- a/docs/source/eager_tutorials/static_quantization.rst +++ b/docs/source/eager_tutorials/static_quantization.rst @@ -1,239 +1,150 @@ Static Quantization -------------------- -Static quantization refers to using a fixed quantization range for all inputs during inference or generation. Unlike dynamic quantization, which dynamically computes new quantization ranges for each new input batch, static quantization typically results in more efficient computation, potentially at the cost of lower quantized accuracy since we cannot adapt to changes in the input distribution on-the-fly. +Static quantization refers to using a fixed quantization range for all inputs during inference. Unlike dynamic quantization, which recomputes quantization ranges for each new input batch, static quantization typically results in more efficient computation, potentially at the cost of lower quantized accuracy since we cannot adapt to changes in the input distribution on-the-fly. -In static quantization, this fixed quantization range is typically calibrated on similar inputs before quantizing the model. During the calibration phase, we first insert observers into the model to "observe" the distribution of the inputs to be quantized, and use this distribution to decide what scales and zero points to ultimately use when quantizing the model. +In static quantization, this fixed quantization range is typically calibrated on similar inputs before quantizing the model. During the calibration phase, we determine what scales and zero points to use, then lock them in for all future inference. -In this tutorial, we walk through an example of how to achieve this in torchao. Let's start with our toy linear model: +In this tutorial, we walk through an example of how to achieve this in torchao using ``Int8StaticActivationInt8WeightConfig`` and ``quantize_``. Let's start with our toy linear model: -.. code:: py +.. code:: python - import copy - import torch + from collections import OrderedDict + import copy + import torch - class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64): - super().__init__() - self.linear1 = torch.nn.Linear(m, k, bias=False) - self.linear2 = torch.nn.Linear(k, n, bias=False) + from torchao.quantization import ( + AffineQuantizedMinMaxObserver, + FqnToConfig, + Int8StaticActivationInt8WeightConfig, + MappingType, + PerRow, + PerTensor, + quantize_, + ) - def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): - return ( - torch.randn( - batch_size, self.linear1.in_features, dtype=dtype, device=device - ), - ) + class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64): + super().__init__() + self.linear1 = torch.nn.Linear(m, k, bias=False) + self.linear2 = torch.nn.Linear(k, n, bias=False) - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) - dtype = torch.bfloat16 - m = ToyLinearModel().eval().to(dtype).to("cuda") - m = torch.compile(m, mode="max-autotune") + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x Calibration Phase ~~~~~~~~~~~~~~~~~ -torchao comes with a a simple observer implementation, `AffineQuantizedMinMaxObserver`, that records the min and max values that have flowed through the observer during the calibration phase. Users are welcome to implement their own desired, more advanced observation techniques, such as those relying on moving averages or histograms, and these may be added to torchao in the future. - -.. code:: py - - from torchao.quantization.granularity import PerAxis, PerTensor - from torchao.quantization.observer import AffineQuantizedMinMaxObserver - from torchao.quantization.quant_primitives import MappingType - - # per tensor input activation asymmetric quantization - act_obs = AffineQuantizedMinMaxObserver( - MappingType.ASYMMETRIC, - torch.uint8, - granularity=PerTensor(), - eps=torch.finfo(torch.float32).eps, - scale_dtype=torch.float32, - zero_point_dtype=torch.float32, - ) - - # per channel weight asymmetric quantization - weight_obs = AffineQuantizedMinMaxObserver( - MappingType.ASYMMETRIC, - torch.uint8, - granularity=PerAxis(axis=0), - eps=torch.finfo(torch.float32).eps, - scale_dtype=torch.float32, - zero_point_dtype=torch.float32, - ) - -Next, we define our observed linear that we will swap our `torch.nn.Linear` with. This is a high precision (e.g. fp32) linear module with the above observers inserted to record the input activation and weight values during calibration: - -.. code:: py - - import torch.nn.functional as F - - class ObservedLinear(torch.nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - act_obs: torch.nn.Module, - weight_obs: torch.nn.Module, - bias: bool = True, - device=None, - dtype=None, - ): - super().__init__(in_features, out_features, bias, device, dtype) - self.act_obs = act_obs - self.weight_obs = weight_obs - - def forward(self, input: torch.Tensor): - observed_input = self.act_obs(input) - observed_weight = self.weight_obs(self.weight) - return F.linear(observed_input, observed_weight, self.bias) - - @classmethod - def from_float(cls, float_linear, act_obs, weight_obs): - observed_linear = cls( - float_linear.in_features, - float_linear.out_features, - act_obs, - weight_obs, - False, - device=float_linear.weight.device, - dtype=float_linear.weight.dtype, - ) - observed_linear.weight = float_linear.weight - observed_linear.bias = float_linear.bias - return observed_linear - -To actually insert these observers into our toy model: - -.. code:: py - - from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, - ) - - def insert_observers_(model, act_obs, weight_obs): - _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - - def replacement_fn(m): - copied_act_obs = copy.deepcopy(act_obs) - copied_weight_obs = copy.deepcopy(weight_obs) - return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs) - - _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) - - insert_observers_(m, act_obs, weight_obs) - -Now we are ready to calibrate the model, which populates the observers we inserted with statistics recorded during the calibration. We can do this simply by feeding some example inputs to our "observed" model: - -.. code:: py - - for _ in range(10): - example_inputs = m.example_inputs(dtype=dtype, device="cuda") - m(*example_inputs) +The goal of calibration is to determine fixed activation quantization parameters for each linear layer. We insert activation observers with forward pre-hooks, then run representative inputs through the original floating point model: + +.. code:: python + + dtype = torch.bfloat16 + m = ToyLinearModel().eval().to(dtype).to("cuda") + m_static = copy.deepcopy(m) + + activation_granularity = PerTensor() + weight_granularity = PerRow() + act_mapping_type = MappingType.SYMMETRIC + + activation_observers = OrderedDict() + observer_handles = [] + + def make_activation_observer(): + return AffineQuantizedMinMaxObserver( + act_mapping_type, + torch.int8, + granularity=activation_granularity, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float32, + zero_point_dtype=torch.int8, + keepdim=True, + ) + + def observe_input(module, inputs, observer): + observer(inputs[0]) + + for name, module in m.named_modules(): + if isinstance(module, torch.nn.Linear): + observer = make_activation_observer() + activation_observers[name] = observer + observer_handles.append( + module.register_forward_pre_hook( + lambda module, inputs, observer=observer: observe_input( + module, inputs, observer + ) + ) + ) + + with torch.no_grad(): + for _ in range(10): + example_inputs = m.example_inputs(dtype=dtype, device="cuda") + m(*example_inputs) + + for handle in observer_handles: + handle.remove() + +After calibration, each observer can compute the activation scale and zero point for its corresponding layer: + +.. code:: python + + act_scale, act_zero_point = activation_observers["linear1"].calculate_qparams() Quantization Phase ~~~~~~~~~~~~~~~~~~ -There are multiple ways to actually quantize the model. Here we walk through the simpler alternative, which is to define a `QuantizedLinear` class that we will swap our `ObservedLinear` to. - -.. code:: py - - from torchao.quantization import Int8Tensor - from torchao.quantization import PerRow, PerTensor - - class QuantizedLinear(torch.nn.Module): - def __init__( - self, - in_features: int, - out_features: int, - act_obs: torch.nn.Module, - weight_obs: torch.nn.Module, - weight: torch.Tensor, - bias: torch.Tensor, - ): - super().__init__() - self.act_scale, self.act_zero_point = act_obs.calculate_qparams() - weight_scale, weight_zero_point = weight_obs.calculate_qparams() - self.bias = bias - self.qweight = Int8Tensor.from_hp( - weight, granularity=PerRow(), - scale=weight_scale, zero_point=weight_zero_point, - ) - - def forward(self, input: torch.Tensor): - qinput = Int8Tensor.from_hp( - input, - granularity=PerTensor(), - scale=self.act_scale, - zero_point=self.act_zero_point, - ) - return F.linear(qinput, self.qweight, self.bias) - - @classmethod - def from_observed(cls, observed_linear, target_dtype): - quantized_linear = cls( - observed_linear.in_features, - observed_linear.out_features, - observed_linear.act_obs, - observed_linear.weight_obs, - observed_linear.weight, - observed_linear.bias, - target_dtype, - ) - return quantized_linear - -This linear class computes the scales and zero points for both input activations and weights in the beginning, effectively fixing the quantization range for future forward calls. Now, to actually quantize the model using this linear class, we can define the following config and pass it to torchao's main `quantize_` API: - -.. code:: py - - from dataclasses import dataclass - - from torchao.core.config import AOBaseConfig - from torchao.quantization import quantize_ - from torchao.quantization.transform_module import ( - register_quantize_module_handler, - ) - - @dataclass - class StaticQuantConfig(AOBaseConfig): - target_dtype: torch.dtype - - @register_quantize_module_handler(StaticQuantConfig) - def _apply_static_quant( - module: torch.nn.Module, - config: StaticQuantConfig, - ): - """ - Define a transformation associated with `StaticQuantConfig`. - This is called by `quantize_`, not by the user directly. - """ - return QuantizedLinear.from_observed(module, config.target_dtype) - - # filter function to identify which modules to swap - is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) - - # perform static quantization - quantize_(m, StaticQuantConfig(torch.uint8), is_observed_linear) - -Now, we will see that the linear layers in our model are swapped to our `QuantizedLinear` class, with a fixed input activation scale and a fixed quantized weight: - -.. code:: py - - >>> m - OptimizedModule( - (_orig_mod): ToyLinearModel( - (linear1): QuantizedLinear() - (linear2): QuantizedLinear() - ) - ) - >>> m.linear1.act_scale - tensor([0.0237], device='cuda:0') - >>> m.linear1.qweight # quantized weight tensor with scale and zero_point - IntxUnpackedToInt8Tensor(...) # actual repr depends on quantization config - -In this tutorial, we walked through a basic example of how to perform integer static quantization in torchao. +Now we create one ``Int8StaticActivationInt8WeightConfig`` per linear layer and apply the configs with ``FqnToConfig``. This keeps the model structure unchanged and only replaces each linear weight with an ``Int8Tensor`` carrying the calibrated activation quantization parameters: + +.. code:: python + + fqn_to_config = OrderedDict() + for name, observer in activation_observers.items(): + act_scale, act_zero_point = observer.calculate_qparams() + fqn_to_config[f"{name}.weight"] = Int8StaticActivationInt8WeightConfig( + act_quant_scale=act_scale, + act_quant_zero_point=act_zero_point, + granularity=[activation_granularity, weight_granularity], + act_mapping_type=act_mapping_type, + ) + + quantize_(m_static, FqnToConfig(fqn_to_config), filter_fn=None) + +Now, we will see that the linear layers in our model have fixed activation scales and quantized weights: + +.. code:: + + >>> m_static + ToyLinearModel( + (linear1): Linear(in_features=64, out_features=64, bias=False) + (linear2): Linear(in_features=64, out_features=32, bias=False) + ) + >>> type(m_static.linear1.weight) + + >>> m_static.linear1.weight.act_quant_scale # fixed at calibration time + tensor(..., device='cuda:0') + +The model structure is unchanged. Only the weight tensors are replaced with quantized tensor subclasses carrying fixed activation scales. All subsequent forward passes will use the same scales, unlike dynamic quantization which recomputes them per batch. + +Other Approaches +~~~~~~~~~~~~~~~~ + +The calibration phase can be customized: + +- **Observers**: You can use lower-level observer APIs (``AffineQuantizedMinMaxObserver``) to record min/max statistics over calibration data and compute scales yourself. This gives full control over scale computation (e.g., moving averages, histograms). +- **AWQ / SmoothQuant**: These algorithms provide a ``prepare`` -> calibrate -> ``convert`` flow via ``quantize_()`` that integrates activation pre-scaling and smoothing with static quantization. See the ``torchao.prototype.smoothquant`` and ``torchao.prototype.awq`` modules. +- **Float8 static quantization**: ``Float8StaticActivationFloat8WeightConfig`` (in ``torchao.prototype.quantization``) supports a built-in observer-based prepare/convert flow for float8 dtypes. + +For a list of all available quantization configs, see the :doc:`API Reference <../api_reference/api_ref_quantization>`. + +In this tutorial, we walked through how to perform integer static quantization in torchao using the current ``quantize_`` API with ``Int8StaticActivationInt8WeightConfig``. From cbfef022ba49cfcca4ad24a82f0545febf7510f8 Mon Sep 17 00:00:00 2001 From: ali-ch-001 Date: Thu, 25 Jun 2026 15:03:39 +0500 Subject: [PATCH 2/2] docs: add protocol references for AWQ/SmoothQuant contracts --- docs/source/eager_tutorials/static_quantization.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/eager_tutorials/static_quantization.rst b/docs/source/eager_tutorials/static_quantization.rst index a4d4db48c2..371f7bb609 100644 --- a/docs/source/eager_tutorials/static_quantization.rst +++ b/docs/source/eager_tutorials/static_quantization.rst @@ -142,7 +142,7 @@ Other Approaches The calibration phase can be customized: - **Observers**: You can use lower-level observer APIs (``AffineQuantizedMinMaxObserver``) to record min/max statistics over calibration data and compute scales yourself. This gives full control over scale computation (e.g., moving averages, histograms). -- **AWQ / SmoothQuant**: These algorithms provide a ``prepare`` -> calibrate -> ``convert`` flow via ``quantize_()`` that integrates activation pre-scaling and smoothing with static quantization. See the ``torchao.prototype.smoothquant`` and ``torchao.prototype.awq`` modules. +- **AWQ / SmoothQuant**: These algorithms provide a ``prepare`` -> calibrate -> ``convert`` flow via ``quantize_()`` that integrates activation pre-scaling and smoothing with static quantization. They work with any tensor subclass that implements the relevant protocols — ``SupportsActivationPreScaling`` (``act_pre_scale`` attribute) for AWQ, and both ``IsStaticQuantizationConfig`` and ``SupportsActivationPreScaling`` for SmoothQuant. See the ``torchao.prototype.smoothquant`` and ``torchao.prototype.awq`` modules. - **Float8 static quantization**: ``Float8StaticActivationFloat8WeightConfig`` (in ``torchao.prototype.quantization``) supports a built-in observer-based prepare/convert flow for float8 dtypes. For a list of all available quantization configs, see the :doc:`API Reference <../api_reference/api_ref_quantization>`.