diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh index 734a817fd1a06..10d2e236498ea 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh @@ -139,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" run_and_track_test 6 "test_kv_cache_update_kernel.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" +run_and_track_test 7 "test_tpu_int8.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then diff --git a/tests/v1/tpu/test_tpu_int8.py b/tests/v1/tpu/test_tpu_int8.py new file mode 100644 index 0000000000000..991070dc9239d --- /dev/null +++ b/tests/v1/tpu/test_tpu_int8.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests whether TPU Int8 computation is enabled correctly. + +Run `pytest tests/quantization/test_tpu_int8.py`. +""" +import pytest + +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.tpu_int8 import ( + TPUInt8LinearMethod) +from vllm.platforms import current_platform + +from ...models.registry import HF_EXAMPLE_MODELS + +MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="TPU Int8 is only enabled for TPUs.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [10]) +@pytest.mark.parametrize( + "hf_overrides", + [ + # w8a8 dynamic activation + { + 'quantization_config': { + 'quant_method': 'tpu_int8', + 'activation_scheme': 'dynamic' + } + } + ]) +def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int, + hf_overrides: dict, monkeypatch) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_transformers_version(on_fail="skip") + + activation_scheme = hf_overrides.get('quantization_config', + {}).get('activation_scheme') + quantize_activation = activation_scheme == 'dynamic' + + # Allows using apply_model + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + # Prevent error from re-initializing cache + monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "") + + prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + "The greatest glory in living lies not in never falling,", + ] + answers = [ + "or, being injured, not kill, except in", + "without the heart, one can only see wrongly.", + "but in rising every time we fall. - Nelson" + ] + + with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm: + + def check_model(model): + for name, module in model.named_modules(): + if not isinstance(module, LinearBase): + continue + quant_method = module.quant_method + assert isinstance(quant_method, TPUInt8LinearMethod) + assert quant_method.quantize_activation == quantize_activation + + vllm.apply_model(check_model) + outputs = vllm.generate_greedy(prompts, max_tokens) + for (_, output), answer in zip(outputs, answers): + assert answer in output diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 83c8a98eac913..38de4b54fb191 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import ModelWeightParameter -ACTIVATION_SCHEMES = ["none"] +ACTIVATION_SCHEMES = ["none", "dynamic"] class Int8TpuConfig(QuantizationConfig): @@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Int8TpuConfig): self.quant_config = quant_config + self.quantize_activation = False + if self.quant_config.activation_scheme == 'dynamic': + self.quantize_activation = True def create_weights(self, layer: Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, @@ -107,7 +110,7 @@ class TPUInt8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: try: - import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + import torch_xla.experimental.custom_kernel # noqa: F401 except ImportError as err: raise ImportError( "Please install torch_xla by following the instructions at " @@ -115,7 +118,8 @@ class TPUInt8LinearMethod(LinearMethodBase): "to run vLLM on TPU.") from err weight = layer.weight scale = layer.scale - out = torch.ops.xla.quantized_matmul(x, weight, scale) + out = torch.ops.xla.quantized_matmul_int8( + x, weight, scale, quantize_activation=self.quantize_activation) if bias is not None: out = out + bias return out