mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:55:36 +08:00
Fix AOPerModuleConfig name changes (#18869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
parent
cb6d572e85
commit
c8134bea15
@ -424,6 +424,9 @@ steps:
|
|||||||
- vllm/model_executor/layers/quantization
|
- vllm/model_executor/layers/quantization
|
||||||
- tests/quantization
|
- tests/quantization
|
||||||
commands:
|
commands:
|
||||||
|
# temporary install here since we need nightly, will move to requirements/test.in
|
||||||
|
# after torchao 0.12 release
|
||||||
|
- pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
|
||||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||||
|
|
||||||
- label: LM Eval Small Models # 53min
|
- label: LM Eval Small Models # 53min
|
||||||
|
|||||||
@ -13,7 +13,7 @@ TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None
|
|||||||
|
|
||||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
def test_pre_quantized_model(vllm_runner):
|
def test_pre_quantized_model(vllm_runner):
|
||||||
with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m",
|
with vllm_runner("drisspg/fp8-opt-125m",
|
||||||
quantization="torchao",
|
quantization="torchao",
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
enforce_eager=True) as llm:
|
enforce_eager=True) as llm:
|
||||||
@ -30,10 +30,10 @@ def test_pre_quantized_model(vllm_runner):
|
|||||||
"cuda:0",
|
"cuda:0",
|
||||||
# {"": "cuda"},
|
# {"": "cuda"},
|
||||||
])
|
])
|
||||||
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
|
def test_opt_125m_int8wo_model_loading_with_params(vllm_runner,
|
||||||
pt_load_map_location):
|
pt_load_map_location):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
model_name = "jerryzh168/opt-125m-int4wo"
|
model_name = "jerryzh168/opt-125m-int8wo-partial-quant"
|
||||||
with vllm_runner(model_name=model_name,
|
with vllm_runner(model_name=model_name,
|
||||||
quantization="torchao",
|
quantization="torchao",
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -13,12 +14,28 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TorchAOConfig(QuantizationConfig):
|
class TorchAOConfig(QuantizationConfig):
|
||||||
"""Config class for torchao."""
|
"""Config class for torchao."""
|
||||||
|
|
||||||
def __init__(self, torchao_config) -> None:
|
def __init__(self, torchao_config) -> None:
|
||||||
self.torchao_config = torchao_config
|
self.torchao_config = torchao_config
|
||||||
|
"""
|
||||||
|
# TorchAO quantization relies on tensor subclasses. In order,
|
||||||
|
# to enable proper caching this needs standalone compile
|
||||||
|
if is_torch_equal_or_newer("2.8.0"):
|
||||||
|
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
|
||||||
|
logger.info(
|
||||||
|
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
|
||||||
|
|
||||||
|
# TODO: remove after the torch dependency is updated to 2.8
|
||||||
|
if is_torch_equal_or_newer(
|
||||||
|
"2.7.0") and not is_torch_equal_or_newer("2.8.0"):
|
||||||
|
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
||||||
|
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"TorchAOConfig({self.torchao_config})"
|
return f"TorchAOConfig({self.torchao_config})"
|
||||||
@ -61,10 +78,10 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
if not isinstance(layer, LinearBase):
|
if not isinstance(layer, LinearBase):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from torchao.quantization import AOPerModuleConfig
|
from torchao.quantization import ModuleFqnToConfig
|
||||||
|
|
||||||
module_fqn = prefix
|
module_fqn = prefix
|
||||||
if isinstance(self.torchao_config, AOPerModuleConfig):
|
if isinstance(self.torchao_config, ModuleFqnToConfig):
|
||||||
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
||||||
c = module_fqn_to_config.get(
|
c = module_fqn_to_config.get(
|
||||||
module_fqn) or module_fqn_to_config.get("_default", None)
|
module_fqn) or module_fqn_to_config.get("_default", None)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user