Fix AOPerModuleConfig name changes (#18869)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang 2025-06-05 21:51:32 -04:00 committed by GitHub
parent cb6d572e85
commit c8134bea15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 5 deletions

View File

@ -424,6 +424,9 @@ steps:
- vllm/model_executor/layers/quantization
- tests/quantization
commands:
# temporary install here since we need nightly, will move to requirements/test.in
# after torchao 0.12 release
- pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
- label: LM Eval Small Models # 53min

View File

@ -13,7 +13,7 @@ TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_pre_quantized_model(vllm_runner):
with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m",
with vllm_runner("drisspg/fp8-opt-125m",
quantization="torchao",
dtype="bfloat16",
enforce_eager=True) as llm:
@ -30,10 +30,10 @@ def test_pre_quantized_model(vllm_runner):
"cuda:0",
# {"": "cuda"},
])
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
def test_opt_125m_int8wo_model_loading_with_params(vllm_runner,
pt_load_map_location):
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo"
model_name = "jerryzh168/opt-125m-int8wo-partial-quant"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",

View File

@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -13,12 +14,28 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""
def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
if is_torch_equal_or_newer("2.8.0"):
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
logger.info(
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
# TODO: remove after the torch dependency is updated to 2.8
if is_torch_equal_or_newer(
"2.7.0") and not is_torch_equal_or_newer("2.8.0"):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
@ -61,10 +78,10 @@ class TorchAOConfig(QuantizationConfig):
if not isinstance(layer, LinearBase):
return None
from torchao.quantization import AOPerModuleConfig
from torchao.quantization import ModuleFqnToConfig
module_fqn = prefix
if isinstance(self.torchao_config, AOPerModuleConfig):
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)