From 96846bb3607370798540c7d325f8d06dbd67dcf4 Mon Sep 17 00:00:00 2001 From: mobicham <37179323+mobicham@users.noreply.github.com> Date: Thu, 12 Jun 2025 16:22:53 +0200 Subject: [PATCH] Fix TorchAOConfig skip layers (#19265) Signed-off-by: mobicham --- tests/quantization/test_torchao.py | 15 +++++ .../layers/quantization/torchao.py | 64 +++++++++++++++++-- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index 54ec59585450..eef3568efea1 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -60,5 +60,20 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner): print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): + torch._dynamo.reset() + model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao" + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0") as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index a7d9332032a2..af50b45d44b7 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -17,11 +17,30 @@ from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) +def should_skip(prefix: str, skip_modules: list[str]) -> bool: + """ + Robust skipping logic: + should_skip("model.model.layers.1.q_proj", + ["model.model.layers.1.q_proj"]) # True + should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True + should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True + should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True + should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False + """ + for s in skip_modules: + if prefix == s: + return True + if f".{s}." in f".{prefix}.": + return True + return False + + class TorchAOConfig(QuantizationConfig): """Config class for torchao.""" - def __init__(self, torchao_config) -> None: - self.torchao_config = torchao_config + def __init__(self, + torchao_config, + skip_modules: Optional[list[str]] = None) -> None: """ # TorchAO quantization relies on tensor subclasses. In order, # to enable proper caching this needs standalone compile @@ -36,6 +55,8 @@ class TorchAOConfig(QuantizationConfig): os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") """ + self.torchao_config = torchao_config + self.skip_modules = skip_modules or [] def __repr__(self) -> str: return f"TorchAOConfig({self.torchao_config})" @@ -67,11 +88,28 @@ class TorchAOConfig(QuantizationConfig): hf_config = cls.get_from_keys_or(config, ["quant_type"], None) assert hf_config is not None, "quant_type must be specified" - assert (len(hf_config) == 1 and "default" in hf_config - ), "Expected only one key 'default' in quant_type dictionary" + assert len(hf_config) == 1 and "default" in hf_config, ( + "Expected only one key 'default' in quant_type dictionary") quant_type = hf_config["default"] ao_config = config_from_dict(quant_type) - return cls(ao_config) + + # Adds skipped modules defined in "modules_to_not_convert" + skip_modules = config.get("modules_to_not_convert", []) or [] + + # Adds skipped modules defined in "module_fqn_to_config" + _data = quant_type.get("_data", {}) + if not isinstance(_data, dict): + _data = {} + + module_fqn = _data.get("module_fqn_to_config", {}) + if not isinstance(module_fqn, dict): + module_fqn = {} + + for layer, layer_cfg in module_fqn.items(): + if layer_cfg is None: + skip_modules.append(layer) + + return cls(ao_config, skip_modules) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -80,13 +118,16 @@ class TorchAOConfig(QuantizationConfig): from torchao.quantization import ModuleFqnToConfig + if should_skip(prefix, self.skip_modules): + return UnquantizedLinearMethod() + module_fqn = prefix if isinstance(self.torchao_config, ModuleFqnToConfig): module_fqn_to_config = self.torchao_config.module_fqn_to_config c = module_fqn_to_config.get( module_fqn) or module_fqn_to_config.get("_default", None) if c is not None: - current_torchao_config = TorchAOConfig(c) + current_torchao_config = TorchAOConfig(c, self.skip_modules) return TorchAOLinearMethod(current_torchao_config) else: return UnquantizedLinearMethod() @@ -108,8 +149,17 @@ def torchao_quantize_param_data(param: torch.Tensor, """ from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ + assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" - dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) + """ + Avoid real weight allocation for faster load, since we will + end up setting it to param. + """ + with torch.device("meta"): + dummy_linear = torch.nn.Linear(param.shape[1], + param.shape[0], + bias=False) + dummy_linear.weight = param quantize_(dummy_linear, torchao_config) return dummy_linear.weight