diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index a4571a554572..1c0210b6a814 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -8,6 +8,7 @@ from __future__ import annotations from typing import TYPE_CHECKING import pytest +from torch_xla._internal import tpu from vllm.platforms import current_platform @@ -63,3 +64,45 @@ def test_basic( output = vllm_outputs[0][1] assert "1024" in output or "0, 1" in output + + +TP_SIZE_8 = 8 + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a test for TPU only") +@pytest.mark.skipif(tpu.num_available_chips() < TP_SIZE_8, + reason=f"This test requires {TP_SIZE_8} TPU chips.") +def test_gemma3_27b_with_text_input_and_tp( + vllm_runner: type[VllmRunner], + monkeypatch: pytest.MonkeyPatch, +) -> None: + model = "google/gemma-3-27b-it" + max_tokens = 16 + tensor_parallel_size = TP_SIZE_8 + max_num_seqs = 4 + prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + "The greatest glory in living lies not in never falling,", + ] + answers = [ + " or, through inaction, allow a human being to come to harm.", + " what is essential is invisible to the eye.", + " but in rising every time we fall.", + ] + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + with vllm_runner( + model, + max_num_batched_tokens=256, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) + # vllm_outputs is a list of tuples whose first element is the token id + # and the second element is the output (including the prompt). + for output, answer in zip(vllm_outputs, answers): + generated_text = output[1] + assert answer in generated_text diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 9c95e6d3fa08..52deaf12248a 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -30,6 +30,7 @@ class TpuPlatform(Platform): dispatch_key: str = "XLA" ray_device_key: str = "TPU" device_control_env_var: str = "TPU_VISIBLE_CHIPS" + simple_compile_backend: str = "openxla" supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]