[Feat] Iteration-level profiling for Torch and CUDA profiler (#28987)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benjamin Chislett 2025-11-19 22:17:48 -05:00 committed by GitHub
parent 3168285fca
commit fcbcba6c70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 437 additions and 63 deletions

View File

@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm.envs as envs
from vllm.profiler.gpu_profiler import WorkerProfiler
class ConcreteWorkerProfiler(WorkerProfiler):
"""
A basic implementation of a worker profiler for testing purposes.
"""
def __init__(self):
self.start_call_count = 0
self.stop_call_count = 0
self.should_fail_start = False
super().__init__()
def _start(self) -> None:
if self.should_fail_start:
raise RuntimeError("Simulated start failure")
self.start_call_count += 1
def _stop(self) -> None:
self.stop_call_count += 1
@pytest.fixture(autouse=True)
def reset_mocks():
"""Fixture to reset mocks and env variables before each test."""
envs.VLLM_PROFILER_DELAY_ITERS = 0
envs.VLLM_PROFILER_MAX_ITERS = 0
def test_immediate_start_stop():
"""Test standard start without delay."""
profiler = ConcreteWorkerProfiler()
profiler.start()
assert profiler._running is True
assert profiler._active is True
assert profiler.start_call_count == 1
profiler.stop()
assert profiler._running is False
assert profiler._active is False
assert profiler.stop_call_count == 1
def test_delayed_start():
"""Test that profiler waits for N steps before actually starting."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
profiler = ConcreteWorkerProfiler()
# User requests start
profiler.start()
# Should be active (request accepted) but not running (waiting for delay)
assert profiler._active is True
assert profiler._running is False
assert profiler.start_call_count == 0
# Step 1
profiler.step()
assert profiler._running is False
# Step 2 (Threshold reached)
profiler.step()
assert profiler._running is True
assert profiler.start_call_count == 1
def test_max_iterations():
"""Test that profiler stops automatically after max iterations."""
envs.VLLM_PROFILER_MAX_ITERS = 2
profiler = ConcreteWorkerProfiler()
profiler.start()
assert profiler._running is True
# Iteration 1
profiler.step() # profiling_count becomes 1
assert profiler._running is True
# Iteration 2
profiler.step() # profiling_count becomes 2
assert profiler._running is True
# Iteration 3 (Exceeds max)
profiler.step() # profiling_count becomes 3
# Should have stopped now
assert profiler._running is False
assert profiler.stop_call_count == 1
def test_delayed_start_and_max_iters():
"""Test combined delayed start and max iterations."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
envs.VLLM_PROFILER_MAX_ITERS = 2
profiler = ConcreteWorkerProfiler()
profiler.start()
# Step 1
profiler.step()
assert profiler._running is False
assert profiler._active is True
# Step 2 (Starts now)
profiler.step()
assert profiler._profiling_for_iters == 1
assert profiler._running is True
assert profiler._active is True
# Next iteration
profiler.step()
assert profiler._profiling_for_iters == 2
assert profiler._running is True
# Iteration 2 (exceeds max)
profiler.step()
# Should have stopped now
assert profiler._running is False
assert profiler.stop_call_count == 1
def test_idempotency():
"""Test that calling start/stop multiple times doesn't break logic."""
profiler = ConcreteWorkerProfiler()
# Double Start
profiler.start()
profiler.start()
assert profiler.start_call_count == 1 # Should only start once
# Double Stop
profiler.stop()
profiler.stop()
assert profiler.stop_call_count == 1 # Should only stop once
def test_step_inactive():
"""Test that stepping while inactive does nothing."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
profiler = ConcreteWorkerProfiler()
# Not started yet
profiler.step()
profiler.step()
# Even though we stepped 2 times, start shouldn't happen because active=False
assert profiler.start_call_count == 0
def test_start_failure():
"""Test behavior when the underlying _start method raises exception."""
profiler = ConcreteWorkerProfiler()
profiler.should_fail_start = True
profiler.start()
# Exception caught in _call_start
assert profiler._running is False # Should not mark as running
assert profiler._active is True # Request is still considered active
assert profiler.start_call_count == 0 # Logic failed inside start
def test_shutdown():
"""Test that shutdown calls stop only if running."""
profiler = ConcreteWorkerProfiler()
# Case 1: Not running
profiler.shutdown()
assert profiler.stop_call_count == 0
# Case 2: Running
profiler.start()
profiler.shutdown()
assert profiler.stop_call_count == 1
def test_mixed_delay_and_stop():
"""Test manual stop during the delay period."""
envs.VLLM_PROFILER_DELAY_ITERS = 5
profiler = ConcreteWorkerProfiler()
profiler.start()
profiler.step()
profiler.step()
# User cancels before delay finishes
profiler.stop()
assert profiler._active is False
# Further steps should not trigger start
profiler.step()
profiler.step()
profiler.step()
assert profiler.start_call_count == 0

View File

@ -92,11 +92,14 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: str | None = None
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False
VLLM_USE_AOT_COMPILE: bool = False
VLLM_USE_BYTECODE_HOOK: bool = False
VLLM_FORCE_AOT_LOAD: bool = False
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_PROFILER_DELAY_ITERS: int = 0
VLLM_PROFILER_MAX_ITERS: int = 0
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
@ -872,6 +875,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
),
# Disable torch profiling of the AsyncLLMEngine process.
# If set to 1, will not profile the engine process.
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
),
# Delay number of iterations before starting profiling when using
# the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
"VLLM_PROFILER_DELAY_ITERS": lambda: int(
os.getenv("VLLM_PROFILER_DELAY_ITERS", "0")
),
# Maximum number of iterations to profile when using the torch/torch CUDA profiler.
# If set to 0, will not limit the number of iterations.
"VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")),
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
# If set, allow loading or unloading lora adapters in runtime,

View File

@ -1,37 +1,212 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from contextlib import nullcontext
import torch
from typing_extensions import override
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudaProfilerWrapper:
class WorkerProfiler(ABC):
def __init__(self) -> None:
self._profiler_running = False
self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS
if self._delay_iters > 0:
logger.info_once(
"GPU profiling will start "
f"{self._delay_iters} steps after start_profile."
)
self._max_iters = envs.VLLM_PROFILER_MAX_ITERS
if self._max_iters > 0:
logger.info_once(
"GPU profiling will stop "
f"after {self._max_iters} worker steps, "
"or when stop_profile is received."
)
# Track when the profiler gets triggered by start_profile
self._active_iteration_count = 0
self._active = False
# Track when the profiler is actually running
self._profiling_for_iters = 0
self._running = False
@abstractmethod
def _start(self) -> None:
"""Start the profiler."""
pass
@abstractmethod
def _stop(self) -> None:
"""Stop the profiler."""
pass
def _call_start(self) -> None:
"""Call _start with error handling but no safeguards."""
try:
self._start()
self._running = True # Only mark as running if start succeeds
except Exception as e:
logger.warning("Failed to start profiler: %s", e)
def _call_stop(self) -> None:
"""Call _stop with error handling but no safeguards."""
try:
self._stop()
logger.info("Profiler stopped successfully.")
except Exception as e:
logger.warning("Failed to stop profiler: %s", e)
self._running = False # Always mark as not running, assume stop worked
def start(self) -> None:
"""Attempt to start the profiler, accounting for delayed starts."""
if self._active:
logger.debug(
"start_profile received when profiler is already active. "
"Ignoring request."
)
return
self._active = True
if self._delay_iters == 0:
self._call_start()
def step(self) -> None:
"""Update the profiler state at each worker step,
to handle delayed starts and max iteration limits."""
if not self._active:
return
self._active_iteration_count += 1
if (
not self._running
and self._delay_iters > 0
and self._active_iteration_count == self._delay_iters
):
logger.info("Starting profiler after delay...")
self._call_start()
if self._running:
self._profiling_for_iters += 1
if (
self._max_iters > 0
and self._running
and self._profiling_for_iters > self._max_iters
):
# Automatically stop the profiler after max iters
# will be marked as not running, but leave as active so that stop
# can clean up properly
logger.info("Max profiling iterations reached. Stopping profiler...")
self._call_stop()
return
def stop(self) -> None:
"""Attempt to stop the profiler, accounting for overlapped calls."""
if not self._active:
logger.debug(
"stop_profile received when profiler is not active. Ignoring request."
)
return
self._active = False
self._active_iteration_count = 0
self._profiling_for_iters = 0
if self._running:
self._call_stop()
def shutdown(self) -> None:
"""Ensure profiler is stopped when shutting down."""
logger.info_once("Shutting down profiler")
if self._running:
self.stop()
def annotate_context_manager(self, name: str):
"""Return a context manager to annotate profiler traces."""
return nullcontext()
class TorchProfilerWrapper(WorkerProfiler):
def __init__(self, worker_name: str, local_rank: int) -> None:
super().__init__()
self.local_rank = local_rank
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info(
"Torch profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
)
logger.debug(
"Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
),
)
@override
def _start(self) -> None:
self.profiler.start()
@override
def _stop(self) -> None:
self.profiler.stop()
rank = self.local_rank
profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
sort_key = "self_cuda_time_total"
table = self.profiler.key_averages().table(sort_by=sort_key)
with open(profiler_out_file, "w") as f:
print(table, file=f)
# only print profiler results on rank 0
if rank == 0:
print(table)
@override
def annotate_context_manager(self, name: str):
return torch.profiler.record_function(name)
class CudaProfilerWrapper(WorkerProfiler):
def __init__(self) -> None:
super().__init__()
# Note: lazy import to avoid dependency issues if CUDA is not available.
import torch.cuda.profiler as cuda_profiler
self._cuda_profiler = cuda_profiler
def start(self) -> None:
try:
self._cuda_profiler.start()
self._profiler_running = True
logger.info_once("Started CUDA profiler")
except Exception as e:
logger.warning_once("Failed to start CUDA profiler: %s", e)
@override
def _start(self) -> None:
self._cuda_profiler.start()
def stop(self) -> None:
if self._profiler_running:
try:
self._cuda_profiler.stop()
logger.info_once("Stopped CUDA profiler")
except Exception as e:
logger.warning_once("Failed to stop CUDA profiler: %s", e)
finally:
self._profiler_running = False
@override
def _stop(self) -> None:
self._cuda_profiler.stop()
def shutdown(self) -> None:
"""Ensure profiler is stopped when shutting down."""
self.stop()
@override
def annotate_context_manager(self, name: str):
return torch.cuda.nvtx.range(name)

View File

@ -160,11 +160,23 @@ class AsyncLLM(EngineClient):
except RuntimeError:
pass
if envs.VLLM_TORCH_PROFILER_DIR:
if (
envs.VLLM_TORCH_PROFILER_DIR
and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM
):
logger.info(
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
envs.VLLM_TORCH_PROFILER_DIR,
)
if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0:
logger.warning_once(
"Torch profiler received max_iters or delay_iters setting. These "
"are not compatible with the AsyncLLM profiler and will be ignored "
"for the AsyncLLM process. Engine process profiling will still "
"respect these settings. Consider setting "
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable "
"AsyncLLM profiling."
)
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
self.profiler = torch.profiler.profile(
activities=[

View File

@ -36,7 +36,7 @@ from vllm.model_executor import set_random_seed
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.profiler.gpu_profiler import CudaProfilerWrapper
from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes
@ -90,32 +90,9 @@ class Worker(WorkerBase):
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
logger.info(
"Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
)
logger.debug(
"Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
),
self.profiler = TorchProfilerWrapper(
worker_name=worker_name, local_rank=self.local_rank
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
self.profiler = CudaProfilerWrapper()
@ -526,10 +503,12 @@ class Worker(WorkerBase):
if not self.profiler:
return nullcontext()
self.profiler.step()
num_new = len(scheduler_output.scheduled_new_reqs)
num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)
return torch.profiler.record_function(
return self.profiler.annotate_context_manager(
f"execute_new_{num_new}_cached_{num_cached}"
)
@ -587,24 +566,11 @@ class Worker(WorkerBase):
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
raise RuntimeError("Profiling is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
if isinstance(self.profiler, torch.profiler.profile):
rank = self.local_rank
profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
sort_key = "self_cuda_time_total"
table = self.profiler.key_averages().table(sort_by=sort_key)
with open(profiler_out_file, "w") as f:
print(table, file=f)
# only print profiler results on rank 0
if rank == 0:
print(table)
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1, uniform_decode=True)
@ -865,6 +831,8 @@ class Worker(WorkerBase):
def shutdown(self) -> None:
if runner := getattr(self, "model_runner", None):
runner.ensure_kv_transfer_shutdown()
if self.profiler is not None:
self.profiler.shutdown()
def init_worker_distributed_environment(