From f46dc7135b5dd306876e76cf378366ba01d789d7 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sun, 21 Jun 2026 13:50:57 +0200 Subject: [PATCH] Use semantic version checks in release script --- .../quantize_and_upload.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/.github/scripts/torchao_model_releases/quantize_and_upload.py b/.github/scripts/torchao_model_releases/quantize_and_upload.py index 4f91fa611e..5a1e6d942a 100644 --- a/.github/scripts/torchao_model_releases/quantize_and_upload.py +++ b/.github/scripts/torchao_model_releases/quantize_and_upload.py @@ -11,13 +11,15 @@ import torch import transformers from huggingface_hub import ModelCard, get_token, whoami +from packaging.version import Version, parse from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig -_transformers_version = str(transformers.__version__) -if _transformers_version >= "5": +TRANSFORMERS_VERSION = parse(transformers.__version__) +IS_TRANSFORMERS_V5_OR_GREATER = TRANSFORMERS_VERSION.major >= 5 +if IS_TRANSFORMERS_V5_OR_GREATER: from transformers.quantizers.auto import get_hf_quantizer -_huggingface_hub_version = str(huggingface_hub.__version__) +HF_HUB_VERSION = parse(huggingface_hub.__version__) from torchao.prototype.awq import ( AWQConfig, @@ -42,7 +44,7 @@ ) from torchao.quantization.quant_api import _is_linear -safe_serialization = _transformers_version >= "5" +safe_serialization = IS_TRANSFORMERS_V5_OR_GREATER def _get_username(): @@ -58,7 +60,7 @@ def _untie_weights_and_save_locally(model_id, device): tokenizer = AutoTokenizer.from_pretrained(model_id) - if _transformers_version >= "5": + if IS_TRANSFORMERS_V5_OR_GREATER: from accelerate.utils.modeling import find_tied_parameters else: from transformers.modeling_utils import find_tied_parameters @@ -925,7 +927,7 @@ def filter_fn_skip_lmhead(module, fqn): # Push to hub if push_to_hub: - if _huggingface_hub_version < "1.4.1": + if HF_HUB_VERSION < Version("1.4.1"): quantized_model.push_to_hub( quantized_model_id, safe_serialization=safe_serialization ) @@ -936,7 +938,7 @@ def filter_fn_skip_lmhead(module, fqn): if populate_model_card_template: card.push_to_hub(quantized_model_id) else: - if _huggingface_hub_version < "1.4.1": + if HF_HUB_VERSION < Version("1.4.1"): quantized_model.save_pretrained( quantized_model_id, safe_serialization=safe_serialization )