diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 49d28927d6e7..90222f2e3b0e 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -13,7 +13,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -28,8 +29,10 @@ logger = init_logger(__name__) class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" - def __init__(self, ) -> None: + def __init__(self, + unquantized_modules: Optional[list[str]] = None) -> None: super().__init__() + self.unquantized_modules = unquantized_modules or [] def __repr__(self) -> str: return ("GGUFConfig()") @@ -55,6 +58,8 @@ class GGUFConfig(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): + if is_layer_skipped_gguf(prefix, self.unquantized_modules): + return UnquantizedLinearMethod() return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): return GGUFEmbeddingMethod(self) @@ -63,6 +68,10 @@ class GGUFConfig(QuantizationConfig): return None +def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]): + return any(module_name in prefix for module_name in unquantized_modules) + + UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} STANDARD_QUANT_TYPES = { WeightType.Q4_0, diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 21655b0c69bb..9877cb3b7c06 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( - get_gguf_extra_tensor_names, gguf_quant_weights_iterator) + get_gguf_extra_tensor_names, get_gguf_weight_type_map, + gguf_quant_weights_iterator) class GGUFModelLoader(BaseModelLoader): @@ -132,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader): local_model_path, gguf_weights_map): model_config.hf_config.update({"tie_word_embeddings": True}) + weight_type_map = get_gguf_weight_type_map(model_config.model, + gguf_weights_map) + + # filter out unquantized modules to skip + unquant_names = [ + name.removesuffix(".weight") + for name, weight_type in weight_type_map.items() + if weight_type == "F32" and name.endswith(".weight") + ] + vllm_config.quant_config.unquantized_modules.extend(unquant_names) + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 7053c5bc515c..3bb47f82d2f3 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -563,6 +563,18 @@ def get_gguf_extra_tensor_names( return [gguf_to_hf_name_map[key] for key in extra_keys] +def get_gguf_weight_type_map( + gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]: + """ + Return GGUF mapped weight's name and its quant type + """ + reader = gguf.GGUFReader(gguf_file) + return { + gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name + for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map + } + + def gguf_quant_weights_iterator( gguf_file: str, gguf_to_hf_name_map: dict[str, str] ) -> Generator[tuple[str, torch.Tensor], None, None]: