mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:45:15 +08:00
Use revision when downloading the quantization config file (#2697)
Co-authored-by: Pernekhan Utemuratov <pernekhan@deepinfra.com>
This commit is contained in:
parent
bb8c697ee0
commit
c410f5d020
@ -44,10 +44,7 @@ def get_model(model_config: ModelConfig,
|
|||||||
# Get the (maybe quantized) linear method.
|
# Get the (maybe quantized) linear method.
|
||||||
linear_method = None
|
linear_method = None
|
||||||
if model_config.quantization is not None:
|
if model_config.quantization is not None:
|
||||||
quant_config = get_quant_config(model_config.quantization,
|
quant_config = get_quant_config(model_config)
|
||||||
model_config.model,
|
|
||||||
model_config.hf_config,
|
|
||||||
model_config.download_dir)
|
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
if capability < quant_config.get_min_capability():
|
if capability < quant_config.get_min_capability():
|
||||||
|
|||||||
@ -11,9 +11,9 @@ from huggingface_hub import snapshot_download, HfFileSystem
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from safetensors.torch import load_file, save_file, safe_open
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import (get_quantization_config,
|
from vllm.model_executor.layers.quantization import (get_quantization_config,
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
@ -83,25 +83,22 @@ def convert_bin_to_safetensor_file(
|
|||||||
|
|
||||||
|
|
||||||
# TODO(woosuk): Move this to other place.
|
# TODO(woosuk): Move this to other place.
|
||||||
def get_quant_config(
|
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||||
quantization: str,
|
quant_cls = get_quantization_config(model_config.quantization)
|
||||||
model_name_or_path: str,
|
|
||||||
hf_config: PretrainedConfig,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
) -> QuantizationConfig:
|
|
||||||
quant_cls = get_quantization_config(quantization)
|
|
||||||
# Read the quantization config from the HF model config, if available.
|
# Read the quantization config from the HF model config, if available.
|
||||||
hf_quant_config = getattr(hf_config, "quantization_config", None)
|
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||||
|
None)
|
||||||
if hf_quant_config is not None:
|
if hf_quant_config is not None:
|
||||||
return quant_cls.from_config(hf_quant_config)
|
return quant_cls.from_config(hf_quant_config)
|
||||||
|
model_name_or_path = model_config.model
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
if not is_local:
|
if not is_local:
|
||||||
# Download the config files.
|
# Download the config files.
|
||||||
with get_lock(model_name_or_path, cache_dir):
|
with get_lock(model_name_or_path, model_config.download_dir):
|
||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
|
revision=model_config.revision,
|
||||||
allow_patterns="*.json",
|
allow_patterns="*.json",
|
||||||
cache_dir=cache_dir,
|
cache_dir=model_config.download_dir,
|
||||||
tqdm_class=Disabledtqdm)
|
tqdm_class=Disabledtqdm)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
@ -112,10 +109,12 @@ def get_quant_config(
|
|||||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||||
]
|
]
|
||||||
if len(quant_config_files) == 0:
|
if len(quant_config_files) == 0:
|
||||||
raise ValueError(f"Cannot find the config file for {quantization}")
|
raise ValueError(
|
||||||
|
f"Cannot find the config file for {model_config.quantization}")
|
||||||
if len(quant_config_files) > 1:
|
if len(quant_config_files) > 1:
|
||||||
raise ValueError(f"Found multiple config files for {quantization}: "
|
raise ValueError(
|
||||||
f"{quant_config_files}")
|
f"Found multiple config files for {model_config.quantization}: "
|
||||||
|
f"{quant_config_files}")
|
||||||
|
|
||||||
quant_config_file = quant_config_files[0]
|
quant_config_file = quant_config_files[0]
|
||||||
with open(quant_config_file, "r") as f:
|
with open(quant_config_file, "r") as f:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user