[Cleanup] Refactor profiling env vars into a CLI config (#29912)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Benjamin Chislett 2025-12-09 13:29:33 -05:00 committed by GitHub
parent d471b2aff0
commit e858bfe051
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 433 additions and 252 deletions

View File

@ -96,8 +96,9 @@ start_server() {
# This correctly passes each element as a separate argument.
if [[ -n "$profile_dir" ]]; then
# Start server with profiling enabled
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}"
VLLM_SERVER_DEV_MODE=1 \
vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 &
else
# Start server without profiling
VLLM_SERVER_DEV_MODE=1 \

View File

@ -963,8 +963,7 @@ def create_argument_parser():
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
parser.add_argument(
"--result-dir",

View File

@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes.
- [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][]
- [vllm.config.StructuredOutputsConfig][]
- [vllm.config.ProfilerConfig][]
- [vllm.config.ObservabilityConfig][]
- [vllm.config.KVTransferConfig][]
- [vllm.config.CompilationConfig][]

View File

@ -5,16 +5,15 @@
## Profile with PyTorch Profiler
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables:
We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config`
when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config:
- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default
- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default
- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
- `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default
- `torch_profiler_with_memory` to record memory, off by default
- `torch_profiler_with_stack` to enable recording stack information, on by default
- `torch_profiler_with_flops` to enable recording FLOPs, off by default
- `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default
- `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default
When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag.
@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
#### OpenAI Server
```bash
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
vllm serve meta-llama/Llama-3.1-8B-Instruct
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
```
vllm bench command:
@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with `
```bash
# server
VLLM_TORCH_CUDA_PROFILE=1 \
nsys profile \
--trace-fork-before-exec=true \
--cuda-graph-trace=node \
--capture-range=cudaProfilerApi \
--capture-range-end repeat \
vllm serve meta-llama/Llama-3.1-8B-Instruct
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda
# client
vllm bench serve \

View File

@ -1,14 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time
from vllm import LLM, SamplingParams
# enable torch profiler, can also be set on cmd line
os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
# Sample prompts.
prompts = [
"Hello, my name is",
@ -22,7 +18,14 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=1,
profiler_config={
"profiler": "torch",
"torch_profiler_dir": "./vllm_profile",
},
)
llm.start_profile()

View File

@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm.envs as envs
from vllm.profiler.gpu_profiler import WorkerProfiler
from vllm.config import ProfilerConfig
from vllm.profiler.wrapper import WorkerProfiler
class ConcreteWorkerProfiler(WorkerProfiler):
@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
A basic implementation of a worker profiler for testing purposes.
"""
def __init__(self):
def __init__(self, profiler_config: ProfilerConfig):
self.start_call_count = 0
self.stop_call_count = 0
self.should_fail_start = False
super().__init__()
super().__init__(profiler_config)
def _start(self) -> None:
if self.should_fail_start:
@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
self.stop_call_count += 1
@pytest.fixture(autouse=True)
def reset_mocks():
"""Fixture to reset mocks and env variables before each test."""
envs.VLLM_PROFILER_DELAY_ITERS = 0
envs.VLLM_PROFILER_MAX_ITERS = 0
@pytest.fixture
def default_profiler_config():
return ProfilerConfig(
profiler="torch",
torch_profiler_dir="/tmp/mock",
delay_iterations=0,
max_iterations=0,
)
def test_immediate_start_stop():
def test_immediate_start_stop(default_profiler_config):
"""Test standard start without delay."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
assert profiler._running is True
assert profiler._active is True
@ -48,10 +50,10 @@ def test_immediate_start_stop():
assert profiler.stop_call_count == 1
def test_delayed_start():
def test_delayed_start(default_profiler_config):
"""Test that profiler waits for N steps before actually starting."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
# User requests start
profiler.start()
@ -71,10 +73,10 @@ def test_delayed_start():
assert profiler.start_call_count == 1
def test_max_iterations():
def test_max_iterations(default_profiler_config):
"""Test that profiler stops automatically after max iterations."""
envs.VLLM_PROFILER_MAX_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
assert profiler._running is True
@ -95,12 +97,11 @@ def test_max_iterations():
assert profiler.stop_call_count == 1
def test_delayed_start_and_max_iters():
def test_delayed_start_and_max_iters(default_profiler_config):
"""Test combined delayed start and max iterations."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
envs.VLLM_PROFILER_MAX_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
# Step 1
@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
assert profiler.stop_call_count == 1
def test_idempotency():
def test_idempotency(default_profiler_config):
"""Test that calling start/stop multiple times doesn't break logic."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Double Start
profiler.start()
@ -142,10 +143,10 @@ def test_idempotency():
assert profiler.stop_call_count == 1 # Should only stop once
def test_step_inactive():
def test_step_inactive(default_profiler_config):
"""Test that stepping while inactive does nothing."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Not started yet
profiler.step()
@ -155,9 +156,9 @@ def test_step_inactive():
assert profiler.start_call_count == 0
def test_start_failure():
def test_start_failure(default_profiler_config):
"""Test behavior when the underlying _start method raises exception."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.should_fail_start = True
profiler.start()
@ -168,9 +169,9 @@ def test_start_failure():
assert profiler.start_call_count == 0 # Logic failed inside start
def test_shutdown():
def test_shutdown(default_profiler_config):
"""Test that shutdown calls stop only if running."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Case 1: Not running
profiler.shutdown()
@ -182,10 +183,10 @@ def test_shutdown():
assert profiler.stop_call_count == 1
def test_mixed_delay_and_stop():
def test_mixed_delay_and_stop(default_profiler_config):
"""Test manual stop during the delay period."""
envs.VLLM_PROFILER_DELAY_ITERS = 5
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 5
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
profiler.step()

View File

@ -12,7 +12,6 @@ from typing import Any
import numpy as np
from tqdm import tqdm
import vllm.envs as envs
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace):
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
raise OSError(
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
"Please set it to a valid path to use torch profiler."
)
engine_args = EngineArgs.from_cli_args(args)
if args.profile and not engine_args.profiler_config.profiler == "torch":
raise ValueError(
"The torch profiler is not enabled. Please provide profiler_config."
)
# Lazy import to avoid importing LLM when the bench command is not selected.
from vllm import LLM, SamplingParams
@ -144,7 +142,7 @@ def main(args: argparse.Namespace):
run_to_completion(profile_dir=None)
if args.profile:
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
profile_dir = engine_args.profiler_config.torch_profiler_dir
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return

View File

@ -1097,8 +1097,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
parser.add_argument(
"--save-result",

View File

@ -655,8 +655,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--profile",
action="store_true",
default=False,
help="Use Torch Profiler. The env variable "
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
# prefix repetition dataset

View File

@ -24,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig
from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig
@ -89,6 +90,8 @@ __all__ = [
"SpeechToTextConfig",
# From vllm.config.structured_outputs
"StructuredOutputsConfig",
# From vllm.config.profiler
"ProfilerConfig",
# From vllm.config.utils
"ConfigType",
"SupportsMetricsInfo",

199
vllm/config/profiler.py Normal file
View File

@ -0,0 +1,199 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Literal
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
@config
@dataclass
class ProfilerConfig:
"""Dataclass which contains profiler config for the engine."""
profiler: ProfilerKind | None = None
"""Which profiler to use. Defaults to None. Options are:
- 'torch': Use PyTorch profiler.\n
- 'cuda': Use CUDA profiler."""
torch_profiler_dir: str = ""
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
worker's traces (CPU & GPU) will be saved under this directory. Note that
it must be an absolute path."""
torch_profiler_with_stack: bool = True
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
torch_profiler_with_flops: bool = False
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
torch_profiler_use_gzip: bool = True
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
torch_profiler_dump_cuda_time_total: bool = True
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
torch_profiler_record_shapes: bool = False
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
torch_profiler_with_memory: bool = False
"""If `True`, enables memory profiling in the torch profiler.
Disabled by default."""
ignore_frontend: bool = False
"""If `True`, disables the front-end profiling of AsyncLLM when using the
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
since the front-end profiling does not track iterations and will capture the
entire range.
"""
delay_iterations: int = Field(default=0, ge=0)
"""Number of engine iterations to skip before starting profiling.
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
"""
max_iterations: int = Field(default=0, ge=0)
"""Maximum number of engine iterations to profile after starting profiling.
Defaults to 0, meaning no limit.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Get field from env var if set, with deprecation warning."""
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--profiler-config.%s command line argument or "
"ProfilerConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
return value
return None
def _set_from_env_if_set(
self,
field_name: str,
env_var_name: str,
to_bool: bool = True,
to_int: bool = False,
) -> None:
"""Set field from env var if set, with deprecation warning."""
value = self._get_from_env_if_set(field_name, env_var_name)
if value is not None:
if to_bool:
value = value == "1"
if to_int:
value = int(value)
setattr(self, field_name, value)
@model_validator(mode="after")
def _validate_profiler_config(self) -> Self:
maybe_use_cuda_profiler = self._get_from_env_if_set(
"profiler", "VLLM_TORCH_CUDA_PROFILE"
)
if maybe_use_cuda_profiler is not None:
self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None
else:
self._set_from_env_if_set(
"torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False
)
if self.torch_profiler_dir:
self.profiler = "torch"
self._set_from_env_if_set(
"torch_profiler_record_shapes",
"VLLM_TORCH_PROFILER_RECORD_SHAPES",
)
self._set_from_env_if_set(
"torch_profiler_with_memory",
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY",
)
self._set_from_env_if_set(
"torch_profiler_with_stack",
"VLLM_TORCH_PROFILER_WITH_STACK",
)
self._set_from_env_if_set(
"torch_profiler_with_flops",
"VLLM_TORCH_PROFILER_WITH_FLOPS",
)
self._set_from_env_if_set(
"ignore_frontend",
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM",
)
self._set_from_env_if_set(
"torch_profiler_use_gzip",
"VLLM_TORCH_PROFILER_USE_GZIP",
)
self._set_from_env_if_set(
"torch_profiler_dump_cuda_time_total",
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL",
)
self._set_from_env_if_set(
"delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True
)
self._set_from_env_if_set(
"max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True
)
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
logger.warning_once(
"Using 'torch' profiler with delay_iterations or max_iterations "
"while ignore_frontend is False may result in high overhead."
)
profiler_dir = self.torch_profiler_dir
if profiler_dir and self.profiler != "torch":
raise ValueError(
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
)
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
if profiler_dir:
is_gs_path = (
profiler_dir.startswith("gs://")
and profiler_dir[5:]
and profiler_dir[5] != "/"
)
if not is_gs_path:
self.torch_profiler_dir = os.path.abspath(
os.path.expanduser(profiler_dir)
)
return self

View File

@ -39,6 +39,7 @@ from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
@ -218,6 +219,8 @@ class VllmConfig:
You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
"""
profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
"""Profiling configuration."""
kv_transfer_config: KVTransferConfig | None = None
"""The configurations for distributed KV cache transfer."""
kv_events_config: KVEventsConfig | None = None
@ -296,6 +299,8 @@ class VllmConfig:
vllm_factors.append("None")
if self.structured_outputs_config:
vllm_factors.append(self.structured_outputs_config.compute_hash())
if self.profiler_config:
vllm_factors.append(self.profiler_config.compute_hash())
else:
vllm_factors.append("None")
vllm_factors.append(self.observability_config.compute_hash())

View File

@ -50,6 +50,7 @@ from vllm.config import (
ObservabilityConfig,
ParallelConfig,
PoolerConfig,
ProfilerConfig,
SchedulerConfig,
SpeculativeConfig,
StructuredOutputsConfig,
@ -532,6 +533,8 @@ class EngineArgs:
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")
kv_transfer_config: KVTransferConfig | None = None
kv_events_config: KVEventsConfig | None = None
@ -1164,7 +1167,7 @@ class EngineArgs:
vllm_group.add_argument(
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
)
vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
vllm_group.add_argument(
"--optimization-level", **vllm_kwargs["optimization_level"]
)
@ -1782,6 +1785,7 @@ class EngineArgs:
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
profiler_config=self.profiler_config,
additional_config=self.additional_config,
optimization_level=self.optimization_level,
)

View File

@ -20,6 +20,7 @@ from vllm.beam_search import (
from vllm.config import (
CompilationConfig,
PoolerConfig,
ProfilerConfig,
StructuredOutputsConfig,
is_init_field,
)
@ -211,6 +212,7 @@ class LLM:
structured_outputs_config: dict[str, Any]
| StructuredOutputsConfig
| None = None,
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None,
@ -282,6 +284,20 @@ class LLM:
else:
structured_outputs_instance = StructuredOutputsConfig()
if profiler_config is not None:
if isinstance(profiler_config, dict):
profiler_config_instance = ProfilerConfig(
**{
k: v
for k, v in profiler_config.items()
if is_init_field(ProfilerConfig, k)
}
)
else:
profiler_config_instance = profiler_config
else:
profiler_config_instance = ProfilerConfig()
# warn about single-process data parallel usage.
_dp_size = int(kwargs.get("data_parallel_size", 1))
_distributed_executor_backend = kwargs.get("distributed_executor_backend")
@ -324,6 +340,7 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs,
pooler_config=pooler_config,
structured_outputs_config=structured_outputs_instance,
profiler_config=profiler_config_instance,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
**kwargs,

View File

@ -5,7 +5,7 @@
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import Response
import vllm.envs as envs
from vllm.config import ProfilerConfig
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
@ -35,15 +35,12 @@ async def stop_profile(raw_request: Request):
def attach_router(app: FastAPI):
if envs.VLLM_TORCH_PROFILER_DIR:
profiler_config = getattr(app.state.args, "profiler_config", None)
assert profiler_config is None or isinstance(profiler_config, ProfilerConfig)
if profiler_config is not None and profiler_config.profiler is not None:
logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
"Profiler with mode '%s' is enabled in the "
"API server. This should ONLY be used for local development!",
profiler_config.profiler,
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
logger.warning_once(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
app.include_router(router)

View File

@ -89,20 +89,23 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_TORCH_CUDA_PROFILE: bool = False
# Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None
VLLM_TORCH_PROFILER_DIR: str | None = None
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False
VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None
VLLM_TORCH_PROFILER_WITH_STACK: str | None = None
VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None
VLLM_TORCH_PROFILER_USE_GZIP: str | None = None
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None
VLLM_PROFILER_DELAY_ITERS: str | None = None
VLLM_PROFILER_MAX_ITERS: str | None = None
# End of deprecated env variables for profiling
VLLM_USE_AOT_COMPILE: bool = False
VLLM_USE_BYTECODE_HOOK: bool = False
VLLM_FORCE_AOT_LOAD: bool = False
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_PROFILER_DELAY_ITERS: int = 0
VLLM_PROFILER_MAX_ITERS: int = 0
VLLM_TORCH_PROFILER_USE_GZIP: bool = True
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
@ -850,71 +853,52 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None
),
# Enables torch CUDA profiling if set.
# On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered.
"VLLM_TORCH_CUDA_PROFILE": lambda: bool(
os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0"
),
# Enables torch CUDA profiling if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
# Enables torch profiler if set.
# Both AsyncLLM's CPU traces as well as workers'
# traces (CPU & GPU) will be saved under this directory.
# Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR": lambda: (
None
if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None
else (
val
if val.startswith("gs://") and val[5:] and val[5] != "/"
else os.path.abspath(os.path.expanduser(val))
)
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"),
# Enable torch profiler to record shapes if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: (
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES")
),
# Enable torch profiler to record shapes if set
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
# not record shapes.
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"
# Enable torch profiler to profile memory if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: (
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY")
),
# Enable torch profiler to profile memory if set
# VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler
# will not profile memory.
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"
# Enable torch profiler to profile stack if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: (
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK")
),
# Enable torch profiler to profile stack if set
# VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL
# profile stack by default.
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"
# Enable torch profiler to profile flops if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: (
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS")
),
# Enable torch profiler to profile flops if set
# VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will
# not profile flops.
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
),
# Disable torch profiling of the AsyncLLMEngine process.
# If set to 1, will not profile the engine process.
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
# Disable torch profiling of the AsyncLLMEngine process if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: (
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM")
),
# Delay number of iterations before starting profiling when using
# the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
"VLLM_PROFILER_DELAY_ITERS": lambda: int(
os.getenv("VLLM_PROFILER_DELAY_ITERS", "0")
),
# Deprecated, see profiler_config.
"VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")),
# Maximum number of iterations to profile when using the torch/torch CUDA profiler.
# If set to 0, will not limit the number of iterations.
"VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")),
"VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"),
# Control whether torch profiler gzip-compresses profiling files.
# Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default).
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0"
),
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"),
# Control whether torch profiler dumps the self_cuda_time_total table.
# Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping
# (enabled by default).
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0"
# Set to 0 to disable dumping the table.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: (
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL")
),
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),

