[TPU] Add support for online w8a8 quantization (#22425)

Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
This commit is contained in:
Kyuyeun Kim 2025-08-08 23:12:54 -07:00 committed by GitHub
parent 10a02535d4
commit 9a0c5ded5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 82 additions and 3 deletions

View File

@ -139,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
run_and_track_test 7 "test_tpu_int8.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then

View File

@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether TPU Int8 computation is enabled correctly.
Run `pytest tests/quantization/test_tpu_int8.py`.
"""
import pytest
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.tpu_int8 import (
TPUInt8LinearMethod)
from vllm.platforms import current_platform
from ...models.registry import HF_EXAMPLE_MODELS
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="TPU Int8 is only enabled for TPUs.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize(
"hf_overrides",
[
# w8a8 dynamic activation
{
'quantization_config': {
'quant_method': 'tpu_int8',
'activation_scheme': 'dynamic'
}
}
])
def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
hf_overrides: dict, monkeypatch) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_transformers_version(on_fail="skip")
activation_scheme = hf_overrides.get('quantization_config',
{}).get('activation_scheme')
quantize_activation = activation_scheme == 'dynamic'
# Allows using apply_model
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
# Prevent error from re-initializing cache
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
"or, being injured, not kill, except in",
"without the heart, one can only see wrongly.",
"but in rising every time we fall. - Nelson"
]
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
def check_model(model):
for name, module in model.named_modules():
if not isinstance(module, LinearBase):
continue
quant_method = module.quant_method
assert isinstance(quant_method, TPUInt8LinearMethod)
assert quant_method.quantize_activation == quantize_activation
vllm.apply_model(check_model)
outputs = vllm.generate_greedy(prompts, max_tokens)
for (_, output), answer in zip(outputs, answers):
assert answer in output

View File

@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter
ACTIVATION_SCHEMES = ["none"]
ACTIVATION_SCHEMES = ["none", "dynamic"]
class Int8TpuConfig(QuantizationConfig):
@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Int8TpuConfig):
self.quant_config = quant_config
self.quantize_activation = False
if self.quant_config.activation_scheme == 'dynamic':
self.quantize_activation = True
def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
@ -107,7 +110,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
try:
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
import torch_xla.experimental.custom_kernel # noqa: F401
except ImportError as err:
raise ImportError(
"Please install torch_xla by following the instructions at "
@ -115,7 +118,8 @@ class TPUInt8LinearMethod(LinearMethodBase):
"to run vLLM on TPU.") from err
weight = layer.weight
scale = layer.scale
out = torch.ops.xla.quantized_matmul(x, weight, scale)
out = torch.ops.xla.quantized_matmul_int8(
x, weight, scale, quantize_activation=self.quantize_activation)
if bias is not None:
out = out + bias
return out