From e858bfe05167a3bbb064e283da5a1a7709dee24e Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Dec 2025 13:29:33 -0500 Subject: [PATCH] [Cleanup] Refactor profiling env vars into a CLI config (#29912) Signed-off-by: Benjamin Chislett Signed-off-by: Benjamin Chislett Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- benchmarks/auto_tune/auto_tune.sh | 5 +- .../benchmark_serving_structured_output.py | 3 +- docs/api/README.md | 1 + docs/contributing/profiling.md | 23 +- .../offline_inference/simple_profiling.py | 13 +- tests/v1/worker/test_gpu_profiler.py | 71 ++++--- vllm/benchmarks/latency.py | 12 +- vllm/benchmarks/serve.py | 3 +- vllm/benchmarks/throughput.py | 3 +- vllm/config/__init__.py | 3 + vllm/config/profiler.py | 199 ++++++++++++++++++ vllm/config/vllm.py | 5 + vllm/engine/arg_utils.py | 6 +- vllm/entrypoints/llm.py | 17 ++ vllm/entrypoints/serve/profile/api_router.py | 17 +- vllm/envs.py | 110 +++++----- vllm/profiler/{gpu_profiler.py => wrapper.py} | 72 ++++--- vllm/v1/engine/async_llm.py | 22 +- vllm/v1/worker/cpu_worker.py | 36 +--- vllm/v1/worker/gpu_worker.py | 18 +- vllm/v1/worker/tpu_worker.py | 4 +- vllm/v1/worker/xpu_worker.py | 42 +--- 22 files changed, 433 insertions(+), 252 deletions(-) create mode 100644 vllm/config/profiler.py rename vllm/profiler/{gpu_profiler.py => wrapper.py} (73%) diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index 56b721cbb4021..25baa9cbda39c 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -96,8 +96,9 @@ start_server() { # This correctly passes each element as a separate argument. if [[ -n "$profile_dir" ]]; then # Start server with profiling enabled - VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \ - vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & + local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}" + VLLM_SERVER_DEV_MODE=1 \ + vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 & else # Start server without profiling VLLM_SERVER_DEV_MODE=1 \ diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index df122b4c5e8db..a4e1b163dcca9 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -963,8 +963,7 @@ def create_argument_parser(): parser.add_argument( "--profile", action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "VLLM_TORCH_PROFILER_DIR to enable profiler.", + help="Use vLLM Profiling. --profiler-config must be provided on the server.", ) parser.add_argument( "--result-dir", diff --git a/docs/api/README.md b/docs/api/README.md index d3a141f327308..d51329ec2faa3 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes. - [vllm.config.MultiModalConfig][] - [vllm.config.PoolerConfig][] - [vllm.config.StructuredOutputsConfig][] +- [vllm.config.ProfilerConfig][] - [vllm.config.ObservabilityConfig][] - [vllm.config.KVTransferConfig][] - [vllm.config.CompilationConfig][] diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 65382afbe4f21..cbce14ce992ec 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -5,16 +5,15 @@ ## Profile with PyTorch Profiler -We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables: +We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config` +when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config: -- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default -- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default -- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default -- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default -- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default -- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default - -The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set. +- `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default +- `torch_profiler_with_memory` to record memory, off by default +- `torch_profiler_with_stack` to enable recording stack information, on by default +- `torch_profiler_with_flops` to enable recording FLOPs, off by default +- `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default +- `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag. @@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline #### OpenAI Server ```bash -VLLM_TORCH_PROFILER_DIR=./vllm_profile \ - vllm serve meta-llama/Llama-3.1-8B-Instruct +vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}' ``` vllm bench command: @@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with ` ```bash # server -VLLM_TORCH_CUDA_PROFILE=1 \ nsys profile \ --trace-fork-before-exec=true \ --cuda-graph-trace=node \ --capture-range=cudaProfilerApi \ --capture-range-end repeat \ - vllm serve meta-llama/Llama-3.1-8B-Instruct + vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda # client vllm bench serve \ diff --git a/examples/offline_inference/simple_profiling.py b/examples/offline_inference/simple_profiling.py index 46858fffadc52..e8a75cd03befb 100644 --- a/examples/offline_inference/simple_profiling.py +++ b/examples/offline_inference/simple_profiling.py @@ -1,14 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import time from vllm import LLM, SamplingParams -# enable torch profiler, can also be set on cmd line -os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile" - # Sample prompts. prompts = [ "Hello, my name is", @@ -22,7 +18,14 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) + llm = LLM( + model="facebook/opt-125m", + tensor_parallel_size=1, + profiler_config={ + "profiler": "torch", + "torch_profiler_dir": "./vllm_profile", + }, + ) llm.start_profile() diff --git a/tests/v1/worker/test_gpu_profiler.py b/tests/v1/worker/test_gpu_profiler.py index f7255fae05a4e..933ea42f18cd5 100644 --- a/tests/v1/worker/test_gpu_profiler.py +++ b/tests/v1/worker/test_gpu_profiler.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -import vllm.envs as envs -from vllm.profiler.gpu_profiler import WorkerProfiler +from vllm.config import ProfilerConfig +from vllm.profiler.wrapper import WorkerProfiler class ConcreteWorkerProfiler(WorkerProfiler): @@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler): A basic implementation of a worker profiler for testing purposes. """ - def __init__(self): + def __init__(self, profiler_config: ProfilerConfig): self.start_call_count = 0 self.stop_call_count = 0 self.should_fail_start = False - super().__init__() + super().__init__(profiler_config) def _start(self) -> None: if self.should_fail_start: @@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler): self.stop_call_count += 1 -@pytest.fixture(autouse=True) -def reset_mocks(): - """Fixture to reset mocks and env variables before each test.""" - envs.VLLM_PROFILER_DELAY_ITERS = 0 - envs.VLLM_PROFILER_MAX_ITERS = 0 +@pytest.fixture +def default_profiler_config(): + return ProfilerConfig( + profiler="torch", + torch_profiler_dir="/tmp/mock", + delay_iterations=0, + max_iterations=0, + ) -def test_immediate_start_stop(): +def test_immediate_start_stop(default_profiler_config): """Test standard start without delay.""" - profiler = ConcreteWorkerProfiler() - + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.start() assert profiler._running is True assert profiler._active is True @@ -48,10 +50,10 @@ def test_immediate_start_stop(): assert profiler.stop_call_count == 1 -def test_delayed_start(): +def test_delayed_start(default_profiler_config): """Test that profiler waits for N steps before actually starting.""" - envs.VLLM_PROFILER_DELAY_ITERS = 2 - profiler = ConcreteWorkerProfiler() + default_profiler_config.delay_iterations = 2 + profiler = ConcreteWorkerProfiler(default_profiler_config) # User requests start profiler.start() @@ -71,10 +73,10 @@ def test_delayed_start(): assert profiler.start_call_count == 1 -def test_max_iterations(): +def test_max_iterations(default_profiler_config): """Test that profiler stops automatically after max iterations.""" - envs.VLLM_PROFILER_MAX_ITERS = 2 - profiler = ConcreteWorkerProfiler() + default_profiler_config.max_iterations = 2 + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.start() assert profiler._running is True @@ -95,12 +97,11 @@ def test_max_iterations(): assert profiler.stop_call_count == 1 -def test_delayed_start_and_max_iters(): +def test_delayed_start_and_max_iters(default_profiler_config): """Test combined delayed start and max iterations.""" - envs.VLLM_PROFILER_DELAY_ITERS = 2 - envs.VLLM_PROFILER_MAX_ITERS = 2 - profiler = ConcreteWorkerProfiler() - + default_profiler_config.delay_iterations = 2 + default_profiler_config.max_iterations = 2 + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.start() # Step 1 @@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters(): assert profiler.stop_call_count == 1 -def test_idempotency(): +def test_idempotency(default_profiler_config): """Test that calling start/stop multiple times doesn't break logic.""" - profiler = ConcreteWorkerProfiler() + profiler = ConcreteWorkerProfiler(default_profiler_config) # Double Start profiler.start() @@ -142,10 +143,10 @@ def test_idempotency(): assert profiler.stop_call_count == 1 # Should only stop once -def test_step_inactive(): +def test_step_inactive(default_profiler_config): """Test that stepping while inactive does nothing.""" - envs.VLLM_PROFILER_DELAY_ITERS = 2 - profiler = ConcreteWorkerProfiler() + default_profiler_config.delay_iterations = 2 + profiler = ConcreteWorkerProfiler(default_profiler_config) # Not started yet profiler.step() @@ -155,9 +156,9 @@ def test_step_inactive(): assert profiler.start_call_count == 0 -def test_start_failure(): +def test_start_failure(default_profiler_config): """Test behavior when the underlying _start method raises exception.""" - profiler = ConcreteWorkerProfiler() + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.should_fail_start = True profiler.start() @@ -168,9 +169,9 @@ def test_start_failure(): assert profiler.start_call_count == 0 # Logic failed inside start -def test_shutdown(): +def test_shutdown(default_profiler_config): """Test that shutdown calls stop only if running.""" - profiler = ConcreteWorkerProfiler() + profiler = ConcreteWorkerProfiler(default_profiler_config) # Case 1: Not running profiler.shutdown() @@ -182,10 +183,10 @@ def test_shutdown(): assert profiler.stop_call_count == 1 -def test_mixed_delay_and_stop(): +def test_mixed_delay_and_stop(default_profiler_config): """Test manual stop during the delay period.""" - envs.VLLM_PROFILER_DELAY_ITERS = 5 - profiler = ConcreteWorkerProfiler() + default_profiler_config.delay_iterations = 5 + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.start() profiler.step() diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index b4f1751837f48..99c1c846f19af 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -12,7 +12,6 @@ from typing import Any import numpy as np from tqdm import tqdm -import vllm.envs as envs from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType @@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser): def main(args: argparse.Namespace): - if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: - raise OSError( - "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " - "Please set it to a valid path to use torch profiler." - ) engine_args = EngineArgs.from_cli_args(args) + if args.profile and not engine_args.profiler_config.profiler == "torch": + raise ValueError( + "The torch profiler is not enabled. Please provide profiler_config." + ) # Lazy import to avoid importing LLM when the bench command is not selected. from vllm import LLM, SamplingParams @@ -144,7 +142,7 @@ def main(args: argparse.Namespace): run_to_completion(profile_dir=None) if args.profile: - profile_dir = envs.VLLM_TORCH_PROFILER_DIR + profile_dir = engine_args.profiler_config.torch_profiler_dir print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 568290aa894ff..2e2054a8a4b13 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -1097,8 +1097,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--profile", action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "VLLM_TORCH_PROFILER_DIR to enable profiler.", + help="Use vLLM Profiling. --profiler-config must be provided on the server.", ) parser.add_argument( "--save-result", diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index ea693613fdd16..d824e982b7489 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -655,8 +655,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--profile", action="store_true", default=False, - help="Use Torch Profiler. The env variable " - "VLLM_TORCH_PROFILER_DIR must be set to enable profiler.", + help="Use vLLM Profiling. --profiler-config must be provided on the server.", ) # prefix repetition dataset diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 0f84f3ca9d3e3..0e91dd57420a8 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -24,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig from vllm.config.observability import ObservabilityConfig from vllm.config.parallel import EPLBConfig, ParallelConfig from vllm.config.pooler import PoolerConfig +from vllm.config.profiler import ProfilerConfig from vllm.config.scheduler import SchedulerConfig from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig @@ -89,6 +90,8 @@ __all__ = [ "SpeechToTextConfig", # From vllm.config.structured_outputs "StructuredOutputsConfig", + # From vllm.config.profiler + "ProfilerConfig", # From vllm.config.utils "ConfigType", "SupportsMetricsInfo", diff --git a/vllm/config/profiler.py b/vllm/config/profiler.py new file mode 100644 index 0000000000000..76cc546f3c9e2 --- /dev/null +++ b/vllm/config/profiler.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Any, Literal + +from pydantic import Field, model_validator +from pydantic.dataclasses import dataclass +from typing_extensions import Self + +import vllm.envs as envs +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash + +logger = init_logger(__name__) + +ProfilerKind = Literal["torch", "cuda"] + + +@config +@dataclass +class ProfilerConfig: + """Dataclass which contains profiler config for the engine.""" + + profiler: ProfilerKind | None = None + """Which profiler to use. Defaults to None. Options are: + + - 'torch': Use PyTorch profiler.\n + - 'cuda': Use CUDA profiler.""" + + torch_profiler_dir: str = "" + """Directory to save torch profiler traces. Both AsyncLLM's CPU traces and + worker's traces (CPU & GPU) will be saved under this directory. Note that + it must be an absolute path.""" + + torch_profiler_with_stack: bool = True + """If `True`, enables stack tracing in the torch profiler. Enabled by default.""" + + torch_profiler_with_flops: bool = False + """If `True`, enables FLOPS counting in the torch profiler. Disabled by default.""" + + torch_profiler_use_gzip: bool = True + """If `True`, saves torch profiler traces in gzip format. Enabled by default""" + + torch_profiler_dump_cuda_time_total: bool = True + """If `True`, dumps total CUDA time in torch profiler traces. Enabled by default.""" + + torch_profiler_record_shapes: bool = False + """If `True`, records tensor shapes in the torch profiler. Disabled by default.""" + + torch_profiler_with_memory: bool = False + """If `True`, enables memory profiling in the torch profiler. + Disabled by default.""" + + ignore_frontend: bool = False + """If `True`, disables the front-end profiling of AsyncLLM when using the + 'torch' profiler. This is needed to reduce overhead when using delay/limit options, + since the front-end profiling does not track iterations and will capture the + entire range. + """ + + delay_iterations: int = Field(default=0, ge=0) + """Number of engine iterations to skip before starting profiling. + Defaults to 0, meaning profiling starts immediately after receiving /start_profile. + """ + + max_iterations: int = Field(default=0, ge=0) + """Maximum number of engine iterations to profile after starting profiling. + Defaults to 0, meaning no limit. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None: + """Get field from env var if set, with deprecation warning.""" + + if envs.is_set(env_var_name): + value = getattr(envs, env_var_name) + logger.warning_once( + "Using %s environment variable is deprecated and will be removed in " + "v0.14.0 or v1.0.0, whichever is soonest. Please use " + "--profiler-config.%s command line argument or " + "ProfilerConfig(%s=...) config field instead.", + env_var_name, + field_name, + field_name, + ) + return value + return None + + def _set_from_env_if_set( + self, + field_name: str, + env_var_name: str, + to_bool: bool = True, + to_int: bool = False, + ) -> None: + """Set field from env var if set, with deprecation warning.""" + value = self._get_from_env_if_set(field_name, env_var_name) + if value is not None: + if to_bool: + value = value == "1" + if to_int: + value = int(value) + setattr(self, field_name, value) + + @model_validator(mode="after") + def _validate_profiler_config(self) -> Self: + maybe_use_cuda_profiler = self._get_from_env_if_set( + "profiler", "VLLM_TORCH_CUDA_PROFILE" + ) + if maybe_use_cuda_profiler is not None: + self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None + else: + self._set_from_env_if_set( + "torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False + ) + if self.torch_profiler_dir: + self.profiler = "torch" + self._set_from_env_if_set( + "torch_profiler_record_shapes", + "VLLM_TORCH_PROFILER_RECORD_SHAPES", + ) + self._set_from_env_if_set( + "torch_profiler_with_memory", + "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", + ) + self._set_from_env_if_set( + "torch_profiler_with_stack", + "VLLM_TORCH_PROFILER_WITH_STACK", + ) + self._set_from_env_if_set( + "torch_profiler_with_flops", + "VLLM_TORCH_PROFILER_WITH_FLOPS", + ) + self._set_from_env_if_set( + "ignore_frontend", + "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", + ) + self._set_from_env_if_set( + "torch_profiler_use_gzip", + "VLLM_TORCH_PROFILER_USE_GZIP", + ) + self._set_from_env_if_set( + "torch_profiler_dump_cuda_time_total", + "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", + ) + + self._set_from_env_if_set( + "delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True + ) + self._set_from_env_if_set( + "max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True + ) + + has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0 + if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend: + logger.warning_once( + "Using 'torch' profiler with delay_iterations or max_iterations " + "while ignore_frontend is False may result in high overhead." + ) + + profiler_dir = self.torch_profiler_dir + if profiler_dir and self.profiler != "torch": + raise ValueError( + "torch_profiler_dir is only applicable when profiler is set to 'torch'" + ) + if self.profiler == "torch" and not profiler_dir: + raise ValueError("torch_profiler_dir must be set when profiler is 'torch'") + + if profiler_dir: + is_gs_path = ( + profiler_dir.startswith("gs://") + and profiler_dir[5:] + and profiler_dir[5] != "/" + ) + if not is_gs_path: + self.torch_profiler_dir = os.path.abspath( + os.path.expanduser(profiler_dir) + ) + + return self diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a74413536407b..614a3226cb711 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -39,6 +39,7 @@ from .lora import LoRAConfig from .model import ModelConfig from .observability import ObservabilityConfig from .parallel import ParallelConfig +from .profiler import ProfilerConfig from .scheduler import SchedulerConfig from .speculative import SpeculativeConfig from .structured_outputs import StructuredOutputsConfig @@ -218,6 +219,8 @@ class VllmConfig: You can specify the full compilation config like so: `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` """ + profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig) + """Profiling configuration.""" kv_transfer_config: KVTransferConfig | None = None """The configurations for distributed KV cache transfer.""" kv_events_config: KVEventsConfig | None = None @@ -296,6 +299,8 @@ class VllmConfig: vllm_factors.append("None") if self.structured_outputs_config: vllm_factors.append(self.structured_outputs_config.compute_hash()) + if self.profiler_config: + vllm_factors.append(self.profiler_config.compute_hash()) else: vllm_factors.append("None") vllm_factors.append(self.observability_config.compute_hash()) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ceac5407af6e2..2f307a7ccf16d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -50,6 +50,7 @@ from vllm.config import ( ObservabilityConfig, ParallelConfig, PoolerConfig, + ProfilerConfig, SchedulerConfig, SpeculativeConfig, StructuredOutputsConfig, @@ -532,6 +533,8 @@ class EngineArgs: worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls + profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config") + kv_transfer_config: KVTransferConfig | None = None kv_events_config: KVEventsConfig | None = None @@ -1164,7 +1167,7 @@ class EngineArgs: vllm_group.add_argument( "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] ) - + vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"]) vllm_group.add_argument( "--optimization-level", **vllm_kwargs["optimization_level"] ) @@ -1782,6 +1785,7 @@ class EngineArgs: kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, ec_transfer_config=self.ec_transfer_config, + profiler_config=self.profiler_config, additional_config=self.additional_config, optimization_level=self.optimization_level, ) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 913324fd5f9c3..5d5c4a1cdb77b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -20,6 +20,7 @@ from vllm.beam_search import ( from vllm.config import ( CompilationConfig, PoolerConfig, + ProfilerConfig, StructuredOutputsConfig, is_init_field, ) @@ -211,6 +212,7 @@ class LLM: structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None, + profiler_config: dict[str, Any] | ProfilerConfig | None = None, kv_cache_memory_bytes: int | None = None, compilation_config: int | dict[str, Any] | CompilationConfig | None = None, logits_processors: list[str | type[LogitsProcessor]] | None = None, @@ -282,6 +284,20 @@ class LLM: else: structured_outputs_instance = StructuredOutputsConfig() + if profiler_config is not None: + if isinstance(profiler_config, dict): + profiler_config_instance = ProfilerConfig( + **{ + k: v + for k, v in profiler_config.items() + if is_init_field(ProfilerConfig, k) + } + ) + else: + profiler_config_instance = profiler_config + else: + profiler_config_instance = ProfilerConfig() + # warn about single-process data parallel usage. _dp_size = int(kwargs.get("data_parallel_size", 1)) _distributed_executor_backend = kwargs.get("distributed_executor_backend") @@ -324,6 +340,7 @@ class LLM: mm_processor_kwargs=mm_processor_kwargs, pooler_config=pooler_config, structured_outputs_config=structured_outputs_instance, + profiler_config=profiler_config_instance, compilation_config=compilation_config_instance, logits_processors=logits_processors, **kwargs, diff --git a/vllm/entrypoints/serve/profile/api_router.py b/vllm/entrypoints/serve/profile/api_router.py index 166f13764eb36..eeed6b45ef4e9 100644 --- a/vllm/entrypoints/serve/profile/api_router.py +++ b/vllm/entrypoints/serve/profile/api_router.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, FastAPI, Request from fastapi.responses import Response -import vllm.envs as envs +from vllm.config import ProfilerConfig from vllm.engine.protocol import EngineClient from vllm.logger import init_logger @@ -35,15 +35,12 @@ async def stop_profile(raw_request: Request): def attach_router(app: FastAPI): - if envs.VLLM_TORCH_PROFILER_DIR: + profiler_config = getattr(app.state.args, "profiler_config", None) + assert profiler_config is None or isinstance(profiler_config, ProfilerConfig) + if profiler_config is not None and profiler_config.profiler is not None: logger.warning_once( - "Torch Profiler is enabled in the API server. This should ONLY be " - "used for local development!" + "Profiler with mode '%s' is enabled in the " + "API server. This should ONLY be used for local development!", + profiler_config.profiler, ) - elif envs.VLLM_TORCH_CUDA_PROFILE: - logger.warning_once( - "CUDA Profiler is enabled in the API server. This should ONLY be " - "used for local development!" - ) - if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE: app.include_router(router) diff --git a/vllm/envs.py b/vllm/envs.py index bda9e6e423356..8246109eb73af 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -89,20 +89,23 @@ if TYPE_CHECKING: VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_PLUGINS: list[str] | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None - VLLM_TORCH_CUDA_PROFILE: bool = False + # Deprecated env variables for profiling, kept for backward compatibility + # See also vllm/config/profiler.py and `--profiler-config` argument + VLLM_TORCH_CUDA_PROFILE: str | None = None VLLM_TORCH_PROFILER_DIR: str | None = None - VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False - VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False - VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False + VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None + VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None + VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None + VLLM_TORCH_PROFILER_WITH_STACK: str | None = None + VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None + VLLM_TORCH_PROFILER_USE_GZIP: str | None = None + VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None + VLLM_PROFILER_DELAY_ITERS: str | None = None + VLLM_PROFILER_MAX_ITERS: str | None = None + # End of deprecated env variables for profiling VLLM_USE_AOT_COMPILE: bool = False VLLM_USE_BYTECODE_HOOK: bool = False VLLM_FORCE_AOT_LOAD: bool = False - VLLM_TORCH_PROFILER_WITH_STACK: bool = True - VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False - VLLM_PROFILER_DELAY_ITERS: int = 0 - VLLM_PROFILER_MAX_ITERS: int = 0 - VLLM_TORCH_PROFILER_USE_GZIP: bool = True - VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False @@ -850,71 +853,52 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( "VLLM_LORA_RESOLVER_CACHE_DIR", None ), - # Enables torch CUDA profiling if set. - # On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered. - "VLLM_TORCH_CUDA_PROFILE": lambda: bool( - os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0" - ), + # Enables torch CUDA profiling if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"), # Enables torch profiler if set. - # Both AsyncLLM's CPU traces as well as workers' - # traces (CPU & GPU) will be saved under this directory. - # Note that it must be an absolute path. - "VLLM_TORCH_PROFILER_DIR": lambda: ( - None - if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None - else ( - val - if val.startswith("gs://") and val[5:] and val[5] != "/" - else os.path.abspath(os.path.expanduser(val)) - ) + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"), + # Enable torch profiler to record shapes if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES") ), - # Enable torch profiler to record shapes if set - # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will - # not record shapes. - "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0" + # Enable torch profiler to profile memory if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY") ), - # Enable torch profiler to profile memory if set - # VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler - # will not profile memory. - "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0" + # Enable torch profiler to profile stack if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_STACK": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_STACK") ), - # Enable torch profiler to profile stack if set - # VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL - # profile stack by default. - "VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0" + # Enable torch profiler to profile flops if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS") ), - # Enable torch profiler to profile flops if set - # VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will - # not profile flops. - "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0" - ), - # Disable torch profiling of the AsyncLLMEngine process. - # If set to 1, will not profile the engine process. - "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0" + # Disable torch profiling of the AsyncLLMEngine process if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM") ), # Delay number of iterations before starting profiling when using # the torch/torch CUDA profiler. If set to 0, will start profiling immediately. - "VLLM_PROFILER_DELAY_ITERS": lambda: int( - os.getenv("VLLM_PROFILER_DELAY_ITERS", "0") - ), + # Deprecated, see profiler_config. + "VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")), # Maximum number of iterations to profile when using the torch/torch CUDA profiler. # If set to 0, will not limit the number of iterations. - "VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")), + "VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"), # Control whether torch profiler gzip-compresses profiling files. - # Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default). - "VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0" - ), + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"), # Control whether torch profiler dumps the self_cuda_time_total table. - # Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping - # (enabled by default). - "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0" + # Set to 0 to disable dumping the table. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL") ), # If set, vLLM will use Triton implementations of AWQ. "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), diff --git a/vllm/profiler/gpu_profiler.py b/vllm/profiler/wrapper.py similarity index 73% rename from vllm/profiler/gpu_profiler.py rename to vllm/profiler/wrapper.py index 798c615221b9f..a44a6a5eea0dd 100644 --- a/vllm/profiler/gpu_profiler.py +++ b/vllm/profiler/wrapper.py @@ -3,26 +3,27 @@ from abc import ABC, abstractmethod from contextlib import nullcontext +from typing import Literal import torch from typing_extensions import override -import vllm.envs as envs +from vllm.config import ProfilerConfig from vllm.logger import init_logger logger = init_logger(__name__) class WorkerProfiler(ABC): - def __init__(self) -> None: - self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS + def __init__(self, profiler_config: ProfilerConfig) -> None: + self._delay_iters = profiler_config.delay_iterations if self._delay_iters > 0: logger.info_once( "GPU profiling will start " f"{self._delay_iters} steps after start_profile." ) - self._max_iters = envs.VLLM_PROFILER_MAX_ITERS + self._max_iters = profiler_config.max_iterations if self._max_iters > 0: logger.info_once( "GPU profiling will stop " @@ -133,12 +134,27 @@ class WorkerProfiler(ABC): return nullcontext() +TorchProfilerActivity = Literal["CPU", "CUDA", "XPU"] +TorchProfilerActivityMap = { + "CPU": torch.profiler.ProfilerActivity.CPU, + "CUDA": torch.profiler.ProfilerActivity.CUDA, + "XPU": torch.profiler.ProfilerActivity.XPU, +} + + class TorchProfilerWrapper(WorkerProfiler): - def __init__(self, worker_name: str, local_rank: int) -> None: - super().__init__() + def __init__( + self, + profiler_config: ProfilerConfig, + worker_name: str, + local_rank: int, + activities: list[TorchProfilerActivity], + ) -> None: + super().__init__(profiler_config) self.local_rank = local_rank - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + self.profiler_config = profiler_config + torch_profiler_trace_dir = profiler_config.torch_profiler_dir if local_rank in (None, 0): logger.info( "Torch profiling enabled. Traces will be saved to: %s", @@ -147,24 +163,23 @@ class TorchProfilerWrapper(WorkerProfiler): logger.debug( "Profiler config: record_shapes=%s," "profile_memory=%s,with_stack=%s,with_flops=%s", - envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - envs.VLLM_TORCH_PROFILER_WITH_STACK, - envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + profiler_config.torch_profiler_record_shapes, + profiler_config.torch_profiler_with_memory, + profiler_config.torch_profiler_with_stack, + profiler_config.torch_profiler_with_flops, ) + + self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1 self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + activities=[TorchProfilerActivityMap[activity] for activity in activities], + record_shapes=profiler_config.torch_profiler_record_shapes, + profile_memory=profiler_config.torch_profiler_with_memory, + with_stack=profiler_config.torch_profiler_with_stack, + with_flops=profiler_config.torch_profiler_with_flops, on_trace_ready=torch.profiler.tensorboard_trace_handler( torch_profiler_trace_dir, worker_name=worker_name, - use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, + use_gzip=profiler_config.torch_profiler_use_gzip, ), ) @@ -176,9 +191,10 @@ class TorchProfilerWrapper(WorkerProfiler): def _stop(self) -> None: self.profiler.stop() - if envs.VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: - rank = self.local_rank - profiler_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_config = self.profiler_config + rank = self.local_rank + if profiler_config.torch_profiler_dump_cuda_time_total: + profiler_dir = profiler_config.torch_profiler_dir profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" sort_key = "self_cuda_time_total" table = self.profiler.key_averages().table(sort_by=sort_key) @@ -189,6 +205,12 @@ class TorchProfilerWrapper(WorkerProfiler): # only print profiler results on rank 0 if rank == 0: print(table) + if self.dump_cpu_time_total and rank == 0: + logger.info( + self.profiler.key_averages().table( + sort_by="self_cpu_time_total", row_limit=50 + ) + ) @override def annotate_context_manager(self, name: str): @@ -196,8 +218,8 @@ class TorchProfilerWrapper(WorkerProfiler): class CudaProfilerWrapper(WorkerProfiler): - def __init__(self) -> None: - super().__init__() + def __init__(self, profiler_config: ProfilerConfig) -> None: + super().__init__(profiler_config) # Note: lazy import to avoid dependency issues if CUDA is not available. import torch.cuda.profiler as cuda_profiler diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index fd7e04dc02082..931d13be3d9b6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -166,32 +166,24 @@ class AsyncLLM(EngineClient): pass if ( - envs.VLLM_TORCH_PROFILER_DIR - and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM + vllm_config.profiler_config.profiler == "torch" + and not vllm_config.profiler_config.ignore_frontend ): + profiler_dir = vllm_config.profiler_config.torch_profiler_dir logger.info( "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 - envs.VLLM_TORCH_PROFILER_DIR, + profiler_dir, ) - if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0: - logger.warning_once( - "Torch profiler received max_iters or delay_iters setting. These " - "are not compatible with the AsyncLLM profiler and will be ignored " - "for the AsyncLLM process. Engine process profiling will still " - "respect these settings. Consider setting " - "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable " - "AsyncLLM profiling." - ) worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" self.profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, ], - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_stack=vllm_config.profiler_config.torch_profiler_with_stack, on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, + profiler_dir, worker_name=worker_name, - use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, + use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip, ), ) else: diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index b080fea1d2dd6..e54b995ab908f 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.model_executor.utils import set_random_seed from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo +from vllm.profiler.wrapper import TorchProfilerWrapper from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment @@ -38,30 +39,17 @@ class CPUWorker(Worker): self.parallel_config.disable_custom_all_reduce = True - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + # Torch profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_config = vllm_config.profiler_config + if profiler_config.profiler == "torch": worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" - logger.info( - "Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir, + self.profiler = TorchProfilerWrapper( + profiler_config, + worker_name=worker_name, + local_rank=self.local_rank, + activities=["CPU"], ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - ], - record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, worker_name=worker_name, use_gzip=False - ), - ) - else: - self.profiler = None def init_device(self): # Setup OpenMP threads affinity. @@ -202,9 +190,3 @@ class CPUWorker(Worker): self.profiler.start() else: self.profiler.stop() - if self.local_rank == 0: - logger.info( - self.profiler.key_averages().table( - sort_by="self_cpu_time_total", row_limit=50 - ) - ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 24a3533a169f0..f2b6a1f76b0b9 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -38,7 +38,7 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform -from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper +from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils.mem_constants import GiB_bytes @@ -92,17 +92,19 @@ class Worker(WorkerBase): # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} - # Torch/CUDA profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - # VLLM_TORCH_CUDA_PROFILE=1 + # Torch/CUDA profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None - if envs.VLLM_TORCH_PROFILER_DIR: + profiler_config = vllm_config.profiler_config + if profiler_config.profiler == "torch": worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" self.profiler = TorchProfilerWrapper( - worker_name=worker_name, local_rank=self.local_rank + profiler_config, + worker_name=worker_name, + local_rank=self.local_rank, + activities=["CPU", "CUDA"], ) - elif envs.VLLM_TORCH_CUDA_PROFILE: - self.profiler = CudaProfilerWrapper() + elif profiler_config.profiler == "cuda": + self.profiler = CudaProfilerWrapper(profiler_config) else: self.profiler = None diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ce18ca6c37165..7a10ac198985e 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -98,10 +98,10 @@ class TPUWorker: # MP runtime is initialized. self.profiler = None self.profile_dir = None - if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + if vllm_config.profiler_config.profiler == "torch" and self.rank < 1: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. - self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + self.profile_dir = vllm_config.profiler_config.torch_profiler_dir logger.info( "Profiling enabled. Traces will be saved to: %s", self.profile_dir ) diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 267369c730368..1faa1a24ff0ea 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -6,12 +6,12 @@ from typing import Any import torch import torch.distributed -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_world_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.profiler.wrapper import TorchProfilerWrapper from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.xpu_model_runner import XPUModelRunner @@ -36,41 +36,17 @@ class XPUWorker(Worker): assert device_config.device_type == "xpu" assert current_platform.is_xpu() - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + # Torch profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_config = vllm_config.profiler_config + if profiler_config.profiler == "torch": worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" - logger.info( - "Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir, + self.profiler = TorchProfilerWrapper( + profiler_config, + worker_name=worker_name, + local_rank=self.local_rank, + activities=["CPU", "XPU"], ) - logger.debug( - "Profiler config: record_shapes=%s," - "profile_memory=%s,with_stack=%s,with_flops=%s", - envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - envs.VLLM_TORCH_PROFILER_WITH_STACK, - envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.XPU, - ], - record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, - worker_name=worker_name, - use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, - ), - ) - else: - self.profiler = None # we provide this function due to `torch.xpu.mem_get_info()` doesn't # return correct free_gpu_memory on intel client GPU. We need to