Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
347 changes: 129 additions & 218 deletions docs/source/eager_tutorials/static_quantization.rst
Original file line number Diff line number Diff line change
@@ -1,239 +1,150 @@
Static Quantization
--------------------

Static quantization refers to using a fixed quantization range for all inputs during inference or generation. Unlike dynamic quantization, which dynamically computes new quantization ranges for each new input batch, static quantization typically results in more efficient computation, potentially at the cost of lower quantized accuracy since we cannot adapt to changes in the input distribution on-the-fly.
Static quantization refers to using a fixed quantization range for all inputs during inference. Unlike dynamic quantization, which recomputes quantization ranges for each new input batch, static quantization typically results in more efficient computation, potentially at the cost of lower quantized accuracy since we cannot adapt to changes in the input distribution on-the-fly.

In static quantization, this fixed quantization range is typically calibrated on similar inputs before quantizing the model. During the calibration phase, we first insert observers into the model to "observe" the distribution of the inputs to be quantized, and use this distribution to decide what scales and zero points to ultimately use when quantizing the model.
In static quantization, this fixed quantization range is typically calibrated on similar inputs before quantizing the model. During the calibration phase, we determine what scales and zero points to use, then lock them in for all future inference.

In this tutorial, we walk through an example of how to achieve this in torchao. Let's start with our toy linear model:
In this tutorial, we walk through an example of how to achieve this in torchao using ``Int8StaticActivationInt8WeightConfig`` and ``quantize_``. Let's start with our toy linear model:

.. code:: py
.. code:: python

import copy
import torch
from collections import OrderedDict
import copy
import torch

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, k, bias=False)
self.linear2 = torch.nn.Linear(k, n, bias=False)
from torchao.quantization import (
AffineQuantizedMinMaxObserver,
FqnToConfig,
Int8StaticActivationInt8WeightConfig,
MappingType,
PerRow,
PerTensor,
quantize_,
)

def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (
torch.randn(
batch_size, self.linear1.in_features, dtype=dtype, device=device
),
)
class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, k, bias=False)
self.linear2 = torch.nn.Linear(k, n, bias=False)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (
torch.randn(
batch_size, self.linear1.in_features, dtype=dtype, device=device
),
)

dtype = torch.bfloat16
m = ToyLinearModel().eval().to(dtype).to("cuda")
m = torch.compile(m, mode="max-autotune")
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


Calibration Phase
~~~~~~~~~~~~~~~~~

torchao comes with a a simple observer implementation, `AffineQuantizedMinMaxObserver`, that records the min and max values that have flowed through the observer during the calibration phase. Users are welcome to implement their own desired, more advanced observation techniques, such as those relying on moving averages or histograms, and these may be added to torchao in the future.

.. code:: py

from torchao.quantization.granularity import PerAxis, PerTensor
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
from torchao.quantization.quant_primitives import MappingType

# per tensor input activation asymmetric quantization
act_obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float32,
zero_point_dtype=torch.float32,
)

# per channel weight asymmetric quantization
weight_obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float32,
zero_point_dtype=torch.float32,
)

Next, we define our observed linear that we will swap our `torch.nn.Linear` with. This is a high precision (e.g. fp32) linear module with the above observers inserted to record the input activation and weight values during calibration:

.. code:: py

import torch.nn.functional as F

class ObservedLinear(torch.nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
act_obs: torch.nn.Module,
weight_obs: torch.nn.Module,
bias: bool = True,
device=None,
dtype=None,
):
super().__init__(in_features, out_features, bias, device, dtype)
self.act_obs = act_obs
self.weight_obs = weight_obs

def forward(self, input: torch.Tensor):
observed_input = self.act_obs(input)
observed_weight = self.weight_obs(self.weight)
return F.linear(observed_input, observed_weight, self.bias)

@classmethod
def from_float(cls, float_linear, act_obs, weight_obs):
observed_linear = cls(
float_linear.in_features,
float_linear.out_features,
act_obs,
weight_obs,
False,
device=float_linear.weight.device,
dtype=float_linear.weight.dtype,
)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear

To actually insert these observers into our toy model:

.. code:: py

from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)

def insert_observers_(model, act_obs, weight_obs):
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)

def replacement_fn(m):
copied_act_obs = copy.deepcopy(act_obs)
copied_weight_obs = copy.deepcopy(weight_obs)
return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs)

_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)

insert_observers_(m, act_obs, weight_obs)

Now we are ready to calibrate the model, which populates the observers we inserted with statistics recorded during the calibration. We can do this simply by feeding some example inputs to our "observed" model:

.. code:: py

for _ in range(10):
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
m(*example_inputs)
The goal of calibration is to determine fixed activation quantization parameters for each linear layer. We insert activation observers with forward pre-hooks, then run representative inputs through the original floating point model:

.. code:: python

dtype = torch.bfloat16
m = ToyLinearModel().eval().to(dtype).to("cuda")
m_static = copy.deepcopy(m)

activation_granularity = PerTensor()
weight_granularity = PerRow()
act_mapping_type = MappingType.SYMMETRIC

activation_observers = OrderedDict()
observer_handles = []

def make_activation_observer():
return AffineQuantizedMinMaxObserver(
act_mapping_type,
torch.int8,
granularity=activation_granularity,
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float32,
zero_point_dtype=torch.int8,
keepdim=True,
)

def observe_input(module, inputs, observer):
observer(inputs[0])

for name, module in m.named_modules():
if isinstance(module, torch.nn.Linear):
observer = make_activation_observer()
activation_observers[name] = observer
observer_handles.append(
module.register_forward_pre_hook(
lambda module, inputs, observer=observer: observe_input(
module, inputs, observer
)
)
)

