From c410f5d020216df2dfedde52bcae24ae7f0ac7ec Mon Sep 17 00:00:00 2001 From: Pernekhan Utemuratov Date: Thu, 1 Feb 2024 15:41:58 -0800 Subject: [PATCH] Use revision when downloading the quantization config file (#2697) Co-authored-by: Pernekhan Utemuratov --- vllm/model_executor/model_loader.py | 5 +---- vllm/model_executor/weight_utils.py | 29 ++++++++++++++--------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index bf13ebf57d42..cd21c7788fc7 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -44,10 +44,7 @@ def get_model(model_config: ModelConfig, # Get the (maybe quantized) linear method. linear_method = None if model_config.quantization is not None: - quant_config = get_quant_config(model_config.quantization, - model_config.model, - model_config.hf_config, - model_config.download_dir) + quant_config = get_quant_config(model_config) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] if capability < quant_config.get_min_capability(): diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 8e6f7a174f21..3570366887e7 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -11,9 +11,9 @@ from huggingface_hub import snapshot_download, HfFileSystem import numpy as np from safetensors.torch import load_file, save_file, safe_open import torch -from transformers import PretrainedConfig from tqdm.auto import tqdm +from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (get_quantization_config, QuantizationConfig) @@ -83,25 +83,22 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. -def get_quant_config( - quantization: str, - model_name_or_path: str, - hf_config: PretrainedConfig, - cache_dir: Optional[str] = None, -) -> QuantizationConfig: - quant_cls = get_quantization_config(quantization) +def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: + quant_cls = get_quantization_config(model_config.quantization) # Read the quantization config from the HF model config, if available. - hf_quant_config = getattr(hf_config, "quantization_config", None) + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) - + model_name_or_path = model_config.model is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. - with get_lock(model_name_or_path, cache_dir): + with get_lock(model_name_or_path, model_config.download_dir): hf_folder = snapshot_download(model_name_or_path, + revision=model_config.revision, allow_patterns="*.json", - cache_dir=cache_dir, + cache_dir=model_config.download_dir, tqdm_class=Disabledtqdm) else: hf_folder = model_name_or_path @@ -112,10 +109,12 @@ def get_quant_config( f.endswith(x) for x in quant_cls.get_config_filenames()) ] if len(quant_config_files) == 0: - raise ValueError(f"Cannot find the config file for {quantization}") + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") if len(quant_config_files) > 1: - raise ValueError(f"Found multiple config files for {quantization}: " - f"{quant_config_files}") + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") quant_config_file = quant_config_files[0] with open(quant_config_file, "r") as f: