mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 23:37:52 +08:00
Fix GPTQ model loading in Transformers backend (#25770)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
7977e5027c
commit
ec152c8748
@ -100,10 +100,9 @@ def test_distributed(
|
|||||||
kwargs_test=kwargs)
|
kwargs_test=kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="bitsandbytes quantization is currently not supported in rocm.")
|
|
||||||
@pytest.mark.parametrize("model, quantization_kwargs", [
|
@pytest.mark.parametrize("model, quantization_kwargs", [
|
||||||
|
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}),
|
||||||
|
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}),
|
||||||
(
|
(
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
{
|
{
|
||||||
@ -121,6 +120,11 @@ def test_quantization(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if (current_platform.is_rocm()
|
||||||
|
and quantization_kwargs.get("quantization", "") == "bitsandbytes"):
|
||||||
|
pytest.skip(
|
||||||
|
"bitsandbytes quantization is currently not supported in rocm.")
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model, model_impl="auto", enforce_eager=True,
|
model, model_impl="auto", enforce_eager=True,
|
||||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||||
|
|||||||
@ -447,7 +447,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
self.device_config: DeviceConfig = vllm_config.device_config
|
self.device_config: DeviceConfig = vllm_config.device_config
|
||||||
self.model_config: ModelConfig = vllm_config.model_config
|
self.model_config: ModelConfig = vllm_config.model_config
|
||||||
self.parallel_config: ParallelConfig = vllm_config.parallel_config
|
self.parallel_config: ParallelConfig = vllm_config.parallel_config
|
||||||
self.quant_config: QuantizationConfig = vllm_config.quant_config
|
self.quant_config: Optional[
|
||||||
|
QuantizationConfig] = vllm_config.quant_config
|
||||||
|
|
||||||
self.pp_group = get_pp_group()
|
self.pp_group = get_pp_group()
|
||||||
self.pp_size = self.pp_group.world_size
|
self.pp_size = self.pp_group.world_size
|
||||||
@ -456,7 +457,18 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
# Weights to skip in `self.load_weights`
|
# Weights to skip in `self.load_weights`
|
||||||
self.skip_prefixes: list[str] = []
|
self.skip_prefixes: list[str] = []
|
||||||
|
"""Skip loading weights whose qualname starts with these prefixes."""
|
||||||
self.skip_substrs: list[str] = []
|
self.skip_substrs: list[str] = []
|
||||||
|
"""Skip loading weights whose qualname contains these substrings."""
|
||||||
|
self.ignore_unexpected_prefixes: list[str] = []
|
||||||
|
"""Ignore unexpected weights whose qualname starts with these prefixes.
|
||||||
|
"""
|
||||||
|
self.ignore_unexpected_suffixes: list[str] = []
|
||||||
|
"""Ignore unexpected weights whose qualname ends with these suffixes."""
|
||||||
|
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if self.quant_config and "gptq" in self.quant_config.get_name():
|
||||||
|
self.ignore_unexpected_suffixes.append(".bias")
|
||||||
|
|
||||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||||
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
||||||
@ -563,9 +575,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{type(self.model)} does not support tensor parallel. {tip}")
|
f"{type(self.model)} does not support tensor parallel. {tip}")
|
||||||
|
|
||||||
def _tensor_parallel(module: nn.Module,
|
def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None):
|
||||||
prefix: str = "",
|
|
||||||
tp_plan=None):
|
|
||||||
tp_plan = tp_plan or {}
|
tp_plan = tp_plan or {}
|
||||||
|
|
||||||
# If the current module is a PreTrainedModel, set the tp_plan for
|
# If the current module is a PreTrainedModel, set the tp_plan for
|
||||||
@ -597,7 +607,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
prefix=qual_name,
|
prefix=qual_name,
|
||||||
tp_plan=tp_plan)
|
tp_plan=tp_plan)
|
||||||
|
|
||||||
_tensor_parallel(self.model)
|
_tensor_parallel(self.model, prefix="model")
|
||||||
|
|
||||||
def create_attention_instances(
|
def create_attention_instances(
|
||||||
self,
|
self,
|
||||||
@ -696,6 +706,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
skip_prefixes=self.skip_prefixes,
|
skip_prefixes=self.skip_prefixes,
|
||||||
skip_substrs=self.skip_substrs,
|
skip_substrs=self.skip_substrs,
|
||||||
|
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
|
||||||
|
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
|||||||
@ -109,6 +109,7 @@ class AutoWeightsLoader:
|
|||||||
skip_prefixes: Optional[list[str]] = None,
|
skip_prefixes: Optional[list[str]] = None,
|
||||||
skip_substrs: Optional[list[str]] = None,
|
skip_substrs: Optional[list[str]] = None,
|
||||||
ignore_unexpected_prefixes: Optional[list[str]] = None,
|
ignore_unexpected_prefixes: Optional[list[str]] = None,
|
||||||
|
ignore_unexpected_suffixes: Optional[list[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -116,6 +117,7 @@ class AutoWeightsLoader:
|
|||||||
self.skip_prefixes = skip_prefixes or []
|
self.skip_prefixes = skip_prefixes or []
|
||||||
self.skip_substrs = skip_substrs or []
|
self.skip_substrs = skip_substrs or []
|
||||||
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
|
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
|
||||||
|
self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
|
||||||
# update default skip_substrs
|
# update default skip_substrs
|
||||||
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
|
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
|
||||||
|
|
||||||
@ -149,8 +151,9 @@ class AutoWeightsLoader:
|
|||||||
or any(substr in qualname for substr in self.skip_substrs))
|
or any(substr in qualname for substr in self.skip_substrs))
|
||||||
|
|
||||||
def _can_ignore_unexpected(self, qualname: str) -> bool:
|
def _can_ignore_unexpected(self, qualname: str) -> bool:
|
||||||
return any(
|
iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
|
||||||
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
|
ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
|
||||||
|
return any(iup) or any(ius)
|
||||||
|
|
||||||
def _load_param(
|
def _load_param(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user