mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:06:19 +08:00
[Quantization] Allow GGUF quantization to skip unquantized layer (#23188)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
cd7a3df26f
commit
4645024d3a
@ -13,7 +13,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
@ -28,8 +29,10 @@ logger = init_logger(__name__)
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
def __init__(self,
|
||||
unquantized_modules: Optional[list[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.unquantized_modules = unquantized_modules or []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("GGUFConfig()")
|
||||
@ -55,6 +58,8 @@ class GGUFConfig(QuantizationConfig):
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return GGUFLinearMethod(self)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
return GGUFEmbeddingMethod(self)
|
||||
@ -63,6 +68,10 @@ class GGUFConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
|
||||
return any(module_name in prefix for module_name in unquantized_modules)
|
||||
|
||||
|
||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
WeightType.Q4_0,
|
||||
|
||||
@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_gguf_extra_tensor_names, gguf_quant_weights_iterator)
|
||||
get_gguf_extra_tensor_names, get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator)
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
@ -132,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
local_model_path, gguf_weights_map):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
weight_type_map = get_gguf_weight_type_map(model_config.model,
|
||||
gguf_weights_map)
|
||||
|
||||
# filter out unquantized modules to skip
|
||||
unquant_names = [
|
||||
name.removesuffix(".weight")
|
||||
for name, weight_type in weight_type_map.items()
|
||||
if weight_type == "F32" and name.endswith(".weight")
|
||||
]
|
||||
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
|
||||
@ -563,6 +563,18 @@ def get_gguf_extra_tensor_names(
|
||||
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
||||
|
||||
|
||||
def get_gguf_weight_type_map(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Return GGUF mapped weight's name and its quant type
|
||||
"""
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
return {
|
||||
gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name
|
||||
for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map
|
||||
}
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user