mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-19 01:14:31 +08:00
Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: mgoin <michael@neuralmagic.com>
390 lines
16 KiB
Python
390 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import fnmatch
|
|
import re
|
|
from typing import Any, Dict, List, Optional, cast
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|
UnquantizedLinearMethod)
|
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
|
QuarkMoEMethod)
|
|
from vllm.model_executor.layers.quantization.quark.schemes import (
|
|
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
|
|
from vllm.model_executor.layers.quantization.quark.utils import (
|
|
deep_compare, should_ignore_layer)
|
|
from vllm.platforms import current_platform
|
|
|
|
__all__ = ["QuarkLinearMethod"]
|
|
|
|
|
|
class QuarkConfig(QuantizationConfig):
|
|
|
|
def __init__(self,
|
|
quant_config: Dict[str, Any],
|
|
kv_cache_group: Optional[List[str]] = None,
|
|
kv_cache_config: Optional[Dict[str, Any]] = None,
|
|
pack_method: str = "reorder"):
|
|
if kv_cache_group is None:
|
|
kv_cache_group = []
|
|
self.quant_config = quant_config
|
|
self.kv_cache_group = kv_cache_group
|
|
self.kv_cache_config = kv_cache_config
|
|
self.pack_method = pack_method
|
|
|
|
def get_linear_method(self) -> "QuarkLinearMethod":
|
|
return QuarkLinearMethod(self)
|
|
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
return [torch.float16, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 70
|
|
|
|
def get_name(self) -> str:
|
|
return "quark"
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
from vllm.attention.layer import Attention # Avoid circular import
|
|
|
|
# Check if the layer is skipped for quantization.
|
|
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
|
|
if should_ignore_layer(prefix,
|
|
ignore=exclude_layers,
|
|
fused_mapping=self.packed_modules_mapping):
|
|
return UnquantizedLinearMethod()
|
|
if isinstance(layer, LinearBase):
|
|
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
|
layer.scheme = scheme
|
|
return QuarkLinearMethod(self)
|
|
if isinstance(layer, Attention):
|
|
return QuarkKVCacheMethod(self)
|
|
if isinstance(layer, FusedMoE):
|
|
return QuarkMoEMethod.get_moe_method(self,
|
|
module=layer,
|
|
layer_name=prefix)
|
|
return None
|
|
|
|
@classmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
|
|
export_config = config.get("export")
|
|
if export_config is None:
|
|
raise ValueError("The export key should be included in "
|
|
"the configurations of Quark quantized model")
|
|
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
|
|
pack_method = cast(str, export_config.get("pack_method"))
|
|
|
|
# In the export model of quark, the quantization configuration
|
|
# of kv_cache is stored in layer_quant_config. First, it is
|
|
# judged whether kv_cache_group exists, and then it is judged
|
|
# whether layer_quant_config has a quantization configuration
|
|
# that matches kv_cache.
|
|
if len(kv_cache_group) == 0:
|
|
kv_cache_config = None
|
|
else:
|
|
kv_cache_set = set(kv_cache_group)
|
|
layer_quant_config = cast(Dict[str, Any],
|
|
config.get("layer_quant_config"))
|
|
layer_quant_names = list(layer_quant_config.keys())
|
|
layer_quant_set = set(layer_quant_names)
|
|
|
|
if not kv_cache_set.issubset(layer_quant_set):
|
|
raise ValueError("The Quark quantized model has the "
|
|
"kv_cache_group parameter setting, "
|
|
"but no kv_cache quantization settings "
|
|
"were found in the quantization "
|
|
"configuration.")
|
|
|
|
q_configs = [
|
|
cast(Dict[str, Any], layer_quant_config.get(name))
|
|
for name in kv_cache_group
|
|
]
|
|
if not all(
|
|
deep_compare(q_config, q_configs[0])
|
|
for q_config in q_configs):
|
|
raise ValueError(
|
|
"The quantization method used for kv_cache should "
|
|
"be the same, but the quantization method for the "
|
|
"kv_cache layer in the config is different.")
|
|
kv_cache_config = q_configs[0].get("output_tensors")
|
|
if kv_cache_config is None:
|
|
raise ValueError(
|
|
"The kv_cache quantization configuration is empty.")
|
|
|
|
# Since we have already set kv_cache quantization configurations,
|
|
# we will remove the quantization configuration for the
|
|
# output_tensors corresponding to the kv_cache layer.
|
|
for q_config in q_configs:
|
|
q_config["output_tensors"] = None
|
|
|
|
return cls(quant_config=config,
|
|
kv_cache_group=kv_cache_group,
|
|
kv_cache_config=kv_cache_config,
|
|
pack_method=pack_method)
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> List[str]:
|
|
return []
|
|
|
|
def _check_scheme_supported(self,
|
|
min_capability: int,
|
|
error: bool = True) -> bool:
|
|
capability_tuple = current_platform.get_device_capability()
|
|
|
|
if capability_tuple is not None:
|
|
capability = capability_tuple.to_int()
|
|
supported = capability >= min_capability
|
|
if error and not supported:
|
|
raise RuntimeError(
|
|
"Quantization scheme is not supported for ",
|
|
f"the current GPU. Min capability: {min_capability}. ",
|
|
f"Current capability: {capability}.")
|
|
return supported
|
|
else:
|
|
return False
|
|
|
|
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
|
input_quant: Optional[Dict[str, Any]]) -> bool:
|
|
# Confirm weights and input quantized.
|
|
if weight_quant is None or input_quant is None:
|
|
return False
|
|
|
|
# Confirm weight scheme is supported
|
|
is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3"
|
|
and input_quant.get("dtype") == "fp8_e4m3")
|
|
is_static_weight = not weight_quant.get("is_dynamic")
|
|
is_per_tensor_or_channel_weight = (weight_quant.get("qscheme")
|
|
in ["per_tensor", "per_channel"])
|
|
|
|
if not (is_fp8_dtype and is_static_weight
|
|
and is_per_tensor_or_channel_weight):
|
|
return False
|
|
|
|
# Dynamic quantization is always supported if weights supported.
|
|
if input_quant.get("is_dynamic"):
|
|
return True
|
|
|
|
# Confirm activation scheme is supported.
|
|
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
|
|
return is_per_tensor_activation
|
|
|
|
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
|
input_quant: Optional[Dict[str, Any]]) -> bool:
|
|
# Confirm weights and input quantized.
|
|
if weight_quant is None or input_quant is None:
|
|
return False
|
|
|
|
is_int8_dtype = (weight_quant.get("dtype") == "int8"
|
|
and input_quant.get("dtype") == "int8")
|
|
|
|
is_tensor = (weight_quant.get("qscheme")
|
|
in ["per_tensor", "per_channel"]
|
|
and input_quant.get("qscheme") == "per_tensor")
|
|
|
|
is_static = (not weight_quant.get("is_dynamic")
|
|
and not input_quant.get("is_dynamic"))
|
|
|
|
is_weight_symmetric = (weight_quant.get("symmetric") is True)
|
|
|
|
# Both symmetric and asymmetric input quantization supported.
|
|
# Only symmetric weight quantization supported.
|
|
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
|
|
|
def _find_matched_config(self, layer_name: str,
|
|
module: torch.nn.Module) -> Dict[str, Any]:
|
|
|
|
proj_name = layer_name.split(".")[-1]
|
|
if proj_name in self.packed_modules_mapping:
|
|
shard_proj_names = self.packed_modules_mapping[proj_name]
|
|
|
|
# Convert fused_name --> [shard_names]
|
|
shard_names = [
|
|
layer_name.replace(proj_name, shard_proj_name)
|
|
for shard_proj_name in shard_proj_names
|
|
]
|
|
shard_configs = [
|
|
self._find_matched_config(shard_name, module)
|
|
for shard_name in shard_names
|
|
]
|
|
if not all(
|
|
deep_compare(q_config, shard_configs[0])
|
|
for q_config in shard_configs):
|
|
raise ValueError(
|
|
f"Found a different quantization configuration for "
|
|
f"{shard_proj_names} in {layer_name}. vLLM "
|
|
"requires all to use the same scheme.")
|
|
return shard_configs[0]
|
|
else:
|
|
layer_quant_config = cast(
|
|
Dict[str, Any], self.quant_config.get("layer_quant_config"))
|
|
for name_pattern in layer_quant_config:
|
|
if fnmatch.fnmatch(layer_name, name_pattern):
|
|
return layer_quant_config[name_pattern]
|
|
|
|
layer_type = cast(str, type(module))
|
|
layer_type_quant_config = cast(
|
|
Dict[str, Any],
|
|
self.quant_config.get("layer_type_quant_config"))
|
|
if layer_type in layer_type_quant_config:
|
|
return layer_type_quant_config[layer_type]
|
|
|
|
global_quant_config = cast(
|
|
Dict[str, Any], self.quant_config.get("global_quant_config"))
|
|
return global_quant_config
|
|
|
|
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
|
|
if config.get("output_tensors") or config.get("bias"):
|
|
raise NotImplementedError(
|
|
"Currently, Quark models with output_tensors "
|
|
"and bias quantized are not supported")
|
|
weight_config = cast(Dict[str, Any], config.get("weight"))
|
|
input_config = cast(Dict[str, Any], config.get("input_tensors"))
|
|
|
|
if self._is_fp8_w8a8(weight_config, input_config):
|
|
is_fp8_w8a8_supported = self._check_scheme_supported(
|
|
QuarkW8A8Fp8.get_min_capability(), error=False)
|
|
if is_fp8_w8a8_supported:
|
|
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
|
input_static = (input_config is not None and
|
|
not cast(bool, input_config.get("is_dynamic")))
|
|
return QuarkW8A8Fp8(qscheme=weight_qscheme,
|
|
is_static_input_scheme=input_static)
|
|
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
|
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
|
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
|
is_static_input_scheme=True,
|
|
input_symmetric=input_config.get("symmetric"))
|
|
|
|
raise NotImplementedError("No quark compatible scheme was found. "
|
|
f"Weight config: {weight_config}, "
|
|
f"Input config: {input_config}")
|
|
|
|
def get_scheme(self, layer: torch.nn.Module,
|
|
layer_name: str) -> "QuarkScheme":
|
|
|
|
layer_quant_config = self._find_matched_config(layer_name, layer)
|
|
|
|
# Find the quant_scheme
|
|
scheme = self._get_scheme_from_config(layer_quant_config)
|
|
# Raise error if device does not support the scheme
|
|
# (e.g. fp8 needs ada lovelace)
|
|
self._check_scheme_supported(scheme.get_min_capability())
|
|
|
|
return scheme
|
|
|
|
def get_cache_scale(self, name: str) -> Optional[str]:
|
|
"""
|
|
Check whether the param name matches the format for k/v cache scales
|
|
in quark. If this is the case, return its equivalent param name
|
|
expected by vLLM
|
|
|
|
:param name: param name
|
|
:return: matching param name for KV cache scale in vLLM
|
|
"""
|
|
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
|
|
return None
|
|
|
|
kv_proj_names = [
|
|
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
|
|
]
|
|
if name.endswith(".output_scale"):
|
|
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
|
|
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
|
|
return name.replace(kv_output_scale_name, ".attn.k_scale")
|
|
|
|
elif len(kv_proj_names) == 2:
|
|
for kv_proj_name in kv_proj_names:
|
|
if kv_proj_name in name and kv_proj_name == "k_proj":
|
|
return name.replace(".k_proj.output_scale",
|
|
".attn.k_scale")
|
|
elif kv_proj_name in name and kv_proj_name == "v_proj":
|
|
return name.replace(".v_proj.output_scale",
|
|
".attn.v_scale")
|
|
|
|
# If no matches, return None
|
|
return None
|
|
|
|
|
|
class QuarkLinearMethod(LinearMethodBase):
|
|
|
|
def __init__(self, quantization_config: QuarkConfig):
|
|
self.quantization_config = quantization_config
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
layer.scheme.process_weights_after_loading(layer)
|
|
|
|
def create_weights(self, layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: List[int], input_size: int,
|
|
output_size: int, params_dtype: torch.dtype,
|
|
**extra_weight_attrs):
|
|
"""
|
|
Use the CompressedTensorsScheme associated with each layer to create
|
|
the necessary parameters for the layer. See LinearMethodBase for param
|
|
details
|
|
"""
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
|
layer.scheme.create_weights(
|
|
layer=layer,
|
|
input_size=input_size,
|
|
input_size_per_partition=input_size_per_partition,
|
|
output_partition_sizes=output_partition_sizes,
|
|
output_size=output_size,
|
|
params_dtype=params_dtype,
|
|
weight_loader=weight_loader)
|
|
|
|
def apply(self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None):
|
|
"""
|
|
Use the output of create_weights and the CompressedTensorsScheme
|
|
associated with the layer to apply the forward pass with the
|
|
layer input. See LinearMethodBase for param details
|
|
|
|
"""
|
|
scheme = layer.scheme
|
|
if scheme is None:
|
|
raise ValueError("A scheme must be defined for each layer")
|
|
return scheme.apply_weights(layer, x, bias=bias)
|
|
|
|
|
|
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
|
"""
|
|
Supports loading kv-cache scaling factors from quark checkpoints.
|
|
"""
|
|
|
|
def __init__(self, quant_config: QuarkConfig):
|
|
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
|
super().__init__(quant_config)
|
|
|
|
@staticmethod
|
|
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]):
|
|
"""
|
|
Validator for the kv cache configuration. Useful for controlling the
|
|
kv cache quantization schemes, that are being supported in vLLM
|
|
:param kv_cache_config: the quark kv cache scheme
|
|
"""
|
|
if kv_cache_config is None:
|
|
return
|
|
|
|
dtype = kv_cache_config.get("dtype")
|
|
if dtype != "fp8_e4m3":
|
|
raise NotImplementedError(
|
|
"Currently supported kv cache quantization is "
|
|
f"dtype=fp8_e4m3, however received {dtype}")
|
|
|
|
qscheme = kv_cache_config.get("qscheme")
|
|
if qscheme != "per_tensor":
|
|
raise NotImplementedError(
|
|
"Only support per-tensor scaling factor "
|
|
"for quark KV cache. "
|
|
f"Expected qscheme: per_tensor, found qscheme: {qscheme}")
|