View File

@ -3,26 +3,27 @@
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Literal
import torch
from typing_extensions import override
import vllm.envs as envs
from vllm.config import ProfilerConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
class WorkerProfiler(ABC):
def __init__(self) -> None:
self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS
def __init__(self, profiler_config: ProfilerConfig) -> None:
self._delay_iters = profiler_config.delay_iterations
if self._delay_iters > 0:
logger.info_once(
"GPU profiling will start "
f"{self._delay_iters} steps after start_profile."
)
self._max_iters = envs.VLLM_PROFILER_MAX_ITERS
self._max_iters = profiler_config.max_iterations
if self._max_iters > 0:
logger.info_once(
"GPU profiling will stop "
@ -133,12 +134,27 @@ class WorkerProfiler(ABC):
return nullcontext()
TorchProfilerActivity = Literal["CPU", "CUDA", "XPU"]
TorchProfilerActivityMap = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"CUDA": torch.profiler.ProfilerActivity.CUDA,
"XPU": torch.profiler.ProfilerActivity.XPU,
}
class TorchProfilerWrapper(WorkerProfiler):
def __init__(self, worker_name: str, local_rank: int) -> None:
super().__init__()
def __init__(
self,
profiler_config: ProfilerConfig,
worker_name: str,
local_rank: int,
activities: list[TorchProfilerActivity],
) -> None:
super().__init__(profiler_config)
self.local_rank = local_rank
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
self.profiler_config = profiler_config
torch_profiler_trace_dir = profiler_config.torch_profiler_dir
if local_rank in (None, 0):
logger.info(
"Torch profiling enabled. Traces will be saved to: %s",
@ -147,24 +163,23 @@ class TorchProfilerWrapper(WorkerProfiler):
logger.debug(
"Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
profiler_config.torch_profiler_record_shapes,
profiler_config.torch_profiler_with_memory,
profiler_config.torch_profiler_with_stack,
profiler_config.torch_profiler_with_flops,
)
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
activities=[TorchProfilerActivityMap[activity] for activity in activities],
record_shapes=profiler_config.torch_profiler_record_shapes,
profile_memory=profiler_config.torch_profiler_with_memory,
with_stack=profiler_config.torch_profiler_with_stack,
with_flops=profiler_config.torch_profiler_with_flops,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir,
worker_name=worker_name,
use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP,
use_gzip=profiler_config.torch_profiler_use_gzip,
),
)
@ -176,9 +191,10 @@ class TorchProfilerWrapper(WorkerProfiler):
def _stop(self) -> None:
self.profiler.stop()
if envs.VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL:
rank = self.local_rank
profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
profiler_config = self.profiler_config
rank = self.local_rank
if profiler_config.torch_profiler_dump_cuda_time_total:
profiler_dir = profiler_config.torch_profiler_dir
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
sort_key = "self_cuda_time_total"
table = self.profiler.key_averages().table(sort_by=sort_key)
@ -189,6 +205,12 @@ class TorchProfilerWrapper(WorkerProfiler):
# only print profiler results on rank 0
if rank == 0:
print(table)
if self.dump_cpu_time_total and rank == 0:
logger.info(
self.profiler.key_averages().table(
sort_by="self_cpu_time_total", row_limit=50
)
)
@override
def annotate_context_manager(self, name: str):
@ -196,8 +218,8 @@ class TorchProfilerWrapper(WorkerProfiler):
class CudaProfilerWrapper(WorkerProfiler):
def __init__(self) -> None:
super().__init__()
def __init__(self, profiler_config: ProfilerConfig) -> None:
super().__init__(profiler_config)
# Note: lazy import to avoid dependency issues if CUDA is not available.
import torch.cuda.profiler as cuda_profiler

