mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 11:43:05 +08:00
[Quantization] Improve AWQ logic (#19431)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
c9280e6346
commit
73e2e0118f
@ -1,19 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
@ -74,12 +78,42 @@ class AWQConfig(QuantizationConfig):
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
# Lazy import to avoid circular import.
|
||||
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .utils.marlin_utils import check_moe_marlin_supports_layer
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"zero_point": self.zero_point,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix)
|
||||
marlin_compatible_config_dict = {
|
||||
"quant_method": "awq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"zero_point": self.zero_point,
|
||||
"lm_head": False,
|
||||
"modules_to_not_convert": self.modules_to_not_convert,
|
||||
}
|
||||
awq_marlin_config = AWQMarlinConfig.from_config(
|
||||
marlin_compatible_config_dict)
|
||||
return AWQMoEMethod(awq_marlin_config)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user