Add support for loading torchao models with AOPerModuleConfig (#17826)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang 2025-05-14 16:24:59 -07:00 committed by GitHub
parent 2fc9075b82
commit 7974736740
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 10 deletions

View File

@ -31,9 +31,6 @@ def test_pre_quantized_model(vllm_runner):
])
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
pt_load_map_location):
"""
Test loading roberta-base model with no lm_head.
"""
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo"
with vllm_runner(model_name=model_name,
@ -47,5 +44,20 @@ def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
print(output)
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo-per-module"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -5,10 +5,11 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
@ -55,10 +56,24 @@ class TorchAOConfig(QuantizationConfig):
return cls(ao_config)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["TorchAOLinearMethod"]:
if isinstance(layer, LinearBase):
return TorchAOLinearMethod(self)
return None
prefix: str) -> Optional["QuantizeMethodBase"]:
if not isinstance(layer, LinearBase):
return None
from torchao.quantization import AOPerModuleConfig
module_fqn = prefix
if isinstance(self.torchao_config, AOPerModuleConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(c)
return TorchAOLinearMethod(current_torchao_config)
else:
return UnquantizedLinearMethod()
return TorchAOLinearMethod(self)
def get_scaled_act_names(self) -> list[str]:
return []
@ -75,7 +90,7 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
assert isinstance(torchao_config, AOBaseConfig)
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)