mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
Fix TorchAOConfig skip layers (#19265)
Signed-off-by: mobicham <hicham@mobiuslabs.com>
This commit is contained in:
parent
b6efafd9e4
commit
96846bb360
@ -60,5 +60,20 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
|
|||||||
print(output)
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
|
def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
|
||||||
|
torch._dynamo.reset()
|
||||||
|
model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao"
|
||||||
|
with vllm_runner(model_name=model_name,
|
||||||
|
quantization="torchao",
|
||||||
|
dtype="bfloat16",
|
||||||
|
pt_load_map_location="cuda:0") as llm:
|
||||||
|
output = llm.generate_greedy(["The capital of France is"],
|
||||||
|
max_tokens=32)
|
||||||
|
|
||||||
|
assert output
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
@ -17,11 +17,30 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def should_skip(prefix: str, skip_modules: list[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Robust skipping logic:
|
||||||
|
should_skip("model.model.layers.1.q_proj",
|
||||||
|
["model.model.layers.1.q_proj"]) # True
|
||||||
|
should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
|
||||||
|
should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
|
||||||
|
should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
|
||||||
|
should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
|
||||||
|
"""
|
||||||
|
for s in skip_modules:
|
||||||
|
if prefix == s:
|
||||||
|
return True
|
||||||
|
if f".{s}." in f".{prefix}.":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TorchAOConfig(QuantizationConfig):
|
class TorchAOConfig(QuantizationConfig):
|
||||||
"""Config class for torchao."""
|
"""Config class for torchao."""
|
||||||
|
|
||||||
def __init__(self, torchao_config) -> None:
|
def __init__(self,
|
||||||
self.torchao_config = torchao_config
|
torchao_config,
|
||||||
|
skip_modules: Optional[list[str]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
# TorchAO quantization relies on tensor subclasses. In order,
|
# TorchAO quantization relies on tensor subclasses. In order,
|
||||||
# to enable proper caching this needs standalone compile
|
# to enable proper caching this needs standalone compile
|
||||||
@ -36,6 +55,8 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
||||||
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
||||||
"""
|
"""
|
||||||
|
self.torchao_config = torchao_config
|
||||||
|
self.skip_modules = skip_modules or []
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"TorchAOConfig({self.torchao_config})"
|
return f"TorchAOConfig({self.torchao_config})"
|
||||||
@ -67,11 +88,28 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
|
|
||||||
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
|
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
|
||||||
assert hf_config is not None, "quant_type must be specified"
|
assert hf_config is not None, "quant_type must be specified"
|
||||||
assert (len(hf_config) == 1 and "default" in hf_config
|
assert len(hf_config) == 1 and "default" in hf_config, (
|
||||||
), "Expected only one key 'default' in quant_type dictionary"
|
"Expected only one key 'default' in quant_type dictionary")
|
||||||
quant_type = hf_config["default"]
|
quant_type = hf_config["default"]
|
||||||
ao_config = config_from_dict(quant_type)
|
ao_config = config_from_dict(quant_type)
|
||||||
return cls(ao_config)
|
|
||||||
|
# Adds skipped modules defined in "modules_to_not_convert"
|
||||||
|
skip_modules = config.get("modules_to_not_convert", []) or []
|
||||||
|
|
||||||
|
# Adds skipped modules defined in "module_fqn_to_config"
|
||||||
|
_data = quant_type.get("_data", {})
|
||||||
|
if not isinstance(_data, dict):
|
||||||
|
_data = {}
|
||||||
|
|
||||||
|
module_fqn = _data.get("module_fqn_to_config", {})
|
||||||
|
if not isinstance(module_fqn, dict):
|
||||||
|
module_fqn = {}
|
||||||
|
|
||||||
|
for layer, layer_cfg in module_fqn.items():
|
||||||
|
if layer_cfg is None:
|
||||||
|
skip_modules.append(layer)
|
||||||
|
|
||||||
|
return cls(ao_config, skip_modules)
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
@ -80,13 +118,16 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
|
|
||||||
from torchao.quantization import ModuleFqnToConfig
|
from torchao.quantization import ModuleFqnToConfig
|
||||||
|
|
||||||
|
if should_skip(prefix, self.skip_modules):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
|
||||||
module_fqn = prefix
|
module_fqn = prefix
|
||||||
if isinstance(self.torchao_config, ModuleFqnToConfig):
|
if isinstance(self.torchao_config, ModuleFqnToConfig):
|
||||||
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
||||||
c = module_fqn_to_config.get(
|
c = module_fqn_to_config.get(
|
||||||
module_fqn) or module_fqn_to_config.get("_default", None)
|
module_fqn) or module_fqn_to_config.get("_default", None)
|
||||||
if c is not None:
|
if c is not None:
|
||||||
current_torchao_config = TorchAOConfig(c)
|
current_torchao_config = TorchAOConfig(c, self.skip_modules)
|
||||||
return TorchAOLinearMethod(current_torchao_config)
|
return TorchAOLinearMethod(current_torchao_config)
|
||||||
else:
|
else:
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
@ -108,8 +149,17 @@ def torchao_quantize_param_data(param: torch.Tensor,
|
|||||||
"""
|
"""
|
||||||
from torchao.core.config import AOBaseConfig
|
from torchao.core.config import AOBaseConfig
|
||||||
from torchao.quantization import quantize_
|
from torchao.quantization import quantize_
|
||||||
|
|
||||||
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
|
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
|
||||||
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
"""
|
||||||
|
Avoid real weight allocation for faster load, since we will
|
||||||
|
end up setting it to param.
|
||||||
|
"""
|
||||||
|
with torch.device("meta"):
|
||||||
|
dummy_linear = torch.nn.Linear(param.shape[1],
|
||||||
|
param.shape[0],
|
||||||
|
bias=False)
|
||||||
|
|
||||||
dummy_linear.weight = param
|
dummy_linear.weight = param
|
||||||
quantize_(dummy_linear, torchao_config)
|
quantize_(dummy_linear, torchao_config)
|
||||||
return dummy_linear.weight
|
return dummy_linear.weight
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user