View File

@ -166,32 +166,24 @@ class AsyncLLM(EngineClient):
pass
if (
envs.VLLM_TORCH_PROFILER_DIR
and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM
vllm_config.profiler_config.profiler == "torch"
and not vllm_config.profiler_config.ignore_frontend
):
profiler_dir = vllm_config.profiler_config.torch_profiler_dir
logger.info(
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
envs.VLLM_TORCH_PROFILER_DIR,
profiler_dir,
)
if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0:
logger.warning_once(
"Torch profiler received max_iters or delay_iters setting. These "
"are not compatible with the AsyncLLM profiler and will be ignored "
"for the AsyncLLM process. Engine process profiling will still "
"respect these settings. Consider setting "
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable "
"AsyncLLM profiling."
)
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
],
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
envs.VLLM_TORCH_PROFILER_DIR,
profiler_dir,
worker_name=worker_name,
use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP,
use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
),
)
else:

View File

@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
@ -38,30 +39,17 @@ class CPUWorker(Worker):
self.parallel_config.disable_custom_all_reduce = True
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
# Torch profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
logger.info(
"Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU"],
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
],
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=False
),
)
else:
self.profiler = None
def init_device(self):
# Setup OpenMP threads affinity.
@ -202,9 +190,3 @@ class CPUWorker(Worker):
self.profiler.start()
else:
self.profiler.stop()
if self.local_rank == 0:
logger.info(
self.profiler.key_averages().table(
sort_by="self_cpu_time_total", row_limit=50
)
)

