[Hardware][XPU] AWQ/GPTQ support for xpu backend (#10107)

Signed-off-by: yan ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma 2024-11-19 02:18:05 +08:00 committed by GitHub
parent 281cc4b3cd
commit 6b2d25efc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 146 additions and 52 deletions

View File

@ -27,7 +27,7 @@ The table below shows the compatibility of various quantization implementations
- ✅︎ - ✅︎
- ✅︎ - ✅︎
- ✗ - ✗
- - ✅︎
- ✅︎ - ✅︎
- ✗ - ✗
- ✗ - ✗
@ -38,8 +38,8 @@ The table below shows the compatibility of various quantization implementations
- ✅︎ - ✅︎
- ✅︎ - ✅︎
- ✗ - ✗
- - ✅︎
- - ✅︎
- ✗ - ✗
- ✗ - ✗
* - Marlin (GPTQ/AWQ/FP8) * - Marlin (GPTQ/AWQ/FP8)

View File

@ -1,5 +1,5 @@
"""Test model set-up and inference for quantized HF models supported """Test model set-up and inference for quantized HF models supported
on the CPU backend using IPEX (including AWQ). on the CPU/GPU backend using IPEX (including AWQ/GPTQ).
Validating the configuration and printing results for manual checking. Validating the configuration and printing results for manual checking.
@ -11,13 +11,15 @@ import pytest
from vllm.platforms import current_platform from vllm.platforms import current_platform
MODELS = [ MODELS = [
"casperhansen/llama-3-8b-instruct-awq", "AMead10/Llama-3.2-1B-Instruct-AWQ",
"shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx
] ]
DTYPE = ["bfloat16"] DTYPE = ["bfloat16"]
@pytest.mark.skipif(not current_platform.is_cpu(), @pytest.mark.skipif(not current_platform.is_cpu()
reason="only supports the CPU backend.") and not current_platform.is_xpu(),
reason="only supports Intel CPU/XPU backend.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
def test_ipex_quant(vllm_runner, model, dtype): def test_ipex_quant(vllm_runner, model, dtype):

View File

@ -27,7 +27,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod" "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod"
] ]

View File

@ -210,7 +210,6 @@ class GPTQLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile # for torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)

View File

@ -23,6 +23,7 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
PackedColumnParameter, PackedColumnParameter,
PackedvLLMParameter, PackedvLLMParameter,
RowvLLMParameter) RowvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
@ -134,6 +135,9 @@ class GPTQMarlinConfig(QuantizationConfig):
sym = quant_config.get("sym") sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act") desc_act = quant_config.get("desc_act")
if not current_platform.is_cuda():
return False
if quant_method != "gptq": if quant_method != "gptq":
return False return False

View File

