Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions .github/scripts/torchao_model_releases/quantize_and_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
import torch
import transformers
from huggingface_hub import ModelCard, get_token, whoami
from packaging.version import Version, parse
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

_transformers_version = str(transformers.__version__)
if _transformers_version >= "5":
TRANSFORMERS_VERSION = parse(transformers.__version__)
IS_TRANSFORMERS_V5_OR_GREATER = TRANSFORMERS_VERSION.major >= 5
if IS_TRANSFORMERS_V5_OR_GREATER:
from transformers.quantizers.auto import get_hf_quantizer

_huggingface_hub_version = str(huggingface_hub.__version__)
HF_HUB_VERSION = parse(huggingface_hub.__version__)

from torchao.prototype.awq import (
AWQConfig,
Expand All @@ -42,7 +44,7 @@
)
from torchao.quantization.quant_api import _is_linear

safe_serialization = _transformers_version >= "5"
safe_serialization = IS_TRANSFORMERS_V5_OR_GREATER


def _get_username():
Expand All @@ -58,7 +60,7 @@ def _untie_weights_and_save_locally(model_id, device):

tokenizer = AutoTokenizer.from_pretrained(model_id)

if _transformers_version >= "5":
if IS_TRANSFORMERS_V5_OR_GREATER:
from accelerate.utils.modeling import find_tied_parameters
else:
from transformers.modeling_utils import find_tied_parameters
Expand Down Expand Up @@ -925,7 +927,7 @@ def filter_fn_skip_lmhead(module, fqn):

# Push to hub
if push_to_hub:
if _huggingface_hub_version < "1.4.1":
if HF_HUB_VERSION < Version("1.4.1"):
quantized_model.push_to_hub(
quantized_model_id, safe_serialization=safe_serialization
)
Expand All @@ -936,7 +938,7 @@ def filter_fn_skip_lmhead(module, fqn):
if populate_model_card_template:
card.push_to_hub(quantized_model_id)
else:
if _huggingface_hub_version < "1.4.1":
if HF_HUB_VERSION < Version("1.4.1"):
quantized_model.save_pretrained(
quantized_model_id, safe_serialization=safe_serialization
)
Expand Down