diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 725b1a2e4a580..d999e8f1c90ec 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -17,8 +17,9 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - -torch==2.7.0.dev20250226+cpu -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index cbd2fe6edd81b..76b6297606c33 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -7,6 +7,7 @@ import torch import torch.distributed import torch.nn as nn import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp import torch_xla.runtime as xr import vllm.envs as envs @@ -65,6 +66,15 @@ class TPUWorker: from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + self.profiler = None + if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + # For TPU, we can only have 1 active profiler session for 1 profiler + # server. So we only profile on rank0. + self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + self.profile_dir) + self.profiler = xp.start_server(9012) + def init_device(self): os.environ["PJRT_DEVICE"] = "TPU" torch.set_grad_enabled(False) @@ -152,6 +162,15 @@ class TPUWorker: output = self.model_runner.execute_model(scheduler_output) return output if self.is_driver_worker else None + def profile(self, is_start: bool = True): + if self.rank < 1: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + xp.start_trace(self.profile_dir) + else: + xp.stop_trace() + def load_model(self) -> None: self.model_runner.load_model() diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 7903e81943c24..1a5eaba09b940 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union import torch import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp import torch_xla.runtime as xr import vllm.envs as envs @@ -93,6 +94,27 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): f"tp{world_size}_rank{rank}") xr.initialize_cache(per_rank_path, readonly=False) + self.profiler = None + if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + # For TPU, we can only have 1 active profiler session for 1 profiler + # server. So we only profile on rank0. + self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + self.profile_dir) + self.profiler = xp.start_server(9012) + + def start_profile(self): + if self.rank < 1: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + xp.start_trace(self.profile_dir) + + def stop_profile(self): + if self.rank < 1: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + xp.stop_trace() + def load_model(self): self.model_runner.load_model()