vllm/tests/v1/tpu/test_basic.py
Chengji Yao 471fe65630
[TPU][V1] Implicitly adjust page size when there's SMEM OOM (#16871)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
2025-04-21 15:43:13 -06:00

66 lines
1.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
"""A basic correctness check for TPUs
Run `pytest tests/v1/tpu/test_basic.py`.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from vllm.platforms import current_platform
if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODELS = [
"Qwen/Qwen2.5-1.5B-Instruct",
# TODO: Enable this models with v6e
# "Qwen/Qwen2-7B-Instruct",
# "meta-llama/Llama-3.1-8B",
]
TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
model: str,
max_tokens: int,
tensor_parallel_size: int,
max_num_seqs: int,
) -> None:
prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(
model,
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens=1024,
max_model_len=8192,
gpu_memory_utilization=0.7,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output or "0, 1" in output