-
Notifications
You must be signed in to change notification settings - Fork 551
[docs] Update static quantization tutorial to use Int8StaticActivationInt8WeightConfi issue fixed #3637 #4533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Ali-Ch-001
wants to merge
2
commits into
pytorch:main
Choose a base branch
from
Ali-Ch-001:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+129
−218
Open
Changes from 1 commit
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,239 +1,150 @@ | ||
| Static Quantization | ||
| -------------------- | ||
|
|
||
| Static quantization refers to using a fixed quantization range for all inputs during inference or generation. Unlike dynamic quantization, which dynamically computes new quantization ranges for each new input batch, static quantization typically results in more efficient computation, potentially at the cost of lower quantized accuracy since we cannot adapt to changes in the input distribution on-the-fly. | ||
| Static quantization refers to using a fixed quantization range for all inputs during inference. Unlike dynamic quantization, which recomputes quantization ranges for each new input batch, static quantization typically results in more efficient computation, potentially at the cost of lower quantized accuracy since we cannot adapt to changes in the input distribution on-the-fly. | ||
|
|
||
| In static quantization, this fixed quantization range is typically calibrated on similar inputs before quantizing the model. During the calibration phase, we first insert observers into the model to "observe" the distribution of the inputs to be quantized, and use this distribution to decide what scales and zero points to ultimately use when quantizing the model. | ||
| In static quantization, this fixed quantization range is typically calibrated on similar inputs before quantizing the model. During the calibration phase, we determine what scales and zero points to use, then lock them in for all future inference. | ||
|
|
||
| In this tutorial, we walk through an example of how to achieve this in torchao. Let's start with our toy linear model: | ||
| In this tutorial, we walk through an example of how to achieve this in torchao using ``Int8StaticActivationInt8WeightConfig`` and ``quantize_``. Let's start with our toy linear model: | ||
|
|
||
| .. code:: py | ||
| .. code:: python | ||
|
|
||
| import copy | ||
| import torch | ||
| from collections import OrderedDict | ||
| import copy | ||
| import torch | ||
|
|
||
| class ToyLinearModel(torch.nn.Module): | ||
| def __init__(self, m=64, n=32, k=64): | ||
| super().__init__() | ||
| self.linear1 = torch.nn.Linear(m, k, bias=False) | ||
| self.linear2 = torch.nn.Linear(k, n, bias=False) | ||
| from torchao.quantization import ( | ||
| AffineQuantizedMinMaxObserver, | ||
| FqnToConfig, | ||
| Int8StaticActivationInt8WeightConfig, | ||
| MappingType, | ||
| PerRow, | ||
| PerTensor, | ||
| quantize_, | ||
| ) | ||
|
|
||
| def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): | ||
| return ( | ||
| torch.randn( | ||
| batch_size, self.linear1.in_features, dtype=dtype, device=device | ||
| ), | ||
| ) | ||
| class ToyLinearModel(torch.nn.Module): | ||
| def __init__(self, m=64, n=32, k=64): | ||
| super().__init__() | ||
| self.linear1 = torch.nn.Linear(m, k, bias=False) | ||
| self.linear2 = torch.nn.Linear(k, n, bias=False) | ||
|
|
||
| def forward(self, x): | ||
| x = self.linear1(x) | ||
| x = self.linear2(x) | ||
| return x | ||
| def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): | ||
| return ( | ||
| torch.randn( | ||
| batch_size, self.linear1.in_features, dtype=dtype, device=device | ||
| ), | ||
| ) | ||
|
|
||
| dtype = torch.bfloat16 | ||
| m = ToyLinearModel().eval().to(dtype).to("cuda") | ||
| m = torch.compile(m, mode="max-autotune") | ||
| def forward(self, x): | ||
| x = self.linear1(x) | ||
| x = self.linear2(x) | ||
| return x | ||
|
|
||
|
|
||
| Calibration Phase | ||
| ~~~~~~~~~~~~~~~~~ | ||
|
|
||
| torchao comes with a a simple observer implementation, `AffineQuantizedMinMaxObserver`, that records the min and max values that have flowed through the observer during the calibration phase. Users are welcome to implement their own desired, more advanced observation techniques, such as those relying on moving averages or histograms, and these may be added to torchao in the future. | ||
|
|
||
| .. code:: py | ||
|
|
||
| from torchao.quantization.granularity import PerAxis, PerTensor | ||
| from torchao.quantization.observer import AffineQuantizedMinMaxObserver | ||
| from torchao.quantization.quant_primitives import MappingType | ||
|
|
||
| # per tensor input activation asymmetric quantization | ||
| act_obs = AffineQuantizedMinMaxObserver( | ||
| MappingType.ASYMMETRIC, | ||
| torch.uint8, | ||
| granularity=PerTensor(), | ||
| eps=torch.finfo(torch.float32).eps, | ||
| scale_dtype=torch.float32, | ||
| zero_point_dtype=torch.float32, | ||
| ) | ||
|
|
||
| # per channel weight asymmetric quantization | ||
| weight_obs = AffineQuantizedMinMaxObserver( | ||
| MappingType.ASYMMETRIC, | ||
| torch.uint8, | ||
| granularity=PerAxis(axis=0), | ||
| eps=torch.finfo(torch.float32).eps, | ||
| scale_dtype=torch.float32, | ||
| zero_point_dtype=torch.float32, | ||
| ) | ||
|
|
||
| Next, we define our observed linear that we will swap our `torch.nn.Linear` with. This is a high precision (e.g. fp32) linear module with the above observers inserted to record the input activation and weight values during calibration: | ||
|
|
||
| .. code:: py | ||
|
|
||
| import torch.nn.functional as F | ||
|
|
||
| class ObservedLinear(torch.nn.Linear): | ||
| def __init__( | ||
| self, | ||
| in_features: int, | ||
| out_features: int, | ||
| act_obs: torch.nn.Module, | ||
| weight_obs: torch.nn.Module, | ||
| bias: bool = True, | ||
| device=None, | ||
| dtype=None, | ||
| ): | ||
| super().__init__(in_features, out_features, bias, device, dtype) | ||
| self.act_obs = act_obs | ||
| self.weight_obs = weight_obs | ||
|
|
||
| def forward(self, input: torch.Tensor): | ||
| observed_input = self.act_obs(input) | ||
| observed_weight = self.weight_obs(self.weight) | ||
| return F.linear(observed_input, observed_weight, self.bias) | ||
|
|
||
| @classmethod | ||
| def from_float(cls, float_linear, act_obs, weight_obs): | ||
| observed_linear = cls( | ||
| float_linear.in_features, | ||
| float_linear.out_features, | ||
| act_obs, | ||
| weight_obs, | ||
| False, | ||
| device=float_linear.weight.device, | ||
| dtype=float_linear.weight.dtype, | ||
| ) | ||
| observed_linear.weight = float_linear.weight | ||
| observed_linear.bias = float_linear.bias | ||
| return observed_linear | ||
|
|
||
| To actually insert these observers into our toy model: | ||
|
|
||
| .. code:: py | ||
|
|
||
| from torchao.quantization.quant_api import ( | ||
| _replace_with_custom_fn_if_matches_filter, | ||
| ) | ||
|
|
||
| def insert_observers_(model, act_obs, weight_obs): | ||
| _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) | ||
|
|
||
| def replacement_fn(m): | ||
| copied_act_obs = copy.deepcopy(act_obs) | ||
| copied_weight_obs = copy.deepcopy(weight_obs) | ||
| return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs) | ||
|
|
||
| _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) | ||
|
|
||
| insert_observers_(m, act_obs, weight_obs) | ||
|
|
||
| Now we are ready to calibrate the model, which populates the observers we inserted with statistics recorded during the calibration. We can do this simply by feeding some example inputs to our "observed" model: | ||
|
|
||
| .. code:: py | ||
|
|
||
| for _ in range(10): | ||
| example_inputs = m.example_inputs(dtype=dtype, device="cuda") | ||
| m(*example_inputs) | ||
| The goal of calibration is to determine fixed activation quantization parameters for each linear layer. We insert activation observers with forward pre-hooks, then run representative inputs through the original floating point model: | ||
|
|
||
| .. code:: python | ||
|
|
||
| dtype = torch.bfloat16 | ||
| m = ToyLinearModel().eval().to(dtype).to("cuda") | ||
| m_static = copy.deepcopy(m) | ||
|
|
||
| activation_granularity = PerTensor() | ||
| weight_granularity = PerRow() | ||
| act_mapping_type = MappingType.SYMMETRIC | ||
|
|
||
| activation_observers = OrderedDict() | ||
| observer_handles = [] | ||
|
|
||
| def make_activation_observer(): | ||
| return AffineQuantizedMinMaxObserver( | ||
| act_mapping_type, | ||
| torch.int8, | ||
| granularity=activation_granularity, | ||
| eps=torch.finfo(torch.float32).eps, | ||
| scale_dtype=torch.float32, | ||
| zero_point_dtype=torch.int8, | ||
| keepdim=True, | ||
| ) | ||
|
|
||
| def observe_input(module, inputs, observer): | ||
| observer(inputs[0]) | ||
|
|
||
| for name, module in m.named_modules(): | ||
| if isinstance(module, torch.nn.Linear): | ||
| observer = make_activation_observer() | ||
| activation_observers[name] = observer | ||
| observer_handles.append( | ||
| module.register_forward_pre_hook( | ||
| lambda module, inputs, observer=observer: observe_input( | ||
| module, inputs, observer | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| for _ in range(10): | ||
| example_inputs = m.example_inputs(dtype=dtype, device="cuda") | ||
| m(*example_inputs) | ||
|
|
||
| for handle in observer_handles: | ||
| handle.remove() | ||
|
|
||
| After calibration, each observer can compute the activation scale and zero point for its corresponding layer: | ||
|
|
||
| .. code:: python | ||
|
|
||
| act_scale, act_zero_point = activation_observers["linear1"].calculate_qparams() | ||
|
|
||
|
|
||
| Quantization Phase | ||
| ~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| There are multiple ways to actually quantize the model. Here we walk through the simpler alternative, which is to define a `QuantizedLinear` class that we will swap our `ObservedLinear` to. | ||
|
|
||
| .. code:: py | ||
|
|
||
| from torchao.quantization import Int8Tensor | ||
| from torchao.quantization import PerRow, PerTensor | ||
|
|
||
| class QuantizedLinear(torch.nn.Module): | ||
| def __init__( | ||
| self, | ||
| in_features: int, | ||
| out_features: int, | ||
| act_obs: torch.nn.Module, | ||
| weight_obs: torch.nn.Module, | ||
| weight: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| ): | ||
| super().__init__() | ||
| self.act_scale, self.act_zero_point = act_obs.calculate_qparams() | ||
| weight_scale, weight_zero_point = weight_obs.calculate_qparams() | ||
| self.bias = bias | ||
| self.qweight = Int8Tensor.from_hp( | ||
| weight, granularity=PerRow(), | ||
| scale=weight_scale, zero_point=weight_zero_point, | ||
| ) | ||
|
|
||
| def forward(self, input: torch.Tensor): | ||
| qinput = Int8Tensor.from_hp( | ||
| input, | ||
| granularity=PerTensor(), | ||
| scale=self.act_scale, | ||
| zero_point=self.act_zero_point, | ||
| ) | ||
| return F.linear(qinput, self.qweight, self.bias) | ||
|
|
||
| @classmethod | ||
| def from_observed(cls, observed_linear, target_dtype): | ||
| quantized_linear = cls( | ||
| observed_linear.in_features, | ||
| observed_linear.out_features, | ||
| observed_linear.act_obs, | ||
| observed_linear.weight_obs, | ||
| observed_linear.weight, | ||
| observed_linear.bias, | ||
| target_dtype, | ||
| ) | ||
| return quantized_linear | ||
|
|
||
| This linear class computes the scales and zero points for both input activations and weights in the beginning, effectively fixing the quantization range for future forward calls. Now, to actually quantize the model using this linear class, we can define the following config and pass it to torchao's main `quantize_` API: | ||
|
|
||
| .. code:: py | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| from torchao.core.config import AOBaseConfig | ||
| from torchao.quantization import quantize_ | ||
| from torchao.quantization.transform_module import ( | ||
| register_quantize_module_handler, | ||
| ) | ||
|
|
||
| @dataclass | ||
| class StaticQuantConfig(AOBaseConfig): | ||
| target_dtype: torch.dtype | ||
|
|
||
| @register_quantize_module_handler(StaticQuantConfig) | ||
| def _apply_static_quant( | ||
| module: torch.nn.Module, | ||
| config: StaticQuantConfig, | ||
| ): | ||
| """ | ||
| Define a transformation associated with `StaticQuantConfig`. | ||
| This is called by `quantize_`, not by the user directly. | ||
| """ | ||
| return QuantizedLinear.from_observed(module, config.target_dtype) | ||
|
|
||
| # filter function to identify which modules to swap | ||
| is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) | ||
|
|
||
| # perform static quantization | ||
| quantize_(m, StaticQuantConfig(torch.uint8), is_observed_linear) | ||
|
|
||
| Now, we will see that the linear layers in our model are swapped to our `QuantizedLinear` class, with a fixed input activation scale and a fixed quantized weight: | ||
|
|
||
| .. code:: py | ||
|
|
||
| >>> m | ||
| OptimizedModule( | ||
| (_orig_mod): ToyLinearModel( | ||
| (linear1): QuantizedLinear() | ||
| (linear2): QuantizedLinear() | ||
| ) | ||
| ) | ||
| >>> m.linear1.act_scale | ||
| tensor([0.0237], device='cuda:0') | ||
| >>> m.linear1.qweight # quantized weight tensor with scale and zero_point | ||
| IntxUnpackedToInt8Tensor(...) # actual repr depends on quantization config | ||
|
|
||
| In this tutorial, we walked through a basic example of how to perform integer static quantization in torchao. | ||
| Now we create one ``Int8StaticActivationInt8WeightConfig`` per linear layer and apply the configs with ``FqnToConfig``. This keeps the model structure unchanged and only replaces each linear weight with an ``Int8Tensor`` carrying the calibrated activation quantization parameters: | ||
|
|
||
| .. code:: python | ||
|
|
||
| fqn_to_config = OrderedDict() | ||
| for name, observer in activation_observers.items(): | ||
| act_scale, act_zero_point = observer.calculate_qparams() | ||
| fqn_to_config[f"{name}.weight"] = Int8StaticActivationInt8WeightConfig( | ||
| act_quant_scale=act_scale, | ||
| act_quant_zero_point=act_zero_point, | ||
| granularity=[activation_granularity, weight_granularity], | ||
| act_mapping_type=act_mapping_type, | ||
| ) | ||
|
|
||
| quantize_(m_static, FqnToConfig(fqn_to_config), filter_fn=None) | ||
|
|
||
| Now, we will see that the linear layers in our model have fixed activation scales and quantized weights: | ||
|
|
||
| .. code:: | ||
|
|
||
| >>> m_static | ||
| ToyLinearModel( | ||
| (linear1): Linear(in_features=64, out_features=64, bias=False) | ||
| (linear2): Linear(in_features=64, out_features=32, bias=False) | ||
| ) | ||
| >>> type(m_static.linear1.weight) | ||
| <class 'torchao.quantization.quantize_.workflows.int8.int8_tensor.Int8Tensor'> | ||
| >>> m_static.linear1.weight.act_quant_scale # fixed at calibration time | ||
| tensor(..., device='cuda:0') | ||
|
|
||
| The model structure is unchanged. Only the weight tensors are replaced with quantized tensor subclasses carrying fixed activation scales. All subsequent forward passes will use the same scales, unlike dynamic quantization which recomputes them per batch. | ||
|
|
||
| Other Approaches | ||
| ~~~~~~~~~~~~~~~~ | ||
|
|
||
| The calibration phase can be customized: | ||
|
|
||
| - **Observers**: You can use lower-level observer APIs (``AffineQuantizedMinMaxObserver``) to record min/max statistics over calibration data and compute scales yourself. This gives full control over scale computation (e.g., moving averages, histograms). | ||
| - **AWQ / SmoothQuant**: These algorithms provide a ``prepare`` -> calibrate -> ``convert`` flow via ``quantize_()`` that integrates activation pre-scaling and smoothing with static quantization. See the ``torchao.prototype.smoothquant`` and ``torchao.prototype.awq`` modules. | ||
| - **Float8 static quantization**: ``Float8StaticActivationFloat8WeightConfig`` (in ``torchao.prototype.quantization``) supports a built-in observer-based prepare/convert flow for float8 dtypes. | ||
|
|
||
| For a list of all available quantization configs, see the :doc:`API Reference <../api_reference/api_ref_quantization>`. | ||
|
|
||
| In this tutorial, we walked through how to perform integer static quantization in torchao using the current ``quantize_`` API with ``Int8StaticActivationInt8WeightConfig``. | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.