Fix TorchAOConfig skip layers (#19265)

Signed-off-by: mobicham <hicham@mobiuslabs.com>
This commit is contained in:
mobicham 2025-06-12 16:22:53 +02:00 committed by GitHub
parent b6efafd9e4
commit 96846bb360
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 7 deletions

View File

@ -60,5 +60,20 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
print(output)
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -17,11 +17,30 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
def should_skip(prefix: str, skip_modules: list[str]) -> bool:
"""
Robust skipping logic:
should_skip("model.model.layers.1.q_proj",
["model.model.layers.1.q_proj"]) # True
should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
"""
for s in skip_modules:
if prefix == s:
return True
if f".{s}." in f".{prefix}.":
return True
return False
class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""
def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
def __init__(self,
torchao_config,
skip_modules: Optional[list[str]] = None) -> None:
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
@ -36,6 +55,8 @@ class TorchAOConfig(QuantizationConfig):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
self.torchao_config = torchao_config
self.skip_modules = skip_modules or []
def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
@ -67,11 +88,28 @@ class TorchAOConfig(QuantizationConfig):
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
assert hf_config is not None, "quant_type must be specified"
assert (len(hf_config) == 1 and "default" in hf_config
), "Expected only one key 'default' in quant_type dictionary"
assert len(hf_config) == 1 and "default" in hf_config, (
"Expected only one key 'default' in quant_type dictionary")
quant_type = hf_config["default"]
ao_config = config_from_dict(quant_type)
return cls(ao_config)
# Adds skipped modules defined in "modules_to_not_convert"
skip_modules = config.get("modules_to_not_convert", []) or []
# Adds skipped modules defined in "module_fqn_to_config"
_data = quant_type.get("_data", {})
if not isinstance(_data, dict):
_data = {}
module_fqn = _data.get("module_fqn_to_config", {})
if not isinstance(module_fqn, dict):
module_fqn = {}
for layer, layer_cfg in module_fqn.items():
if layer_cfg is None:
skip_modules.append(layer)
return cls(ao_config, skip_modules)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
@ -80,13 +118,16 @@ class TorchAOConfig(QuantizationConfig):
from torchao.quantization import ModuleFqnToConfig
if should_skip(prefix, self.skip_modules):
return UnquantizedLinearMethod()
module_fqn = prefix
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(c)
current_torchao_config = TorchAOConfig(c, self.skip_modules)
return TorchAOLinearMethod(current_torchao_config)
else:
return UnquantizedLinearMethod()
@ -108,8 +149,17 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
"""
Avoid real weight allocation for faster load, since we will
end up setting it to param.
"""
with torch.device("meta"):
dummy_linear = torch.nn.Linear(param.shape[1],
param.shape[0],
bias=False)
dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight