diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4ee6b499b539..b739851cb905 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -424,6 +424,9 @@ steps: - vllm/model_executor/layers/quantization - tests/quantization commands: + # temporary install here since we need nightly, will move to requirements/test.in + # after torchao 0.12 release + - pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index c966dc9b8152..54ec59585450 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -13,7 +13,7 @@ TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") def test_pre_quantized_model(vllm_runner): - with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m", + with vllm_runner("drisspg/fp8-opt-125m", quantization="torchao", dtype="bfloat16", enforce_eager=True) as llm: @@ -30,10 +30,10 @@ def test_pre_quantized_model(vllm_runner): "cuda:0", # {"": "cuda"}, ]) -def test_opt_125m_int4wo_model_loading_with_params(vllm_runner, +def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, pt_load_map_location): torch._dynamo.reset() - model_name = "jerryzh168/opt-125m-int4wo" + model_name = "jerryzh168/opt-125m-int8wo-partial-quant" with vllm_runner(model_name=model_name, quantization="torchao", dtype="bfloat16", diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index af362f7a7d2d..a7d9332032a2 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -13,12 +14,28 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs +logger = init_logger(__name__) + class TorchAOConfig(QuantizationConfig): """Config class for torchao.""" def __init__(self, torchao_config) -> None: self.torchao_config = torchao_config + """ + # TorchAO quantization relies on tensor subclasses. In order, + # to enable proper caching this needs standalone compile + if is_torch_equal_or_newer("2.8.0"): + os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1" + logger.info( + "Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1") + + # TODO: remove after the torch dependency is updated to 2.8 + if is_torch_equal_or_newer( + "2.7.0") and not is_torch_equal_or_newer("2.8.0"): + os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" + logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") + """ def __repr__(self) -> str: return f"TorchAOConfig({self.torchao_config})" @@ -61,10 +78,10 @@ class TorchAOConfig(QuantizationConfig): if not isinstance(layer, LinearBase): return None - from torchao.quantization import AOPerModuleConfig + from torchao.quantization import ModuleFqnToConfig module_fqn = prefix - if isinstance(self.torchao_config, AOPerModuleConfig): + if isinstance(self.torchao_config, ModuleFqnToConfig): module_fqn_to_config = self.torchao_config.module_fqn_to_config c = module_fqn_to_config.get( module_fqn) or module_fqn_to_config.get("_default", None)