diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index d56c25dd9da24..9a380373d4617 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -66,14 +66,18 @@ class TPUWorker: from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + # Delay profiler initialization to the start of the profiling. + # This is because in vLLM V1, MP runtime is initialized before the + # TPU Worker is initialized. The profiler server needs to start after + # MP runtime is initialized. self.profiler = None + self.profile_dir = None if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR logger.info("Profiling enabled. Traces will be saved to: %s", self.profile_dir) - self.profiler = xp.start_server(9012) if self.model_config.seed is None: self.model_config.seed = 0 @@ -168,9 +172,11 @@ class TPUWorker: def profile(self, is_start: bool = True): if self.rank < 1: - if self.profiler is None: + if self.profile_dir is None: raise RuntimeError("Profiler is not enabled.") if is_start: + if self.profiler is None: + self.profiler = xp.start_server(9012) xp.start_trace(self.profile_dir) else: xp.stop_trace()