mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 05:20:54 +08:00
Fix GPTQ model loading in Transformers backend (#25770)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
1cb6005627
commit
c7ae7edb33
@ -100,10 +100,9 @@ def test_distributed(
|
||||
kwargs_test=kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="bitsandbytes quantization is currently not supported in rocm.")
|
||||
@pytest.mark.parametrize("model, quantization_kwargs", [
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}),
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}),
|
||||
(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
@ -121,6 +120,11 @@ def test_quantization(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
if (current_platform.is_rocm()
|
||||
and quantization_kwargs.get("quantization", "") == "bitsandbytes"):
|
||||
pytest.skip(
|
||||
"bitsandbytes quantization is currently not supported in rocm.")
|
||||
|
||||
with vllm_runner(
|
||||
model, model_impl="auto", enforce_eager=True,
|
||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||
|
||||
@ -447,7 +447,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
self.device_config: DeviceConfig = vllm_config.device_config
|
||||
self.model_config: ModelConfig = vllm_config.model_config
|
||||
self.parallel_config: ParallelConfig = vllm_config.parallel_config
|
||||
self.quant_config: QuantizationConfig = vllm_config.quant_config
|
||||
self.quant_config: Optional[
|
||||
QuantizationConfig] = vllm_config.quant_config
|
||||
|
||||
self.pp_group = get_pp_group()
|
||||
self.pp_size = self.pp_group.world_size
|
||||
@ -456,7 +457,18 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
|
||||
# Weights to skip in `self.load_weights`
|
||||
self.skip_prefixes: list[str] = []
|
||||
"""Skip loading weights whose qualname starts with these prefixes."""
|
||||
self.skip_substrs: list[str] = []
|
||||
"""Skip loading weights whose qualname contains these substrings."""
|
||||
self.ignore_unexpected_prefixes: list[str] = []
|
||||
"""Ignore unexpected weights whose qualname starts with these prefixes.
|
||||
"""
|
||||
self.ignore_unexpected_suffixes: list[str] = []
|
||||
"""Ignore unexpected weights whose qualname ends with these suffixes."""
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if self.quant_config and "gptq" in self.quant_config.get_name():
|
||||
self.ignore_unexpected_suffixes.append(".bias")
|
||||
|
||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
||||
@ -563,9 +575,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
raise ValueError(
|
||||
f"{type(self.model)} does not support tensor parallel. {tip}")
|
||||
|
||||
def _tensor_parallel(module: nn.Module,
|
||||
prefix: str = "",
|
||||
tp_plan=None):
|
||||
def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None):
|
||||
tp_plan = tp_plan or {}
|
||||
|
||||
# If the current module is a PreTrainedModel, set the tp_plan for
|
||||
@ -597,7 +607,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
prefix=qual_name,
|
||||
tp_plan=tp_plan)
|
||||
|
||||
_tensor_parallel(self.model)
|
||||
_tensor_parallel(self.model, prefix="model")
|
||||
|
||||
def create_attention_instances(
|
||||
self,
|
||||
@ -696,6 +706,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
skip_prefixes=self.skip_prefixes,
|
||||
skip_substrs=self.skip_substrs,
|
||||
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
|
||||
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
|
||||
)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
@ -109,6 +109,7 @@ class AutoWeightsLoader:
|
||||
skip_prefixes: Optional[list[str]] = None,
|
||||
skip_substrs: Optional[list[str]] = None,
|
||||
ignore_unexpected_prefixes: Optional[list[str]] = None,
|
||||
ignore_unexpected_suffixes: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -116,6 +117,7 @@ class AutoWeightsLoader:
|
||||
self.skip_prefixes = skip_prefixes or []
|
||||
self.skip_substrs = skip_substrs or []
|
||||
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
|
||||
self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
|
||||
# update default skip_substrs
|
||||
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
|
||||
|
||||
@ -149,8 +151,9 @@ class AutoWeightsLoader:
|
||||
or any(substr in qualname for substr in self.skip_substrs))
|
||||
|
||||
def _can_ignore_unexpected(self, qualname: str) -> bool:
|
||||
return any(
|
||||
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
|
||||
iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
|
||||
ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
|
||||
return any(iup) or any(ius)
|
||||
|
||||
def _load_param(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user