Use revision when downloading the quantization config file (#2697)

Co-authored-by: Pernekhan Utemuratov <pernekhan@deepinfra.com>
This commit is contained in:
Pernekhan Utemuratov 2024-02-01 15:41:58 -08:00 committed by GitHub
parent bb8c697ee0
commit c410f5d020
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 19 deletions

View File

@ -44,10 +44,7 @@ def get_model(model_config: ModelConfig,
# Get the (maybe quantized) linear method.
linear_method = None
if model_config.quantization is not None:
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.hf_config,
model_config.download_dir)
quant_config = get_quant_config(model_config)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():

View File

@ -11,9 +11,9 @@ from huggingface_hub import snapshot_download, HfFileSystem
import numpy as np
from safetensors.torch import load_file, save_file, safe_open
import torch
from transformers import PretrainedConfig
from tqdm.auto import tqdm
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (get_quantization_config,
QuantizationConfig)
@ -83,25 +83,22 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def get_quant_config(
quantization: str,
model_name_or_path: str,
hf_config: PretrainedConfig,
cache_dir: Optional[str] = None,
) -> QuantizationConfig:
quant_cls = get_quantization_config(quantization)
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(hf_config, "quantization_config", None)
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, cache_dir):
with get_lock(model_name_or_path, model_config.download_dir):
hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=cache_dir,
cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
@ -112,9 +109,11 @@ def get_quant_config(
f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {quantization}")
raise ValueError(
f"Cannot find the config file for {model_config.quantization}")
if len(quant_config_files) > 1:
raise ValueError(f"Found multiple config files for {quantization}: "
raise ValueError(
f"Found multiple config files for {model_config.quantization}: "
f"{quant_config_files}")
quant_config_file = quant_config_files[0]