View File

@ -38,7 +38,7 @@ from vllm.model_executor import set_random_seed
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes
@ -92,17 +92,19 @@ class Worker(WorkerBase):
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Torch/CUDA profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
# VLLM_TORCH_CUDA_PROFILE=1
# Torch/CUDA profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR:
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
self.profiler = TorchProfilerWrapper(
worker_name=worker_name, local_rank=self.local_rank
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU", "CUDA"],
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
self.profiler = CudaProfilerWrapper()
elif profiler_config.profiler == "cuda":
self.profiler = CudaProfilerWrapper(profiler_config)
else:
self.profiler = None

View File

@ -98,10 +98,10 @@ class TPUWorker:
# MP runtime is initialized.
self.profiler = None
self.profile_dir = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
logger.info(
"Profiling enabled. Traces will be saved to: %s", self.profile_dir
)

View File

@ -6,12 +6,12 @@ from typing import Any
import torch
import torch.distributed
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_world_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
@ -36,41 +36,17 @@ class XPUWorker(Worker):
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
# Torch profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
logger.info(
"Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU", "XPU"],
)
logger.debug(
"Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.XPU,
],
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir,
worker_name=worker_name,
use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP,
),
)
else:
self.profiler = None
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
# return correct free_gpu_memory on intel client GPU. We need to