with torch.no_grad():
for _ in range(10):
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
m(*example_inputs)

for handle in observer_handles:
handle.remove()

After calibration, each observer can compute the activation scale and zero point for its corresponding layer:

.. code:: python

act_scale, act_zero_point = activation_observers["linear1"].calculate_qparams()


Quantization Phase
~~~~~~~~~~~~~~~~~~

There are multiple ways to actually quantize the model. Here we walk through the simpler alternative, which is to define a `QuantizedLinear` class that we will swap our `ObservedLinear` to.

.. code:: py

from torchao.quantization import Int8Tensor
from torchao.quantization import PerRow, PerTensor

class QuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
act_obs: torch.nn.Module,
weight_obs: torch.nn.Module,
weight: torch.Tensor,
bias: torch.Tensor,
):
super().__init__()
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
self.bias = bias
self.qweight = Int8Tensor.from_hp(
weight, granularity=PerRow(),
scale=weight_scale, zero_point=weight_zero_point,
)

def forward(self, input: torch.Tensor):
qinput = Int8Tensor.from_hp(
input,
granularity=PerTensor(),
scale=self.act_scale,
zero_point=self.act_zero_point,
)
return F.linear(qinput, self.qweight, self.bias)

@classmethod
def from_observed(cls, observed_linear, target_dtype):
quantized_linear = cls(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.act_obs,
observed_linear.weight_obs,
observed_linear.weight,
observed_linear.bias,
target_dtype,
)
return quantized_linear

This linear class computes the scales and zero points for both input activations and weights in the beginning, effectively fixing the quantization range for future forward calls. Now, to actually quantize the model using this linear class, we can define the following config and pass it to torchao's main `quantize_` API:

.. code:: py

from dataclasses import dataclass

from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)

@dataclass
class StaticQuantConfig(AOBaseConfig):
target_dtype: torch.dtype

@register_quantize_module_handler(StaticQuantConfig)
def _apply_static_quant(
module: torch.nn.Module,
config: StaticQuantConfig,
):
"""
Define a transformation associated with `StaticQuantConfig`.
This is called by `quantize_`, not by the user directly.
"""
return QuantizedLinear.from_observed(module, config.target_dtype)

# filter function to identify which modules to swap
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

# perform static quantization
quantize_(m, StaticQuantConfig(torch.uint8), is_observed_linear)

Now, we will see that the linear layers in our model are swapped to our `QuantizedLinear` class, with a fixed input activation scale and a fixed quantized weight:

.. code:: py

>>> m
OptimizedModule(
(_orig_mod): ToyLinearModel(
(linear1): QuantizedLinear()
(linear2): QuantizedLinear()
)
)
>>> m.linear1.act_scale
tensor([0.0237], device='cuda:0')
>>> m.linear1.qweight # quantized weight tensor with scale and zero_point
IntxUnpackedToInt8Tensor(...) # actual repr depends on quantization config

In this tutorial, we walked through a basic example of how to perform integer static quantization in torchao.
Now we create one ``Int8StaticActivationInt8WeightConfig`` per linear layer and apply the configs with ``FqnToConfig``. This keeps the model structure unchanged and only replaces each linear weight with an ``Int8Tensor`` carrying the calibrated activation quantization parameters:

.. code:: python

fqn_to_config = OrderedDict()
for name, observer in activation_observers.items():
act_scale, act_zero_point = observer.calculate_qparams()
fqn_to_config[f"{name}.weight"] = Int8StaticActivationInt8WeightConfig(
act_quant_scale=act_scale,
act_quant_zero_point=act_zero_point,
granularity=[activation_granularity, weight_granularity],
act_mapping_type=act_mapping_type,
)

quantize_(m_static, FqnToConfig(fqn_to_config), filter_fn=None)

Now, we will see that the linear layers in our model have fixed activation scales and quantized weights:

.. code::

>>> m_static
ToyLinearModel(
(linear1): Linear(in_features=64, out_features=64, bias=False)
(linear2): Linear(in_features=64, out_features=32, bias=False)
)
>>> type(m_static.linear1.weight)
<class 'torchao.quantization.quantize_.workflows.int8.int8_tensor.Int8Tensor'>
>>> m_static.linear1.weight.act_quant_scale # fixed at calibration time
tensor(..., device='cuda:0')

The model structure is unchanged. Only the weight tensors are replaced with quantized tensor subclasses carrying fixed activation scales. All subsequent forward passes will use the same scales, unlike dynamic quantization which recomputes them per batch.

Other Approaches
~~~~~~~~~~~~~~~~

The calibration phase can be customized:

- **Observers**: You can use lower-level observer APIs (``AffineQuantizedMinMaxObserver``) to record min/max statistics over calibration data and compute scales yourself. This gives full control over scale computation (e.g., moving averages, histograms).
- **AWQ / SmoothQuant**: These algorithms provide a ``prepare`` -> calibrate -> ``convert`` flow via ``quantize_()`` that integrates activation pre-scaling and smoothing with static quantization. See the ``torchao.prototype.smoothquant`` and ``torchao.prototype.awq`` modules.
Comment thread
Ali-Ch-001 marked this conversation as resolved.
Outdated
- **Float8 static quantization**: ``Float8StaticActivationFloat8WeightConfig`` (in ``torchao.prototype.quantization``) supports a built-in observer-based prepare/convert flow for float8 dtypes.

For a list of all available quantization configs, see the :doc:`API Reference <../api_reference/api_ref_quantization>`.

In this tutorial, we walked through how to perform integer static quantization in torchao using the current ``quantize_`` API with ``Int8StaticActivationInt8WeightConfig``.