mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 22:54:58 +08:00
[CORE] [QUANT] Support for GPTQModel's dynamic quantization per module override/control (#7086)
This commit is contained in:
parent
2c2b560f48
commit
36a08630e8
68
tests/quantization/test_gptq_dynamic.py
Normal file
68
tests/quantization/test_gptq_dynamic.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Tests whether gptq models with dynamic quantized can be loaded.
|
||||||
|
|
||||||
|
Run `pytest tests/quantization/test_gptq_dynamic.py --forked`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||||
|
get_dynamic_override)
|
||||||
|
|
||||||
|
PROMPT = "On the surface of Mars, we found"
|
||||||
|
|
||||||
|
# The first layer is quantized using bits=4, group_size=128
|
||||||
|
# The second layer is quantized using bits=8, group_size=32
|
||||||
|
# All other layers (layer index >= 2) are not quantized
|
||||||
|
MODEL_QUANT = [
|
||||||
|
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue",
|
||||||
|
True),
|
||||||
|
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse",
|
||||||
|
False),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
|
||||||
|
def test_gptq_with_dynamic(vllm_runner, model_id: str,
|
||||||
|
use_marlin_kernel: bool):
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
|
||||||
|
|
||||||
|
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
|
||||||
|
GPTQLinearMethod)
|
||||||
|
|
||||||
|
for name, submodule in (vllm_model.model.llm_engine.model_executor.
|
||||||
|
driver_worker.model_runner.model.named_modules()):
|
||||||
|
if name == "lm_head":
|
||||||
|
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||||
|
elif name == 'model.layers.0.self_attn.qkv_proj':
|
||||||
|
# The first layer is quantized using bits=4, group_size=128
|
||||||
|
# desc_act=True
|
||||||
|
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||||
|
config = submodule.quant_method.quant_config
|
||||||
|
assert config.weight_bits == 4
|
||||||
|
assert config.group_size == 128
|
||||||
|
assert config.desc_act
|
||||||
|
elif name == 'model.layers.1.self_attn.qkv_proj':
|
||||||
|
# The second layer is quantized using bits=8, group_size=32
|
||||||
|
# desc_act=False
|
||||||
|
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||||
|
config = submodule.quant_method.quant_config
|
||||||
|
assert get_dynamic_override(config, layer_name=name,
|
||||||
|
key="bits") == 8
|
||||||
|
assert get_dynamic_override(config,
|
||||||
|
layer_name=name,
|
||||||
|
key="group_size") == 32
|
||||||
|
assert not get_dynamic_override(
|
||||||
|
config, layer_name=name, key="desc_act")
|
||||||
|
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
||||||
|
or name == 'model.layers.2.mlp.gate_up_proj'):
|
||||||
|
# All other layers (layer index >= 2) are not quantized
|
||||||
|
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
|
||||||
|
|
||||||
|
del vllm_model
|
||||||
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
|
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
|
||||||
"""
|
"""
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -17,31 +16,31 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
|
|
||||||
PROMPT = "On the surface of Mars, we found"
|
PROMPT = "On the surface of Mars, we found"
|
||||||
|
|
||||||
MODELS_QUANT = [(
|
MODELS_QUANT = [
|
||||||
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
|
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True),
|
||||||
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
|
("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False),
|
||||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
|
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
|
||||||
|
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
|
@pytest.mark.parametrize("model_id, lm_head_quantized", MODELS_QUANT)
|
||||||
def test_lm_head(
|
def test_lm_head(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
model_lm_head_quant: Tuple[str, bool],
|
model_id: str,
|
||||||
|
lm_head_quantized: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
model, lm_head_quantized = model_lm_head_quant
|
with vllm_runner(model_id, dtype=torch.float16,
|
||||||
|
|
||||||
with vllm_runner(model, dtype=torch.float16,
|
|
||||||
max_model_len=2048) as vllm_model:
|
max_model_len=2048) as vllm_model:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
lm_head_layer = model.lm_head
|
lm_head_layer = model.lm_head
|
||||||
|
|
||||||
if lm_head_quantized:
|
if lm_head_quantized:
|
||||||
assert isinstance(lm_head_layer.linear_method,
|
assert isinstance(lm_head_layer.quant_method,
|
||||||
(GPTQLinearMethod, GPTQMarlinLinearMethod,
|
(GPTQLinearMethod, GPTQMarlinLinearMethod,
|
||||||
MarlinLinearMethod))
|
MarlinLinearMethod))
|
||||||
else:
|
else:
|
||||||
assert isinstance(lm_head_layer.linear_method,
|
assert isinstance(lm_head_layer.quant_method,
|
||||||
UnquantizedEmbeddingMethod)
|
UnquantizedEmbeddingMethod)
|
||||||
|
|
||||||
vllm_model.apply_model(check_model)
|
vllm_model.apply_model(check_model)
|
||||||
|
|||||||
@ -1039,7 +1039,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
embedding_bias: Optional[torch.Tensor] = None,
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
# Get the logits for the next tokens.
|
# Get the logits for the next tokens.
|
||||||
logits = lm_head.linear_method.apply(lm_head, hidden_states)
|
logits = lm_head.quant_method.apply(lm_head, hidden_states)
|
||||||
if embedding_bias is not None:
|
if embedding_bias is not None:
|
||||||
logits += embedding_bias
|
logits += embedding_bias
|
||||||
|
|
||||||
|
|||||||
@ -108,7 +108,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
embedding_bias: Optional[torch.Tensor],
|
embedding_bias: Optional[torch.Tensor],
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
# Get the logits for the next tokens.
|
# Get the logits for the next tokens.
|
||||||
logits = lm_head.linear_method.apply(lm_head,
|
logits = lm_head.quant_method.apply(lm_head,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
bias=embedding_bias)
|
bias=embedding_bias)
|
||||||
|
|
||||||
|
|||||||
@ -3,16 +3,17 @@
|
|||||||
import enum
|
import enum
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||||
|
get_linear_quant_method)
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
PackedColumnParameter,
|
PackedColumnParameter,
|
||||||
@ -32,7 +33,33 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
group_size: int,
|
group_size: int,
|
||||||
desc_act: bool,
|
desc_act: bool,
|
||||||
lm_head_quantized: bool,
|
lm_head_quantized: bool,
|
||||||
|
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# GPTQModel use `dynamic` config property to allow per module
|
||||||
|
# quantization config so each module can be individually optimized.
|
||||||
|
# Format is Dict[str, Dict] where key is a regex string that can
|
||||||
|
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||||
|
# matching of a module.
|
||||||
|
# Default to positive match, override base quant config mode, if no
|
||||||
|
# prefix is used. Value is in dict format of field key and override
|
||||||
|
# value.
|
||||||
|
# Negative matching will skip quantization init for this module
|
||||||
|
# entirely:
|
||||||
|
# non-quantized inference. More details and quantization examples can be
|
||||||
|
# found at: https://github.com/ModelCloud/GPTQModel
|
||||||
|
# Example:
|
||||||
|
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||||
|
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||||
|
# dynamic = {
|
||||||
|
# #`.*\.` matches the layers_node prefix
|
||||||
|
# # positive match layer 10-15
|
||||||
|
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||||
|
# # positive match layer 16-21
|
||||||
|
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||||
|
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||||
|
# }
|
||||||
|
self.dynamic = dynamic
|
||||||
|
|
||||||
self.weight_bits = weight_bits
|
self.weight_bits = weight_bits
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.desc_act = desc_act
|
self.desc_act = desc_act
|
||||||
@ -47,7 +74,8 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||||
f"group_size={self.group_size}, "
|
f"group_size={self.group_size}, "
|
||||||
f"desc_act={self.desc_act}),"
|
f"desc_act={self.desc_act}),"
|
||||||
f"lm_head_quantized={self.lm_head_quantized}")
|
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||||
|
f"dynamic={self.dynamic}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
@ -68,19 +96,20 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||||
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||||
|
dynamic = {} if dynamic is None else dynamic
|
||||||
|
|
||||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
group_size = cls.get_from_keys(config, ["group_size"])
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||||
default=False)
|
default=False)
|
||||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
|
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
||||||
|
dynamic)
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||||
if (isinstance(layer, LinearBase) or
|
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
|
||||||
return GPTQLinearMethod(self)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ExllamaState(Enum):
|
class ExllamaState(Enum):
|
||||||
|
|||||||
@ -9,17 +9,21 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||||
|
get_linear_quant_method)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_marlin_supported, marlin_moe_permute_scales,
|
check_marlin_supported, marlin_moe_permute_scales,
|
||||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
UnquantizedEmbeddingMethod)
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
PackedColumnParameter,
|
PackedColumnParameter,
|
||||||
@ -47,12 +51,41 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
desc_act: bool,
|
desc_act: bool,
|
||||||
is_sym: bool,
|
is_sym: bool,
|
||||||
lm_head_quantized: bool,
|
lm_head_quantized: bool,
|
||||||
|
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||||
) -> None:
|
) -> None:
|
||||||
if desc_act and group_size == -1:
|
if desc_act and group_size == -1:
|
||||||
# In this case, act_order == True is the same as act_order == False
|
# In this case, act_order == True is the same as act_order == False
|
||||||
# (since we have only one group per output channel)
|
# (since we have only one group per output channel)
|
||||||
desc_act = False
|
desc_act = False
|
||||||
|
|
||||||
|
# GPTQModel use `dynamic` config property to allow per module
|
||||||
|
# quantization config so each module can be individually optimized.
|
||||||
|
# Format is Dict[str, Dict] where key is a regex string that can
|
||||||
|
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||||
|
# matching of a module.
|
||||||
|
# Default to positive match, override base quant config mode, if no
|
||||||
|
# prefix is used. Value is in dict format of field key and override
|
||||||
|
# value.
|
||||||
|
# Negative matching will skip quantization init for this module
|
||||||
|
# entirely:
|
||||||
|
# non-quantized inference. More details and quantization examples can be
|
||||||
|
# found at: https://github.com/ModelCloud/GPTQModel
|
||||||
|
# Example:
|
||||||
|
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||||
|
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||||
|
# dynamic = {
|
||||||
|
# #`.*\.` matches the layers_node prefix
|
||||||
|
# # positive match layer 10-15
|
||||||
|
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||||
|
# # positive match layer 16-21
|
||||||
|
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||||
|
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||||
|
# }
|
||||||
|
self.dynamic = dynamic
|
||||||
|
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.is_sym = is_sym
|
||||||
|
|
||||||
self.pack_factor = 32 // weight_bits # packed into int32
|
self.pack_factor = 32 // weight_bits # packed into int32
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.desc_act = desc_act
|
self.desc_act = desc_act
|
||||||
@ -68,7 +101,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||||
f"group_size={self.group_size}, "
|
f"group_size={self.group_size}, "
|
||||||
f"desc_act={self.desc_act}, "
|
f"desc_act={self.desc_act}, "
|
||||||
f"lm_head_quantized={self.lm_head_quantized})")
|
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||||
|
f"dynamic={self.dynamic}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
@ -88,6 +122,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||||
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||||
|
dynamic = {} if dynamic is None else dynamic
|
||||||
|
|
||||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
group_size = cls.get_from_keys(config, ["group_size"])
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||||
@ -95,7 +132,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||||
default=False)
|
default=False)
|
||||||
return cls(weight_bits, group_size, desc_act, is_sym,
|
return cls(weight_bits, group_size, desc_act, is_sym,
|
||||||
lm_head_quantized)
|
lm_head_quantized, dynamic)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(cls, hf_quant_cfg,
|
||||||
@ -120,17 +157,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
|
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod",
|
||||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]:
|
||||||
and self.lm_head_quantized):
|
if isinstance(layer, FusedMoE):
|
||||||
return GPTQMarlinLinearMethod(self)
|
|
||||||
elif isinstance(layer, FusedMoE):
|
|
||||||
return GPTQMarlinMoEMethod(self)
|
return GPTQMarlinMoEMethod(self)
|
||||||
return None
|
return get_linear_quant_method(self, layer, prefix,
|
||||||
|
GPTQMarlinLinearMethod)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||||
# Extract data from quant config.
|
|
||||||
quant_method = quant_config.get("quant_method", "").lower()
|
quant_method = quant_config.get("quant_method", "").lower()
|
||||||
num_bits = quant_config.get("bits")
|
num_bits = quant_config.get("bits")
|
||||||
group_size = quant_config.get("group_size")
|
group_size = quant_config.get("group_size")
|
||||||
@ -143,7 +178,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
if quant_method != "gptq":
|
if quant_method != "gptq":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If we cannot find the info needed in the config, cannot convert.
|
# Marlin conversion is only valid if required properties are found
|
||||||
if (num_bits is None or group_size is None or sym is None
|
if (num_bits is None or group_size is None or sym is None
|
||||||
or desc_act is None):
|
or desc_act is None):
|
||||||
return False
|
return False
|
||||||
|
|||||||
94
vllm/model_executor/layers/quantization/utils/gptq_utils.py
Normal file
94
vllm/model_executor/layers/quantization/utils/gptq_utils.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, UnquantizedEmbeddingMethod)
|
||||||
|
|
||||||
|
|
||||||
|
# Match dynamic rules with module name (prefix) and override quantize
|
||||||
|
# config if module (prefix) matches a rule
|
||||||
|
def override_config(config: QuantizationConfig, prefix: str):
|
||||||
|
weight_bits = get_dynamic_override(config, prefix, "bits",
|
||||||
|
config.weight_bits)
|
||||||
|
if isinstance(weight_bits, int):
|
||||||
|
config.weight_bits = weight_bits
|
||||||
|
group_size = get_dynamic_override(config, prefix, "group_size",
|
||||||
|
config.group_size)
|
||||||
|
if isinstance(group_size, int):
|
||||||
|
config.group_size = group_size
|
||||||
|
desc_act = get_dynamic_override(config, prefix, "desc_act",
|
||||||
|
config.desc_act)
|
||||||
|
if isinstance(desc_act, bool):
|
||||||
|
config.desc_act = desc_act
|
||||||
|
|
||||||
|
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||||
|
if config.get_name() == "gptq_marlin":
|
||||||
|
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||||
|
if isinstance(is_sym, bool):
|
||||||
|
config.is_sym = is_sym
|
||||||
|
|
||||||
|
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||||
|
raise ValueError("Unsupported quantization config: "
|
||||||
|
f"bits={config.weight_bits}, sym={config.is_sym}")
|
||||||
|
|
||||||
|
config.quant_type = config.TYPE_MAP[(config.weight_bits,
|
||||||
|
config.is_sym)]
|
||||||
|
elif config.get_name() == "gptq":
|
||||||
|
if config.weight_bits not in [2, 3, 4, 8]:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||||
|
f"supported for GPTQ, but got {config.weight_bits} bits.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_dynamic_override(
|
||||||
|
config: QuantizationConfig,
|
||||||
|
layer_name: str,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
default_value: Union[int, bool,
|
||||||
|
None] = None) -> Union[Dict, int, bool, None]:
|
||||||
|
for pattern, pattern_dict in config.dynamic.items():
|
||||||
|
# Negative match: matched modules are excluded from quantized init
|
||||||
|
if pattern.startswith("-:"):
|
||||||
|
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||||
|
return False
|
||||||
|
# Positive match: matched modules have quant properties overrides
|
||||||
|
# base quant config
|
||||||
|
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||||
|
if key is None:
|
||||||
|
return pattern_dict
|
||||||
|
else:
|
||||||
|
return pattern_dict.get(key, default_value)
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear_quant_method(
|
||||||
|
config: QuantizationConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
linear_method_cls: type,
|
||||||
|
):
|
||||||
|
cloned_config = deepcopy(config)
|
||||||
|
parallel_lm_head_quantized = isinstance(
|
||||||
|
layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||||
|
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||||
|
# False = skip module, None = no override, else = Positive match
|
||||||
|
if get_dynamic_override( # noqa: E712
|
||||||
|
cloned_config, # noqa: E712
|
||||||
|
layer_name=prefix) == False: # noqa: E712
|
||||||
|
if parallel_lm_head_quantized:
|
||||||
|
return UnquantizedEmbeddingMethod()
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
# Dynamic per module/layer rules may override base config
|
||||||
|
override_config(cloned_config, prefix=prefix)
|
||||||
|
|
||||||
|
return linear_method_cls(cloned_config)
|
||||||
|
return None
|
||||||
@ -226,24 +226,24 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
self.tp_size)
|
self.tp_size)
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
linear_method = None
|
quant_method = None
|
||||||
if quant_config is not None:
|
if quant_config is not None:
|
||||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||||
if linear_method is None:
|
if quant_method is None:
|
||||||
linear_method = UnquantizedEmbeddingMethod()
|
quant_method = UnquantizedEmbeddingMethod()
|
||||||
|
|
||||||
# If we are making an embedding layer, then our quantization linear
|
# If we are making an embedding layer, then our quantization linear
|
||||||
# method must implement the embedding operation. If we are another
|
# method must implement the embedding operation. If we are another
|
||||||
# layer type like ParallelLMHead, this is not important.
|
# layer type like ParallelLMHead, this is not important.
|
||||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||||
type(linear_method))
|
type(quant_method))
|
||||||
if is_embedding_layer and not linear_method_implements_embedding:
|
if is_embedding_layer and not quant_method_implements_embedding:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The class {type(linear_method).__name__} must implement "
|
f"The class {type(quant_method).__name__} must implement "
|
||||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||||
|
|
||||||
self.linear_method: QuantizeMethodBase = linear_method
|
self.quant_method: QuantizeMethodBase = quant_method
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@ -260,7 +260,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
self.shard_indices.added_vocab_end_index -
|
self.shard_indices.added_vocab_end_index -
|
||||||
self.shard_indices.added_vocab_start_index)
|
self.shard_indices.added_vocab_start_index)
|
||||||
|
|
||||||
self.linear_method.create_weights(self,
|
self.quant_method.create_weights(self,
|
||||||
self.embedding_dim,
|
self.embedding_dim,
|
||||||
[self.num_embeddings_per_partition],
|
[self.num_embeddings_per_partition],
|
||||||
self.embedding_dim,
|
self.embedding_dim,
|
||||||
@ -412,7 +412,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
masked_input = input_
|
masked_input = input_
|
||||||
# Get the embeddings.
|
# Get the embeddings.
|
||||||
output_parallel = self.linear_method.embedding(self,
|
output_parallel = self.quant_method.embedding(self,
|
||||||
masked_input.long())
|
masked_input.long())
|
||||||
# Mask the output embedding.
|
# Mask the output embedding.
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user