From 951fdd66d36ffed75a55e9bfa9f2221dce4fbcdc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 14 Aug 2024 14:47:51 -0700 Subject: [PATCH] [TPU] Set per-rank XLA cache (#7533) --- vllm/worker/tpu_worker.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index df1a65f6efad..35f8ecdb8126 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -102,12 +102,12 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): # 30-40 graphs for decode. 128 is an arbitrary safe number. torch._dynamo.config.cache_size_limit = 128 # Use persistent cache to avoid XLA recompilation. - # NOTE(woosuk): This does not completely eliminate the recompilation - # overhead because dynamo does not cache the compiled results. - # NOTE(woosuk): Set readonly=False only for the rank 0 process to avoid - # race conditions. - xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, - readonly=not self.is_driver_worker) + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{self.rank}") + xr.initialize_cache(per_rank_path, readonly=False) def load_model(self): self.model_runner.load_model()