diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index 82f40c650f8cf..6562942ea3f8c 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -19,17 +19,19 @@ docker run --privileged --net host --shm-size=16G -it \ vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ && python3 -m pip install pytest \ && python3 -m pip install lm_eval[api]==0.4.4 \ + && export VLLM_USE_V1=1 \ + && export VLLM_XLA_CHECK_RECOMPILATION=1 \ && echo TEST_1 \ - && VLLM_USE_V1=1 python3 /workspace/vllm/tests/tpu/test_compilation.py \ + && python3 /workspace/vllm/tests/tpu/test_compilation.py \ && echo TEST_2 \ - && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \ + && pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \ && echo TEST_3 \ - && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ + && pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ && echo TEST_4 \ - && VLLM_USE_V1=1 pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ + && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ && echo TEST_5 \ - && VLLM_USE_V1=1 python3 /workspace/vllm/examples/offline_inference/tpu.py" \ - + && python3 /workspace/vllm/examples/offline_inference/tpu.py" \ + # TODO: This test fails because it uses RANDOM_SEED sampling # && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ diff --git a/vllm/envs.py b/vllm/envs.py index d88ab3b5e7d06..d54de9da25315 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") + VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False @@ -446,6 +447,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_XLA_CACHE_PATH", os.path.join(get_default_cache_root(), "vllm", "xla_cache"), )), + + # If set, assert on XLA recompilation after each execution step. + "VLLM_XLA_CHECK_RECOMPILATION": + lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ec3dcbc064cba..d772a3ee13ec3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -11,6 +11,7 @@ import torch.nn as nn import torch_xla.core.xla_model as xm import torch_xla.runtime as xr +import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig @@ -73,6 +74,10 @@ class TPUModelRunner: scheduler_config = self.scheduler_config parallel_config = self.parallel_config self.device = device + self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION + if self.check_recompilation: + self.num_xla_graphs = xr.get_num_cached_compilation_graph() + self.enforce_eager = model_config.enforce_eager self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype @@ -671,6 +676,12 @@ class TPUModelRunner: logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, ) + # Check there is no new graph compilation, all the graphs should be + # captured and compiled during warming up. + if self.check_recompilation and not self.enforce_eager: + curr_cached_graph = xr.get_num_cached_compilation_graph() + assert self.num_xla_graphs == curr_cached_graph, ( + "Recompilation after warm up is detected.") return model_runner_output def load_model(self) -> None: @@ -810,6 +821,14 @@ class TPUModelRunner: xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in in %.2f [secs].", end - start) + # Record the number cached XLA graph after warming up, this will be + # used for checking there is no additional graph compilation during + # runtime execution. + if self.check_recompilation: + total_cached_graphs = xr.get_num_cached_compilation_graph() + num_compiled_graphs = total_cached_graphs - self.num_xla_graphs + logger.info("Compiled %d XLA graphs.", num_compiled_graphs) + self.num_xla_graphs += num_compiled_graphs def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """