diff --git a/tests/v1/worker/test_gpu_profiler.py b/tests/v1/worker/test_gpu_profiler.py new file mode 100644 index 0000000000000..f7255fae05a4e --- /dev/null +++ b/tests/v1/worker/test_gpu_profiler.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +import vllm.envs as envs +from vllm.profiler.gpu_profiler import WorkerProfiler + + +class ConcreteWorkerProfiler(WorkerProfiler): + """ + A basic implementation of a worker profiler for testing purposes. + """ + + def __init__(self): + self.start_call_count = 0 + self.stop_call_count = 0 + self.should_fail_start = False + super().__init__() + + def _start(self) -> None: + if self.should_fail_start: + raise RuntimeError("Simulated start failure") + self.start_call_count += 1 + + def _stop(self) -> None: + self.stop_call_count += 1 + + +@pytest.fixture(autouse=True) +def reset_mocks(): + """Fixture to reset mocks and env variables before each test.""" + envs.VLLM_PROFILER_DELAY_ITERS = 0 + envs.VLLM_PROFILER_MAX_ITERS = 0 + + +def test_immediate_start_stop(): + """Test standard start without delay.""" + profiler = ConcreteWorkerProfiler() + + profiler.start() + assert profiler._running is True + assert profiler._active is True + assert profiler.start_call_count == 1 + + profiler.stop() + assert profiler._running is False + assert profiler._active is False + assert profiler.stop_call_count == 1 + + +def test_delayed_start(): + """Test that profiler waits for N steps before actually starting.""" + envs.VLLM_PROFILER_DELAY_ITERS = 2 + profiler = ConcreteWorkerProfiler() + + # User requests start + profiler.start() + + # Should be active (request accepted) but not running (waiting for delay) + assert profiler._active is True + assert profiler._running is False + assert profiler.start_call_count == 0 + + # Step 1 + profiler.step() + assert profiler._running is False + + # Step 2 (Threshold reached) + profiler.step() + assert profiler._running is True + assert profiler.start_call_count == 1 + + +def test_max_iterations(): + """Test that profiler stops automatically after max iterations.""" + envs.VLLM_PROFILER_MAX_ITERS = 2 + profiler = ConcreteWorkerProfiler() + + profiler.start() + assert profiler._running is True + + # Iteration 1 + profiler.step() # profiling_count becomes 1 + assert profiler._running is True + + # Iteration 2 + profiler.step() # profiling_count becomes 2 + assert profiler._running is True + + # Iteration 3 (Exceeds max) + profiler.step() # profiling_count becomes 3 + + # Should have stopped now + assert profiler._running is False + assert profiler.stop_call_count == 1 + + +def test_delayed_start_and_max_iters(): + """Test combined delayed start and max iterations.""" + envs.VLLM_PROFILER_DELAY_ITERS = 2 + envs.VLLM_PROFILER_MAX_ITERS = 2 + profiler = ConcreteWorkerProfiler() + + profiler.start() + + # Step 1 + profiler.step() + assert profiler._running is False + assert profiler._active is True + + # Step 2 (Starts now) + profiler.step() + assert profiler._profiling_for_iters == 1 + assert profiler._running is True + assert profiler._active is True + + # Next iteration + profiler.step() + assert profiler._profiling_for_iters == 2 + assert profiler._running is True + + # Iteration 2 (exceeds max) + profiler.step() + + # Should have stopped now + assert profiler._running is False + assert profiler.stop_call_count == 1 + + +def test_idempotency(): + """Test that calling start/stop multiple times doesn't break logic.""" + profiler = ConcreteWorkerProfiler() + + # Double Start + profiler.start() + profiler.start() + assert profiler.start_call_count == 1 # Should only start once + + # Double Stop + profiler.stop() + profiler.stop() + assert profiler.stop_call_count == 1 # Should only stop once + + +def test_step_inactive(): + """Test that stepping while inactive does nothing.""" + envs.VLLM_PROFILER_DELAY_ITERS = 2 + profiler = ConcreteWorkerProfiler() + + # Not started yet + profiler.step() + profiler.step() + + # Even though we stepped 2 times, start shouldn't happen because active=False + assert profiler.start_call_count == 0 + + +def test_start_failure(): + """Test behavior when the underlying _start method raises exception.""" + profiler = ConcreteWorkerProfiler() + profiler.should_fail_start = True + + profiler.start() + + # Exception caught in _call_start + assert profiler._running is False # Should not mark as running + assert profiler._active is True # Request is still considered active + assert profiler.start_call_count == 0 # Logic failed inside start + + +def test_shutdown(): + """Test that shutdown calls stop only if running.""" + profiler = ConcreteWorkerProfiler() + + # Case 1: Not running + profiler.shutdown() + assert profiler.stop_call_count == 0 + + # Case 2: Running + profiler.start() + profiler.shutdown() + assert profiler.stop_call_count == 1 + + +def test_mixed_delay_and_stop(): + """Test manual stop during the delay period.""" + envs.VLLM_PROFILER_DELAY_ITERS = 5 + profiler = ConcreteWorkerProfiler() + + profiler.start() + profiler.step() + profiler.step() + + # User cancels before delay finishes + profiler.stop() + assert profiler._active is False + + # Further steps should not trigger start + profiler.step() + profiler.step() + profiler.step() + + assert profiler.start_call_count == 0 diff --git a/vllm/envs.py b/vllm/envs.py index 614bc94b978bd..888a09cf6d3ec 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -92,11 +92,14 @@ if TYPE_CHECKING: VLLM_TORCH_PROFILER_DIR: str | None = None VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False + VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False VLLM_USE_AOT_COMPILE: bool = False VLLM_USE_BYTECODE_HOOK: bool = False VLLM_FORCE_AOT_LOAD: bool = False VLLM_TORCH_PROFILER_WITH_STACK: bool = True VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False + VLLM_PROFILER_DELAY_ITERS: int = 0 + VLLM_PROFILER_MAX_ITERS: int = 0 VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False @@ -872,6 +875,19 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool( os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0" ), + # Disable torch profiling of the AsyncLLMEngine process. + # If set to 1, will not profile the engine process. + "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0" + ), + # Delay number of iterations before starting profiling when using + # the torch/torch CUDA profiler. If set to 0, will start profiling immediately. + "VLLM_PROFILER_DELAY_ITERS": lambda: int( + os.getenv("VLLM_PROFILER_DELAY_ITERS", "0") + ), + # Maximum number of iterations to profile when using the torch/torch CUDA profiler. + # If set to 0, will not limit the number of iterations. + "VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")), # If set, vLLM will use Triton implementations of AWQ. "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), # If set, allow loading or unloading lora adapters in runtime, diff --git a/vllm/profiler/gpu_profiler.py b/vllm/profiler/gpu_profiler.py index 58c6689531615..2155b67a3db4b 100644 --- a/vllm/profiler/gpu_profiler.py +++ b/vllm/profiler/gpu_profiler.py @@ -1,37 +1,212 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from contextlib import nullcontext + +import torch +from typing_extensions import override + +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) -class CudaProfilerWrapper: +class WorkerProfiler(ABC): def __init__(self) -> None: - self._profiler_running = False + self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS + if self._delay_iters > 0: + logger.info_once( + "GPU profiling will start " + f"{self._delay_iters} steps after start_profile." + ) + + self._max_iters = envs.VLLM_PROFILER_MAX_ITERS + if self._max_iters > 0: + logger.info_once( + "GPU profiling will stop " + f"after {self._max_iters} worker steps, " + "or when stop_profile is received." + ) + + # Track when the profiler gets triggered by start_profile + self._active_iteration_count = 0 + self._active = False + + # Track when the profiler is actually running + self._profiling_for_iters = 0 + self._running = False + + @abstractmethod + def _start(self) -> None: + """Start the profiler.""" + pass + + @abstractmethod + def _stop(self) -> None: + """Stop the profiler.""" + pass + + def _call_start(self) -> None: + """Call _start with error handling but no safeguards.""" + try: + self._start() + self._running = True # Only mark as running if start succeeds + except Exception as e: + logger.warning("Failed to start profiler: %s", e) + + def _call_stop(self) -> None: + """Call _stop with error handling but no safeguards.""" + try: + self._stop() + logger.info("Profiler stopped successfully.") + except Exception as e: + logger.warning("Failed to stop profiler: %s", e) + self._running = False # Always mark as not running, assume stop worked + + def start(self) -> None: + """Attempt to start the profiler, accounting for delayed starts.""" + if self._active: + logger.debug( + "start_profile received when profiler is already active. " + "Ignoring request." + ) + return + self._active = True + if self._delay_iters == 0: + self._call_start() + + def step(self) -> None: + """Update the profiler state at each worker step, + to handle delayed starts and max iteration limits.""" + if not self._active: + return + + self._active_iteration_count += 1 + + if ( + not self._running + and self._delay_iters > 0 + and self._active_iteration_count == self._delay_iters + ): + logger.info("Starting profiler after delay...") + self._call_start() + + if self._running: + self._profiling_for_iters += 1 + + if ( + self._max_iters > 0 + and self._running + and self._profiling_for_iters > self._max_iters + ): + # Automatically stop the profiler after max iters + # will be marked as not running, but leave as active so that stop + # can clean up properly + logger.info("Max profiling iterations reached. Stopping profiler...") + self._call_stop() + return + + def stop(self) -> None: + """Attempt to stop the profiler, accounting for overlapped calls.""" + if not self._active: + logger.debug( + "stop_profile received when profiler is not active. Ignoring request." + ) + return + self._active = False + self._active_iteration_count = 0 + self._profiling_for_iters = 0 + + if self._running: + self._call_stop() + + def shutdown(self) -> None: + """Ensure profiler is stopped when shutting down.""" + logger.info_once("Shutting down profiler") + if self._running: + self.stop() + + def annotate_context_manager(self, name: str): + """Return a context manager to annotate profiler traces.""" + return nullcontext() + + +class TorchProfilerWrapper(WorkerProfiler): + def __init__(self, worker_name: str, local_rank: int) -> None: + super().__init__() + + self.local_rank = local_rank + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info( + "Torch profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) + logger.debug( + "Profiler config: record_shapes=%s," + "profile_memory=%s,with_stack=%s,with_flops=%s", + envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + envs.VLLM_TORCH_PROFILER_WITH_STACK, + envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True + ), + ) + + @override + def _start(self) -> None: + self.profiler.start() + + @override + def _stop(self) -> None: + self.profiler.stop() + + rank = self.local_rank + profiler_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" + sort_key = "self_cuda_time_total" + table = self.profiler.key_averages().table(sort_by=sort_key) + + with open(profiler_out_file, "w") as f: + print(table, file=f) + + # only print profiler results on rank 0 + if rank == 0: + print(table) + + @override + def annotate_context_manager(self, name: str): + return torch.profiler.record_function(name) + + +class CudaProfilerWrapper(WorkerProfiler): + def __init__(self) -> None: + super().__init__() # Note: lazy import to avoid dependency issues if CUDA is not available. import torch.cuda.profiler as cuda_profiler self._cuda_profiler = cuda_profiler - def start(self) -> None: - try: - self._cuda_profiler.start() - self._profiler_running = True - logger.info_once("Started CUDA profiler") - except Exception as e: - logger.warning_once("Failed to start CUDA profiler: %s", e) + @override + def _start(self) -> None: + self._cuda_profiler.start() - def stop(self) -> None: - if self._profiler_running: - try: - self._cuda_profiler.stop() - logger.info_once("Stopped CUDA profiler") - except Exception as e: - logger.warning_once("Failed to stop CUDA profiler: %s", e) - finally: - self._profiler_running = False + @override + def _stop(self) -> None: + self._cuda_profiler.stop() - def shutdown(self) -> None: - """Ensure profiler is stopped when shutting down.""" - self.stop() + @override + def annotate_context_manager(self, name: str): + return torch.cuda.nvtx.range(name) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c160c7cbcab4a..abf2c8cfa4539 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -160,11 +160,23 @@ class AsyncLLM(EngineClient): except RuntimeError: pass - if envs.VLLM_TORCH_PROFILER_DIR: + if ( + envs.VLLM_TORCH_PROFILER_DIR + and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM + ): logger.info( "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 envs.VLLM_TORCH_PROFILER_DIR, ) + if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0: + logger.warning_once( + "Torch profiler received max_iters or delay_iters setting. These " + "are not compatible with the AsyncLLM profiler and will be ignored " + "for the AsyncLLM process. Engine process profiling will still " + "respect these settings. Consider setting " + "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable " + "AsyncLLM profiling." + ) worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" self.profiler = torch.profiler.profile( activities=[ diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 7f9cdd221224b..18cbc38262793 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -36,7 +36,7 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform -from vllm.profiler.gpu_profiler import CudaProfilerWrapper +from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils.mem_constants import GiB_bytes @@ -90,32 +90,9 @@ class Worker(WorkerBase): # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" - logger.info( - "Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir, - ) - logger.debug( - "Profiler config: record_shapes=%s," - "profile_memory=%s,with_stack=%s,with_flops=%s", - envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - envs.VLLM_TORCH_PROFILER_WITH_STACK, - envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True - ), + self.profiler = TorchProfilerWrapper( + worker_name=worker_name, local_rank=self.local_rank ) elif envs.VLLM_TORCH_CUDA_PROFILE: self.profiler = CudaProfilerWrapper() @@ -526,10 +503,12 @@ class Worker(WorkerBase): if not self.profiler: return nullcontext() + self.profiler.step() + num_new = len(scheduler_output.scheduled_new_reqs) num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids) - return torch.profiler.record_function( + return self.profiler.annotate_context_manager( f"execute_new_{num_new}_cached_{num_cached}" ) @@ -587,24 +566,11 @@ class Worker(WorkerBase): def profile(self, is_start: bool = True): if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") + raise RuntimeError("Profiling is not enabled.") if is_start: self.profiler.start() else: self.profiler.stop() - if isinstance(self.profiler, torch.profiler.profile): - rank = self.local_rank - profiler_dir = envs.VLLM_TORCH_PROFILER_DIR - profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" - sort_key = "self_cuda_time_total" - table = self.profiler.key_averages().table(sort_by=sort_key) - - with open(profiler_out_file, "w") as f: - print(table, file=f) - - # only print profiler results on rank 0 - if rank == 0: - print(table) def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1, uniform_decode=True) @@ -865,6 +831,8 @@ class Worker(WorkerBase): def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): runner.ensure_kv_transfer_shutdown() + if self.profiler is not None: + self.profiler.shutdown() def init_worker_distributed_environment(