[TPU] support disabling xla compilation cache (#15567)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-03-26 17:09:28 -07:00 committed by GitHub
parent 7a888271f5
commit e74ff409e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 6 deletions

View File

@ -113,9 +113,16 @@ class TPUWorker:
# can have slightly different XLA graphs. # can have slightly different XLA graphs.
world_size = self.parallel_config.world_size world_size = self.parallel_config.world_size
rank = xr.global_ordinal() rank = xr.global_ordinal()
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, # The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
f"tp{world_size}_rank{rank}") # Consequently, changes in optimization flags, which affect compilation
xr.initialize_cache(per_rank_path, readonly=False) # results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device. # Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(self.vllm_config, self.device) self.model_runner = TPUModelRunner(self.vllm_config, self.device)

View File

@ -93,9 +93,16 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# can have slightly different XLA graphs. # can have slightly different XLA graphs.
world_size = self.parallel_config.world_size world_size = self.parallel_config.world_size
rank = xr.global_ordinal() rank = xr.global_ordinal()
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, # The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
f"tp{world_size}_rank{rank}") # Consequently, changes in optimization flags, which affect compilation
xr.initialize_cache(per_rank_path, readonly=False) # results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
self.profiler = None self.profiler = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: