[TPU][Profiler] Support start_profile/stop_profile in TPU worker (#13988)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Siyuan Liu 2025-03-04 11:40:06 -08:00 committed by GitHub
parent f89978ad7c
commit beebf4742a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 5 deletions

View File

@ -17,8 +17,9 @@ ray[default]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.7.0.dev20250226+cpu
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

View File

@ -7,6 +7,7 @@ import torch
import torch.distributed
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
@ -65,6 +66,15 @@ class TPUWorker:
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.profiler = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
self.profiler = xp.start_server(9012)
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
torch.set_grad_enabled(False)
@ -152,6 +162,15 @@ class TPUWorker:
output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None
def profile(self, is_start: bool = True):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
xp.start_trace(self.profile_dir)
else:
xp.stop_trace()
def load_model(self) -> None:
self.model_runner.load_model()

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
@ -93,6 +94,27 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
self.profiler = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
self.profiler = xp.start_server(9012)
def start_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.start_trace(self.profile_dir)
def stop_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.stop_trace()
def load_model(self):
self.model_runner.load_model()