mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 11:07:06 +08:00
Use quantization_config in hf config (#1695)
This commit is contained in:
parent
e87557b069
commit
bb00f66e19
@ -104,14 +104,30 @@ class ModelConfig:
|
|||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = ["awq", "squeezellm"]
|
supported_quantization = ["awq", "squeezellm"]
|
||||||
if self.quantization is None:
|
if self.quantization is not None:
|
||||||
return
|
self.quantization = self.quantization.lower()
|
||||||
quantization = self.quantization.lower()
|
|
||||||
if quantization not in supported_quantization:
|
# Parse quantization method from the HF model config, if available.
|
||||||
raise ValueError(
|
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||||
f"Unknown quantization: {self.quantization}. Must be one of "
|
if hf_quant_config is not None:
|
||||||
f"{supported_quantization}.")
|
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
||||||
self.quantization = quantization
|
if self.quantization is None:
|
||||||
|
self.quantization = hf_quant_method
|
||||||
|
elif self.quantization != hf_quant_method:
|
||||||
|
raise ValueError(
|
||||||
|
"Quantization method specified in the model config "
|
||||||
|
f"({hf_quant_method}) does not match the quantization "
|
||||||
|
f"method specified in the `quantization` argument "
|
||||||
|
f"({self.quantization}).")
|
||||||
|
|
||||||
|
if self.quantization is not None:
|
||||||
|
if self.quantization not in supported_quantization:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown quantization method: {self.quantization}. Must "
|
||||||
|
f"be one of {supported_quantization}.")
|
||||||
|
logger.warning(f"{self.quantization} quantization is not fully "
|
||||||
|
"optimized yet. The speed can be slower than "
|
||||||
|
"non-quantized models.")
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -66,6 +66,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
if model_config.quantization is not None:
|
if model_config.quantization is not None:
|
||||||
quant_config = get_quant_config(model_config.quantization,
|
quant_config = get_quant_config(model_config.quantization,
|
||||||
model_config.model,
|
model_config.model,
|
||||||
|
model_config.hf_config,
|
||||||
model_config.download_dir)
|
model_config.download_dir)
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
|||||||
@ -7,9 +7,10 @@ from collections import defaultdict
|
|||||||
from typing import Any, Iterator, List, Optional, Tuple
|
from typing import Any, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from safetensors.torch import load_file, save_file, safe_open
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import PretrainedConfig
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -84,8 +85,15 @@ def convert_bin_to_safetensor_file(
|
|||||||
def get_quant_config(
|
def get_quant_config(
|
||||||
quantization: str,
|
quantization: str,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
) -> QuantizationConfig:
|
) -> QuantizationConfig:
|
||||||
|
quant_cls = get_quantization_config(quantization)
|
||||||
|
# Read the quantization config from the HF model config, if available.
|
||||||
|
hf_quant_config = getattr(hf_config, "quantization_config", None)
|
||||||
|
if hf_quant_config is not None:
|
||||||
|
return quant_cls.from_config(hf_quant_config)
|
||||||
|
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
if not is_local:
|
if not is_local:
|
||||||
# Download the config files.
|
# Download the config files.
|
||||||
@ -98,7 +106,6 @@ def get_quant_config(
|
|||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||||
|
|
||||||
quant_cls = get_quantization_config(quantization)
|
|
||||||
quant_config_files = [
|
quant_config_files = [
|
||||||
f for f in config_files if any(
|
f for f in config_files if any(
|
||||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user