mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:14:57 +08:00
Use revision when downloading the quantization config file (#2697)
Co-authored-by: Pernekhan Utemuratov <pernekhan@deepinfra.com>
This commit is contained in:
parent
bb8c697ee0
commit
c410f5d020
@ -44,10 +44,7 @@ def get_model(model_config: ModelConfig,
|
||||
# Get the (maybe quantized) linear method.
|
||||
linear_method = None
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config.quantization,
|
||||
model_config.model,
|
||||
model_config.hf_config,
|
||||
model_config.download_dir)
|
||||
quant_config = get_quant_config(model_config)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
|
||||
@ -11,9 +11,9 @@ from huggingface_hub import snapshot_download, HfFileSystem
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (get_quantization_config,
|
||||
QuantizationConfig)
|
||||
@ -83,25 +83,22 @@ def convert_bin_to_safetensor_file(
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(
|
||||
quantization: str,
|
||||
model_name_or_path: str,
|
||||
hf_config: PretrainedConfig,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> QuantizationConfig:
|
||||
quant_cls = get_quantization_config(quantization)
|
||||
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(hf_config, "quantization_config", None)
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
with get_lock(model_name_or_path, model_config.download_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=cache_dir,
|
||||
cache_dir=model_config.download_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
@ -112,10 +109,12 @@ def get_quant_config(
|
||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(f"Cannot find the config file for {quantization}")
|
||||
raise ValueError(
|
||||
f"Cannot find the config file for {model_config.quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(f"Found multiple config files for {quantization}: "
|
||||
f"{quant_config_files}")
|
||||
raise ValueError(
|
||||
f"Found multiple config files for {model_config.quantization}: "
|
||||
f"{quant_config_files}")
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file, "r") as f:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user