diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 1bb698faf46df..f59e5e2a0af7a 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.quark.utils import ( deep_compare, should_ignore_layer, ) +from vllm.model_executor.models.utils import WeightsMapper from vllm.platforms import current_platform if TYPE_CHECKING: @@ -57,7 +58,6 @@ class QuarkConfig(QuantizationConfig): self.kv_cache_group = kv_cache_group self.kv_cache_config = kv_cache_config self.pack_method = pack_method - self.ignore: list[str] = cast(list[str], self.quant_config.get("exclude", [])) def get_linear_method(self) -> "QuarkLinearMethod": return QuarkLinearMethod(self) @@ -72,14 +72,42 @@ class QuarkConfig(QuantizationConfig): def get_name(self) -> QuantizationMethods: return "quark" + def apply_vllm_mapper( # noqa: B027 + self, hf_to_vllm_mapper: "WeightsMapper" + ): + """ + Interface for models to update module names referenced in + quantization configs in order to reflect the vllm model structure + + :param hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure + """ + quant_config_with_hf_to_vllm_mapper = {} + + for k, v in self.quant_config.items(): + if isinstance(v, list): + quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_list(v) + elif isinstance(v, dict): + quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_dict(v) + else: + if isinstance(v, str): + mapped_v_list = hf_to_vllm_mapper.apply_list([v]) + if mapped_v_list: + quant_config_with_hf_to_vllm_mapper[k] = mapped_v_list[0] + else: + quant_config_with_hf_to_vllm_mapper[k] = v + + self.quant_config = quant_config_with_hf_to_vllm_mapper + def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import # Check if the layer is skipped for quantization. + exclude_layers = cast(list[str], self.quant_config.get("exclude")) if should_ignore_layer( - prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping + prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): @@ -93,9 +121,6 @@ class QuarkConfig(QuantizationConfig): return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix) return None - def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) - @classmethod def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": export_config = config.get("export")