@ -2,21 +2,26 @@ from typing import Any, Dict, List, Optional
import torch import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
MIN_IPEX_VERSION = "2.5.0"
class IPEXConfig(QuantizationConfig): class IPEXConfig(QuantizationConfig):
"""INT8 quantization config class using IPEX for the CPU backend, """INT8 quantization config class using IPEX for the CPU/XPU backend,
including AWQ. including AWQ, GPTQ.
""" """
IPEX_QUANT_METHOD_MAP = { IPEX_QUANT_METHOD_MAP = {
"awq": 1, "awq": 1,
"gptq": 2, "gptq": 0,
} }
def __init__( def __init__(
@ -24,29 +29,30 @@ class IPEXConfig(QuantizationConfig):
method: str, method: str,
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
modules_to_not_convert: Optional[List[str]] = None,
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None: ) -> None:
self.method = method self.method = method
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.modules_to_not_convert = modules_to_not_convert or []
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits self.pack_factor = 32 // self.weight_bits
if self.weight_bits not in [4]: if self.weight_bits not in [4]:
raise ValueError(f"IPEX quantization supports weight bits [4], " raise ValueError(f"IPEX quantization supports weight bits [4], "
f"but got {self.weight_bits}.") f"but got {self.weight_bits}.")
if self.method == "awq": if self.method not in ["awq", "gptq"]:
self.quant_method = IPEXAWQLinearMethod raise ValueError(f"IPEX quantization supports [awq, gptq], "
else:
raise ValueError(f"IPEX quantization supports [awq], "
f"but got {self.method}.") f"but got {self.method}.")
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"IPEXConfig(method={self.method}" return (f"IPEXConfig(method={self.method},"
f"weight_bits={self.weight_bits}, " f"weight_bits={self.weight_bits}, "
f"group_size={self.group_size}") f"group_size={self.group_size})")
def get_ipex_quant_method_id(self) -> int:
return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method]
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
@ -70,19 +76,32 @@ class IPEXConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
method = cls.get_from_keys(config, ["quant_method"]).lower() method = cls.get_from_keys(config, ["quant_method"]).lower()
if method == "awq":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) group_size = cls.get_from_keys(config,
return cls(method, weight_bits, group_size) ["q_group_size", "group_size"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(method, weight_bits, group_size, modules_to_not_convert,
False, False)
# otherwise for gptq
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
return cls(method, weight_bits, group_size, [], desc_act,
lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]: user_quant) -> Optional[str]:
if not current_platform.is_cpu(): if not current_platform.is_cpu() and not current_platform.is_xpu():
return None return None
quant_method = hf_quant_cfg.get("quant_method", "").lower() quant_method = hf_quant_cfg.get("quant_method", "").lower()
if quant_method in ["awq"]: if quant_method in ["awq", "gptq"]:
return cls.get_name() return cls.get_name()
return None return None
@ -90,12 +109,81 @@ class IPEXConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]: prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return self.quant_method(self) if self.method == "awq":
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return IPEXAWQLinearMethod(self)
if self.method == "gptq":
return IPEXGPTQLinearMethod(self)
return None return None
class IPEXGPTQLinearMethod(GPTQLinearMethod):
"""GPTQ linear method using IPEX for the CPU/XPU backend.
"""
def __init__(self, quant_config: IPEXConfig):
self.quant_config = quant_config # type: ignore
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
bias = layer.bias if not layer.skip_bias_add else None
try:
import intel_extension_for_pytorch as ipex
if ipex.__version__ < MIN_IPEX_VERSION:
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
except ImportError as err:
raise ImportError(
"Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method.") from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode = ipex.quantization.WoqLowpMode.INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
act_quant_mode=act_quant_mode,
group_size=self.quant_config.group_size,
)
layer.ipex_output_size = layer.qweight.shape[-1]
g_idx = layer.g_idx if self.quant_config.desc_act else None
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
IPEXWeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
g_idx=g_idx,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"]
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
if bias is not None:
out.add_(bias)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
class IPEXAWQLinearMethod(AWQLinearMethod): class IPEXAWQLinearMethod(AWQLinearMethod):
"""AWQ linear method using IPEX for the CPU backend. """AWQ linear method using IPEX for the CPU/XPU backend.
""" """
def __init__(self, quant_config: IPEXConfig): def __init__(self, quant_config: IPEXConfig):
@ -108,15 +196,16 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if ipex.__version__ < "2.4.0": if ipex.__version__ < MIN_IPEX_VERSION:
raise ImportError("intel_extension_for_pytorch version is " raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install " "wrong. Please install "
"intel_extension_for_pytorch>=2.4.0.") f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
except ImportError as err: except ImportError as err:
raise ImportError( raise ImportError(
"Please install " "Please install "
"intel_extension_for_pytorch>=2.4.0 via " f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
"`pip install intel_extension_for_pytorch>=2.4.0`" f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method.") from err " to use IPEX-AWQ linear method.") from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions # Using the compute dtype (lowp_mode) as INT8 to leverage instructions
@ -136,8 +225,8 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
layer.ipex_output_size = layer.qweight.size( layer.ipex_output_size = layer.qweight.size(
1) * self.quant_config.pack_factor 1) * self.quant_config.pack_factor
layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\ layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
WeightOnlyQuantizedLinear.from_weight( IPEXWeightOnlyQuantizedLinear.from_weight(
layer.qweight, layer.qweight,
layer.scales, layer.scales,
layer.qzeros, layer.qzeros,
@ -146,8 +235,7 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
qconfig=qconfig, qconfig=qconfig,
bias=bias, bias=bias,
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
quant_method= quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore
self.quant_config.get_ipex_quant_method_id() # type: ignore
) )
def apply(self, def apply(self,
@ -156,5 +244,4 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x) out = layer.ipex_qlinear(reshaped_x)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))

View File

@ -29,6 +29,8 @@ from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ReplicatedLinear, from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator) serialize_vllm_model, tensorizer_weights_iterator)
@ -348,7 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None: if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading # When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters # (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the # to be on the global target device. This scope is for the