diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9a380373d461..4d9a113e39ee 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -113,9 +113,16 @@ class TPUWorker: # can have slightly different XLA graphs. world_size = self.parallel_config.world_size rank = xr.global_ordinal() - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") - xr.initialize_cache(per_rank_path, readonly=False) + # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. + # Consequently, changes in optimization flags, which affect compilation + # results, don't change the cache key. This can result in the wrong + # compilation being used. To prevent this, disabling the XLA compilation + # cache during development is recommended.We can disable it by + # `export VLLM_XLA_CACHE_PATH=` + if envs.VLLM_XLA_CACHE_PATH: + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) # Init ModelRunner here, so that we have access to self.device. self.model_runner = TPUModelRunner(self.vllm_config, self.device) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 66911790662e..71b4b38fb9d6 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -93,9 +93,16 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): # can have slightly different XLA graphs. world_size = self.parallel_config.world_size rank = xr.global_ordinal() - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") - xr.initialize_cache(per_rank_path, readonly=False) + # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. + # Consequently, changes in optimization flags, which affect compilation + # results, don't change the cache key. This can result in the wrong + # compilation being used. To prevent this, disabling the XLA compilation + # cache during development is recommended.We can disable it by + # `export VLLM_XLA_CACHE_PATH=` + if envs.VLLM_XLA_CACHE_PATH: + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) self.profiler = None if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: