diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index 89252000f4003..8616ea2b79dbc 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -21,6 +21,8 @@ docker run --privileged --net host --shm-size=16G -it \ && python3 -m pip install lm_eval[api]==0.4.4 \ && export VLLM_USE_V1=1 \ && export VLLM_XLA_CHECK_RECOMPILATION=1 \ + && echo TEST_0 \ + && pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \ && echo TEST_1 \ && pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \ && echo TEST_2 \ diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 77fbb5827da9e..2bc32ace0a59d 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -58,7 +58,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): more_args = None if current_platform.is_tpu(): # Limit compilation time for TPU V1 - more_args = "max_num_seqs=64" + more_args = "max_model_len=2048,max_num_seqs=64" # Add TP test (if provided) if TPU_TP_TEST_STR: diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 0d7e8d8d7f5e9..8164952fe3823 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -32,7 +32,7 @@ TENSOR_PARALLEL_SIZES = [1] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) -def test_models( +def test_basic( vllm_runner: type[VllmRunner], monkeypatch: pytest.MonkeyPatch, model: str, @@ -58,4 +58,5 @@ def test_models( vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) output = vllm_outputs[0][1] - assert "1024" in output + + assert "1024" in output or "0, 1" in output diff --git a/tests/v1/tpu/test_perf.py b/tests/v1/tpu/test_perf.py new file mode 100644 index 0000000000000..94a1da88a2f06 --- /dev/null +++ b/tests/v1/tpu/test_perf.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A basic performance regression test for TPUs + +Run `pytest tests/v1/tpu/test_perf.py`. +""" +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from vllm.platforms import current_platform +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import get_tokenizer + +if TYPE_CHECKING: + from tests.conftest import VllmRunner + + +@dataclass +class TestParams: + model: str + num_prompts: int + prefix_len: int + decode_len: int + expected_avg_time: float + err_tol: float + + +TEST_PARAMS = [ + # TODO: Cannot run a series of tests because: + # RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: + # open(/dev/vfio/0): Device or resource busy: Device or resource busy; + # Couldn't open iommu group /dev/vfio/0 + # => Investigate + + # TestParams( + # model="Qwen/Qwen2.5-1.5B-Instruct", + # num_prompts=1, + # prefix_len=10, + # decode_len=5, + # expected_avg_time=0.03, + # err_tol=0.01, + # ), + # TestParams( + # model="Qwen/Qwen2.5-1.5B-Instruct", + # num_prompts=10, + # prefix_len=100, + # decode_len=50, + # expected_avg_time=0.234, + # err_tol=0.020, + # ), + TestParams( + model="Qwen/Qwen2.5-1.5B-Instruct", + num_prompts=64, + prefix_len=500, + decode_len=50, + + # (This is the active CI/CD instance) + # commit id: ccb246776d93ef105904a8ec015b3587240a1183 + # tpu: v5lite (vllm CI/CD) + expected_avg_time=1.4, + err_tol=0.30, + + # (TODO: There is no v6e in CI/CD currently) + # commit id: ccb246776d93ef105904a8ec015b3587240a1183 + # tpu: v6e + # expected_avg_time=1.5, + # err_tol=0.20, + ), +] + +NUM_WARMUPS = 5 +NUM_RUNS = 10 + +MAX_MODEL_LEN = 1024 +MAX_NUM_SEQS = 32 +GPU_UTIL = 0.9 + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a basic performance test for TPU only") +@pytest.mark.parametrize("params", TEST_PARAMS) +def test_perf( + vllm_runner: type[VllmRunner], + monkeypatch: pytest.MonkeyPatch, + params: TestParams, +) -> None: + tokenizer = get_tokenizer(params.model, + tokenizer_mode="auto", + trust_remote_code=True) + + prompts = [] + for i in range(params.num_prompts): + prefix_token_ids = np.random.randint(0, + tokenizer.vocab_size, + size=params.prefix_len).tolist() + prompt = tokenizer.decode(prefix_token_ids) + prompts.append(prompt) + + print( + "-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format( + len(prompts), params.prefix_len, params.decode_len)) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + sampling_params = SamplingParams(max_tokens=params.decode_len, + temperature=1.0, + min_p=0.0) + + with vllm_runner(params.model, + max_num_batched_tokens=MAX_MODEL_LEN, + max_model_len=MAX_MODEL_LEN, + max_num_seqs=MAX_NUM_SEQS, + gpu_memory_utilization=GPU_UTIL, + enforce_eager=False, + tensor_parallel_size=1) as vllm_model: + print(" -- Warmup / Compile") + for i in range(NUM_WARMUPS): + _ = vllm_model.generate(prompts, sampling_params) + + print(" -- Benchmarking... ") + times = [] + for i in range(NUM_RUNS): + start_time = time.time() + _ = vllm_model.generate(prompts, sampling_params) + times.append(time.time() - start_time) + + avg_time = sum(times) / len(times) + + print(" -- avg_time = {}".format(avg_time)) + print(" -- expected_avg_time = {} with err_tol = {}".format( + params.expected_avg_time, params.err_tol)) + diff = avg_time - params.expected_avg_time + ok = diff < params.err_tol + if diff < -params.err_tol: + print(" !! WARNING !! Performance has improved by {}, " + "it may be necessary to fine-tune the " + "expected_avg_time = {}".format( + -diff, params.expected_avg_time)) + + assert ok, " !! ERROR !! Regression detected" diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8f6a54892a4e6..7f7318a7bdd3e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -77,9 +77,12 @@ class TPUModelRunner: parallel_config = self.parallel_config self.device = device self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION - if self.check_recompilation: - self.num_xla_graphs = xr.get_num_cached_compilation_graph() + self.enforce_eager = model_config.enforce_eager + + self.num_xla_graphs = 0 + self._update_num_xla_graphs("init") + self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype self._hidden_states_dtype = self.dtype @@ -180,6 +183,31 @@ class TPUModelRunner: max_token_size=self.max_num_tokens, padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + def _update_num_xla_graphs(self, case_str): + check_comp = self.check_recompilation and not self.enforce_eager + if not check_comp: + return + + total_cached_graphs = xr.get_num_cached_compilation_graph() + new_compiled_graphs = total_cached_graphs - self.num_xla_graphs + if new_compiled_graphs == 0: + return + + logger.info("Add new %d compiled XLA graphs due to %s", + new_compiled_graphs, case_str) + self.num_xla_graphs += new_compiled_graphs + + def _verify_num_xla_graphs(self, case_str): + check_comp = self.check_recompilation and not self.enforce_eager + if not check_comp: + return + + curr_cached_graph = xr.get_num_cached_compilation_graph() + assert self.num_xla_graphs == curr_cached_graph, ( + "Recompilation after warm up is detected during {}." + " num_xla_graphs = {} curr_cached_graph = {}".format( + case_str, self.num_xla_graphs, curr_cached_graph)) + def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler output. @@ -694,12 +722,11 @@ class TPUModelRunner: logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, ) - # Check there is no new graph compilation, all the graphs should be - # captured and compiled during warming up. - if self.check_recompilation and not self.enforce_eager: - curr_cached_graph = xr.get_num_cached_compilation_graph() - assert self.num_xla_graphs == curr_cached_graph, ( - "Recompilation after warm up is detected.") + + # Check there are no new graphs compiled - all the graphs should be + # captured and compiled during warm up. + self._verify_num_xla_graphs("execute_model") + return model_runner_output def load_model(self) -> None: @@ -797,7 +824,9 @@ class TPUModelRunner: xm.mark_step() xm.wait_device_ops() end = time.perf_counter() + logger.info("Compilation finished in in %.2f [secs].", end - start) + self._update_num_xla_graphs("model") logger.info("Compiling sampling with different input shapes.") start = time.perf_counter() @@ -832,15 +861,9 @@ class TPUModelRunner: num_reqs_to_sample + 1, self.max_num_reqs) xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - # Record the number cached XLA graph after warming up, this will be - # used for checking there is no additional graph compilation during - # runtime execution. - if self.check_recompilation: - total_cached_graphs = xr.get_num_cached_compilation_graph() - num_compiled_graphs = total_cached_graphs - self.num_xla_graphs - logger.info("Compiled %d XLA graphs.", num_compiled_graphs) - self.num_xla_graphs += num_compiled_graphs + + logger.info("Compilation finished in in %.2f [secs].", end - start) + self._update_num_xla_graphs("sampling") def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """