mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
Add support for loading torchao models with AOPerModuleConfig (#17826)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
parent
2fc9075b82
commit
7974736740
@ -31,9 +31,6 @@ def test_pre_quantized_model(vllm_runner):
|
||||
])
|
||||
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
|
||||
pt_load_map_location):
|
||||
"""
|
||||
Test loading roberta-base model with no lm_head.
|
||||
"""
|
||||
torch._dynamo.reset()
|
||||
model_name = "jerryzh168/opt-125m-int4wo"
|
||||
with vllm_runner(model_name=model_name,
|
||||
@ -47,5 +44,20 @@ def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
|
||||
print(output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
|
||||
torch._dynamo.reset()
|
||||
model_name = "jerryzh168/opt-125m-int4wo-per-module"
|
||||
with vllm_runner(model_name=model_name,
|
||||
quantization="torchao",
|
||||
dtype="bfloat16",
|
||||
pt_load_map_location="cuda:0") as llm:
|
||||
output = llm.generate_greedy(["The capital of France is"],
|
||||
max_tokens=32)
|
||||
|
||||
assert output
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -5,10 +5,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
@ -55,10 +56,24 @@ class TorchAOConfig(QuantizationConfig):
|
||||
return cls(ao_config)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["TorchAOLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return TorchAOLinearMethod(self)
|
||||
return None
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if not isinstance(layer, LinearBase):
|
||||
return None
|
||||
|
||||
from torchao.quantization import AOPerModuleConfig
|
||||
|
||||
module_fqn = prefix
|
||||
if isinstance(self.torchao_config, AOPerModuleConfig):
|
||||
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
||||
c = module_fqn_to_config.get(
|
||||
module_fqn) or module_fqn_to_config.get("_default", None)
|
||||
if c is not None:
|
||||
current_torchao_config = TorchAOConfig(c)
|
||||
return TorchAOLinearMethod(current_torchao_config)
|
||||
else:
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
return TorchAOLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
@ -75,7 +90,7 @@ def torchao_quantize_param_data(param: torch.Tensor,
|
||||
"""
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
assert isinstance(torchao_config, AOBaseConfig)
|
||||
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
|
||||
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
||||
dummy_linear.weight = param
|
||||
quantize_(dummy_linear, torchao_config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user