mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 19:45:39 +08:00
[Quantization] Allow GGUF quantization to skip unquantized layer (#23188)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
cd7a3df26f
commit
4645024d3a
@ -13,7 +13,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
FusedMoEConfig,
|
FusedMoEConfig,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
@ -28,8 +29,10 @@ logger = init_logger(__name__)
|
|||||||
class GGUFConfig(QuantizationConfig):
|
class GGUFConfig(QuantizationConfig):
|
||||||
"""Config class for GGUF."""
|
"""Config class for GGUF."""
|
||||||
|
|
||||||
def __init__(self, ) -> None:
|
def __init__(self,
|
||||||
|
unquantized_modules: Optional[list[str]] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.unquantized_modules = unquantized_modules or []
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return ("GGUFConfig()")
|
return ("GGUFConfig()")
|
||||||
@ -55,6 +58,8 @@ class GGUFConfig(QuantizationConfig):
|
|||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
|
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
return GGUFLinearMethod(self)
|
return GGUFLinearMethod(self)
|
||||||
elif isinstance(layer, VocabParallelEmbedding):
|
elif isinstance(layer, VocabParallelEmbedding):
|
||||||
return GGUFEmbeddingMethod(self)
|
return GGUFEmbeddingMethod(self)
|
||||||
@ -63,6 +68,10 @@ class GGUFConfig(QuantizationConfig):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
|
||||||
|
return any(module_name in prefix for module_name in unquantized_modules)
|
||||||
|
|
||||||
|
|
||||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||||
STANDARD_QUANT_TYPES = {
|
STANDARD_QUANT_TYPES = {
|
||||||
WeightType.Q4_0,
|
WeightType.Q4_0,
|
||||||
|
|||||||
@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
|||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import (
|
||||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
get_gguf_extra_tensor_names, gguf_quant_weights_iterator)
|
get_gguf_extra_tensor_names, get_gguf_weight_type_map,
|
||||||
|
gguf_quant_weights_iterator)
|
||||||
|
|
||||||
|
|
||||||
class GGUFModelLoader(BaseModelLoader):
|
class GGUFModelLoader(BaseModelLoader):
|
||||||
@ -132,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
local_model_path, gguf_weights_map):
|
local_model_path, gguf_weights_map):
|
||||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||||
|
|
||||||
|
weight_type_map = get_gguf_weight_type_map(model_config.model,
|
||||||
|
gguf_weights_map)
|
||||||
|
|
||||||
|
# filter out unquantized modules to skip
|
||||||
|
unquant_names = [
|
||||||
|
name.removesuffix(".weight")
|
||||||
|
for name, weight_type in weight_type_map.items()
|
||||||
|
if weight_type == "F32" and name.endswith(".weight")
|
||||||
|
]
|
||||||
|
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
|
||||||
|
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
|
|||||||
@ -563,6 +563,18 @@ def get_gguf_extra_tensor_names(
|
|||||||
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
||||||
|
|
||||||
|
|
||||||
|
def get_gguf_weight_type_map(
|
||||||
|
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Return GGUF mapped weight's name and its quant type
|
||||||
|
"""
|
||||||
|
reader = gguf.GGUFReader(gguf_file)
|
||||||
|
return {
|
||||||
|
gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name
|
||||||
|
for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def gguf_quant_weights_iterator(
|
def gguf_quant_weights_iterator(
|
||||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
|
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
|
||||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user