vllm/tests/v1/tpu/test_tpu_int8.py
Kyuyeun Kim 9a0c5ded5a
[TPU] Add support for online w8a8 quantization (#22425)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
2025-08-08 23:12:54 -07:00

74 lines
2.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether TPU Int8 computation is enabled correctly.
Run `pytest tests/quantization/test_tpu_int8.py`.
"""
import pytest
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.tpu_int8 import (
TPUInt8LinearMethod)
from vllm.platforms import current_platform
from ...models.registry import HF_EXAMPLE_MODELS
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="TPU Int8 is only enabled for TPUs.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize(
"hf_overrides",
[
# w8a8 dynamic activation
{
'quantization_config': {
'quant_method': 'tpu_int8',
'activation_scheme': 'dynamic'
}
}
])
def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
hf_overrides: dict, monkeypatch) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_transformers_version(on_fail="skip")
activation_scheme = hf_overrides.get('quantization_config',
{}).get('activation_scheme')
quantize_activation = activation_scheme == 'dynamic'
# Allows using apply_model
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
# Prevent error from re-initializing cache
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
"or, being injured, not kill, except in",
"without the heart, one can only see wrongly.",
"but in rising every time we fall. - Nelson"
]
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
def check_model(model):
for name, module in model.named_modules():
if not isinstance(module, LinearBase):
continue
quant_method = module.quant_method
assert isinstance(quant_method, TPUInt8LinearMethod)
assert quant_method.quantize_activation == quantize_activation
vllm.apply_model(check_model)
outputs = vllm.generate_greedy(prompts, max_tokens)
for (_, output), answer in zip(outputs, answers):
assert answer in output