mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 08:04:33 +08:00
[Misc] Add support for new autogptq checkpoint_format (#3689)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
93deb0b38f
commit
7d4e1b85e7
68
tests/quantization/test_autogptq_marlin_configs.py
Normal file
68
tests/quantization/test_autogptq_marlin_configs.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""Tests whether Marlin models can be loaded from the autogptq config.
|
||||||
|
|
||||||
|
Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelPair:
|
||||||
|
model_marlin: str
|
||||||
|
model_gptq: str
|
||||||
|
|
||||||
|
|
||||||
|
# Model Id // Expected Kernel
|
||||||
|
MODELS_QUANT_TYPE = [
|
||||||
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||||
|
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"),
|
||||||
|
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"),
|
||||||
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"),
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE)
|
||||||
|
def test_auto_gptq(model_quant_type: str, ) -> None:
|
||||||
|
model_path, quant_type = model_quant_type
|
||||||
|
|
||||||
|
model_config_no_quant_arg = ModelConfig(
|
||||||
|
model_path,
|
||||||
|
model_path,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
download_dir=None,
|
||||||
|
load_format="dummy",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
quantization=None # case 1
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config_quant_arg = ModelConfig(
|
||||||
|
model_path,
|
||||||
|
model_path,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
download_dir=None,
|
||||||
|
load_format="dummy",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
quantization="gptq" # case 2
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_config_no_quant_arg.quantization == quant_type, (
|
||||||
|
f"Expected quant_type == {quant_type} for {model_path}, "
|
||||||
|
f"but found {model_config_no_quant_arg.quantization} "
|
||||||
|
"for no --quantization None case")
|
||||||
|
|
||||||
|
assert model_config_quant_arg.quantization == quant_type, (
|
||||||
|
f"Expected quant_type == {quant_type} for {model_path}, "
|
||||||
|
f"but found {model_config_quant_arg.quantization} "
|
||||||
|
"for --quantization gptq case")
|
||||||
@ -171,26 +171,28 @@ class ModelConfig:
|
|||||||
self.quantization = self.quantization.lower()
|
self.quantization = self.quantization.lower()
|
||||||
|
|
||||||
# Parse quantization method from the HF model config, if available.
|
# Parse quantization method from the HF model config, if available.
|
||||||
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||||
if hf_quant_config is not None:
|
if quant_cfg is not None:
|
||||||
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||||
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||||
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||||
|
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
|
||||||
|
or quant_cfg.get("is_marlin_format", False))
|
||||||
|
|
||||||
# If the GPTQ model is serialized in marlin format, use marlin.
|
# Use marlin if the GPTQ model is serialized in marlin format.
|
||||||
if (hf_quant_method == "gptq"
|
if quant_method == "gptq" and is_format_marlin:
|
||||||
and "is_marlin_format" in hf_quant_config
|
|
||||||
and hf_quant_config["is_marlin_format"]):
|
|
||||||
logger.info("The model is serialized in Marlin format. "
|
logger.info("The model is serialized in Marlin format. "
|
||||||
"Using Marlin kernel.")
|
"Using Marlin kernel.")
|
||||||
hf_quant_method = "marlin"
|
quant_method = "marlin"
|
||||||
if self.quantization == "gptq":
|
if self.quantization == "gptq":
|
||||||
self.quantization = hf_quant_method
|
self.quantization = quant_method
|
||||||
|
|
||||||
if self.quantization is None:
|
if self.quantization is None:
|
||||||
self.quantization = hf_quant_method
|
self.quantization = quant_method
|
||||||
elif self.quantization != hf_quant_method:
|
elif self.quantization != quant_method:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Quantization method specified in the model config "
|
"Quantization method specified in the model config "
|
||||||
f"({hf_quant_method}) does not match the quantization "
|
f"({quant_method}) does not match the quantization "
|
||||||
f"method specified in the `quantization` argument "
|
f"method specified in the `quantization` argument "
|
||||||
f"({self.quantization}).")
|
f"({self.quantization}).")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user