mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 19:15:57 +08:00
[Cleanup] Refactor profiling env vars into a CLI config (#29912)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
d471b2aff0
commit
e858bfe051
@ -96,8 +96,9 @@ start_server() {
|
||||
# This correctly passes each element as a separate argument.
|
||||
if [[ -n "$profile_dir" ]]; then
|
||||
# Start server with profiling enabled
|
||||
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
|
||||
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||
local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}"
|
||||
VLLM_SERVER_DEV_MODE=1 \
|
||||
vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||
else
|
||||
# Start server without profiling
|
||||
VLLM_SERVER_DEV_MODE=1 \
|
||||
|
||||
@ -963,8 +963,7 @@ def create_argument_parser():
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-dir",
|
||||
|
||||
@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes.
|
||||
- [vllm.config.MultiModalConfig][]
|
||||
- [vllm.config.PoolerConfig][]
|
||||
- [vllm.config.StructuredOutputsConfig][]
|
||||
- [vllm.config.ProfilerConfig][]
|
||||
- [vllm.config.ObservabilityConfig][]
|
||||
- [vllm.config.KVTransferConfig][]
|
||||
- [vllm.config.CompilationConfig][]
|
||||
|
||||
@ -5,16 +5,15 @@
|
||||
|
||||
## Profile with PyTorch Profiler
|
||||
|
||||
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables:
|
||||
We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config`
|
||||
when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config:
|
||||
|
||||
- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default
|
||||
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default
|
||||
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default
|
||||
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default
|
||||
- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default
|
||||
- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default
|
||||
|
||||
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
|
||||
- `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default
|
||||
- `torch_profiler_with_memory` to record memory, off by default
|
||||
- `torch_profiler_with_stack` to enable recording stack information, on by default
|
||||
- `torch_profiler_with_flops` to enable recording FLOPs, off by default
|
||||
- `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default
|
||||
- `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default
|
||||
|
||||
When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag.
|
||||
|
||||
@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
|
||||
#### OpenAI Server
|
||||
|
||||
```bash
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
|
||||
```
|
||||
|
||||
vllm bench command:
|
||||
@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with `
|
||||
|
||||
```bash
|
||||
# server
|
||||
VLLM_TORCH_CUDA_PROFILE=1 \
|
||||
nsys profile \
|
||||
--trace-fork-before-exec=true \
|
||||
--cuda-graph-trace=node \
|
||||
--capture-range=cudaProfilerApi \
|
||||
--capture-range-end repeat \
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda
|
||||
|
||||
# client
|
||||
vllm bench serve \
|
||||
|
||||
@ -1,14 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# enable torch profiler, can also be set on cmd line
|
||||
os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -22,7 +18,14 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=1,
|
||||
profiler_config={
|
||||
"profiler": "torch",
|
||||
"torch_profiler_dir": "./vllm_profile",
|
||||
},
|
||||
)
|
||||
|
||||
llm.start_profile()
|
||||
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.profiler.gpu_profiler import WorkerProfiler
|
||||
from vllm.config import ProfilerConfig
|
||||
from vllm.profiler.wrapper import WorkerProfiler
|
||||
|
||||
|
||||
class ConcreteWorkerProfiler(WorkerProfiler):
|
||||
@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
|
||||
A basic implementation of a worker profiler for testing purposes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, profiler_config: ProfilerConfig):
|
||||
self.start_call_count = 0
|
||||
self.stop_call_count = 0
|
||||
self.should_fail_start = False
|
||||
super().__init__()
|
||||
super().__init__(profiler_config)
|
||||
|
||||
def _start(self) -> None:
|
||||
if self.should_fail_start:
|
||||
@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
|
||||
self.stop_call_count += 1
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_mocks():
|
||||
"""Fixture to reset mocks and env variables before each test."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 0
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 0
|
||||
@pytest.fixture
|
||||
def default_profiler_config():
|
||||
return ProfilerConfig(
|
||||
profiler="torch",
|
||||
torch_profiler_dir="/tmp/mock",
|
||||
delay_iterations=0,
|
||||
max_iterations=0,
|
||||
)
|
||||
|
||||
|
||||
def test_immediate_start_stop():
|
||||
def test_immediate_start_stop(default_profiler_config):
|
||||
"""Test standard start without delay."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
profiler.start()
|
||||
assert profiler._running is True
|
||||
assert profiler._active is True
|
||||
@ -48,10 +50,10 @@ def test_immediate_start_stop():
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_delayed_start():
|
||||
def test_delayed_start(default_profiler_config):
|
||||
"""Test that profiler waits for N steps before actually starting."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
default_profiler_config.delay_iterations = 2
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
# User requests start
|
||||
profiler.start()
|
||||
@ -71,10 +73,10 @@ def test_delayed_start():
|
||||
assert profiler.start_call_count == 1
|
||||
|
||||
|
||||
def test_max_iterations():
|
||||
def test_max_iterations(default_profiler_config):
|
||||
"""Test that profiler stops automatically after max iterations."""
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
default_profiler_config.max_iterations = 2
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
profiler.start()
|
||||
assert profiler._running is True
|
||||
@ -95,12 +97,11 @@ def test_max_iterations():
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_delayed_start_and_max_iters():
|
||||
def test_delayed_start_and_max_iters(default_profiler_config):
|
||||
"""Test combined delayed start and max iterations."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
default_profiler_config.delay_iterations = 2
|
||||
default_profiler_config.max_iterations = 2
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
profiler.start()
|
||||
|
||||
# Step 1
|
||||
@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_idempotency():
|
||||
def test_idempotency(default_profiler_config):
|
||||
"""Test that calling start/stop multiple times doesn't break logic."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
# Double Start
|
||||
profiler.start()
|
||||
@ -142,10 +143,10 @@ def test_idempotency():
|
||||
assert profiler.stop_call_count == 1 # Should only stop once
|
||||
|
||||
|
||||
def test_step_inactive():
|
||||
def test_step_inactive(default_profiler_config):
|
||||
"""Test that stepping while inactive does nothing."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
default_profiler_config.delay_iterations = 2
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
# Not started yet
|
||||
profiler.step()
|
||||
@ -155,9 +156,9 @@ def test_step_inactive():
|
||||
assert profiler.start_call_count == 0
|
||||
|
||||
|
||||
def test_start_failure():
|
||||
def test_start_failure(default_profiler_config):
|
||||
"""Test behavior when the underlying _start method raises exception."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
profiler.should_fail_start = True
|
||||
|
||||
profiler.start()
|
||||
@ -168,9 +169,9 @@ def test_start_failure():
|
||||
assert profiler.start_call_count == 0 # Logic failed inside start
|
||||
|
||||
|
||||
def test_shutdown():
|
||||
def test_shutdown(default_profiler_config):
|
||||
"""Test that shutdown calls stop only if running."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
# Case 1: Not running
|
||||
profiler.shutdown()
|
||||
@ -182,10 +183,10 @@ def test_shutdown():
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_mixed_delay_and_stop():
|
||||
def test_mixed_delay_and_stop(default_profiler_config):
|
||||
"""Test manual stop during the delay period."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 5
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
default_profiler_config.delay_iterations = 5
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
profiler.start()
|
||||
profiler.step()
|
||||
|
||||
@ -12,7 +12,6 @@ from typing import Any
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
|
||||
raise OSError(
|
||||
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
|
||||
"Please set it to a valid path to use torch profiler."
|
||||
)
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
if args.profile and not engine_args.profiler_config.profiler == "torch":
|
||||
raise ValueError(
|
||||
"The torch profiler is not enabled. Please provide profiler_config."
|
||||
)
|
||||
|
||||
# Lazy import to avoid importing LLM when the bench command is not selected.
|
||||
from vllm import LLM, SamplingParams
|
||||
@ -144,7 +142,7 @@ def main(args: argparse.Namespace):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
profile_dir = engine_args.profiler_config.torch_profiler_dir
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
@ -1097,8 +1097,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-result",
|
||||
|
||||
@ -655,8 +655,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"--profile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use Torch Profiler. The env variable "
|
||||
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.",
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
|
||||
# prefix repetition dataset
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.config.observability import ObservabilityConfig
|
||||
from vllm.config.parallel import EPLBConfig, ParallelConfig
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.profiler import ProfilerConfig
|
||||
from vllm.config.scheduler import SchedulerConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||
@ -89,6 +90,8 @@ __all__ = [
|
||||
"SpeechToTextConfig",
|
||||
# From vllm.config.structured_outputs
|
||||
"StructuredOutputsConfig",
|
||||
# From vllm.config.profiler
|
||||
"ProfilerConfig",
|
||||
# From vllm.config.utils
|
||||
"ConfigType",
|
||||
"SupportsMetricsInfo",
|
||||
|
||||
199
vllm/config/profiler.py
Normal file
199
vllm/config/profiler.py
Normal file
@ -0,0 +1,199 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ProfilerKind = Literal["torch", "cuda"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ProfilerConfig:
|
||||
"""Dataclass which contains profiler config for the engine."""
|
||||
|
||||
profiler: ProfilerKind | None = None
|
||||
"""Which profiler to use. Defaults to None. Options are:
|
||||
|
||||
- 'torch': Use PyTorch profiler.\n
|
||||
- 'cuda': Use CUDA profiler."""
|
||||
|
||||
torch_profiler_dir: str = ""
|
||||
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
|
||||
worker's traces (CPU & GPU) will be saved under this directory. Note that
|
||||
it must be an absolute path."""
|
||||
|
||||
torch_profiler_with_stack: bool = True
|
||||
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
|
||||
|
||||
torch_profiler_with_flops: bool = False
|
||||
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
|
||||
|
||||
torch_profiler_use_gzip: bool = True
|
||||
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
|
||||
|
||||
torch_profiler_dump_cuda_time_total: bool = True
|
||||
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
|
||||
|
||||
torch_profiler_record_shapes: bool = False
|
||||
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
|
||||
|
||||
torch_profiler_with_memory: bool = False
|
||||
"""If `True`, enables memory profiling in the torch profiler.
|
||||
Disabled by default."""
|
||||
|
||||
ignore_frontend: bool = False
|
||||
"""If `True`, disables the front-end profiling of AsyncLLM when using the
|
||||
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
|
||||
since the front-end profiling does not track iterations and will capture the
|
||||
entire range.
|
||||
"""
|
||||
|
||||
delay_iterations: int = Field(default=0, ge=0)
|
||||
"""Number of engine iterations to skip before starting profiling.
|
||||
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
|
||||
"""
|
||||
|
||||
max_iterations: int = Field(default=0, ge=0)
|
||||
"""Maximum number of engine iterations to profile after starting profiling.
|
||||
Defaults to 0, meaning no limit.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
|
||||
"""Get field from env var if set, with deprecation warning."""
|
||||
|
||||
if envs.is_set(env_var_name):
|
||||
value = getattr(envs, env_var_name)
|
||||
logger.warning_once(
|
||||
"Using %s environment variable is deprecated and will be removed in "
|
||||
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
|
||||
"--profiler-config.%s command line argument or "
|
||||
"ProfilerConfig(%s=...) config field instead.",
|
||||
env_var_name,
|
||||
field_name,
|
||||
field_name,
|
||||
)
|
||||
return value
|
||||
return None
|
||||
|
||||
def _set_from_env_if_set(
|
||||
self,
|
||||
field_name: str,
|
||||
env_var_name: str,
|
||||
to_bool: bool = True,
|
||||
to_int: bool = False,
|
||||
) -> None:
|
||||
"""Set field from env var if set, with deprecation warning."""
|
||||
value = self._get_from_env_if_set(field_name, env_var_name)
|
||||
if value is not None:
|
||||
if to_bool:
|
||||
value = value == "1"
|
||||
if to_int:
|
||||
value = int(value)
|
||||
setattr(self, field_name, value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_profiler_config(self) -> Self:
|
||||
maybe_use_cuda_profiler = self._get_from_env_if_set(
|
||||
"profiler", "VLLM_TORCH_CUDA_PROFILE"
|
||||
)
|
||||
if maybe_use_cuda_profiler is not None:
|
||||
self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None
|
||||
else:
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False
|
||||
)
|
||||
if self.torch_profiler_dir:
|
||||
self.profiler = "torch"
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_record_shapes",
|
||||
"VLLM_TORCH_PROFILER_RECORD_SHAPES",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_with_memory",
|
||||
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_with_stack",
|
||||
"VLLM_TORCH_PROFILER_WITH_STACK",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_with_flops",
|
||||
"VLLM_TORCH_PROFILER_WITH_FLOPS",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"ignore_frontend",
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_use_gzip",
|
||||
"VLLM_TORCH_PROFILER_USE_GZIP",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_dump_cuda_time_total",
|
||||
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL",
|
||||
)
|
||||
|
||||
self._set_from_env_if_set(
|
||||
"delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True
|
||||
)
|
||||
|
||||
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
|
||||
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
|
||||
logger.warning_once(
|
||||
"Using 'torch' profiler with delay_iterations or max_iterations "
|
||||
"while ignore_frontend is False may result in high overhead."
|
||||
)
|
||||
|
||||
profiler_dir = self.torch_profiler_dir
|
||||
if profiler_dir and self.profiler != "torch":
|
||||
raise ValueError(
|
||||
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
|
||||
)
|
||||
if self.profiler == "torch" and not profiler_dir:
|
||||
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
|
||||
|
||||
if profiler_dir:
|
||||
is_gs_path = (
|
||||
profiler_dir.startswith("gs://")
|
||||
and profiler_dir[5:]
|
||||
and profiler_dir[5] != "/"
|
||||
)
|
||||
if not is_gs_path:
|
||||
self.torch_profiler_dir = os.path.abspath(
|
||||
os.path.expanduser(profiler_dir)
|
||||
)
|
||||
|
||||
return self
|
||||
@ -39,6 +39,7 @@ from .lora import LoRAConfig
|
||||
from .model import ModelConfig
|
||||
from .observability import ObservabilityConfig
|
||||
from .parallel import ParallelConfig
|
||||
from .profiler import ProfilerConfig
|
||||
from .scheduler import SchedulerConfig
|
||||
from .speculative import SpeculativeConfig
|
||||
from .structured_outputs import StructuredOutputsConfig
|
||||
@ -218,6 +219,8 @@ class VllmConfig:
|
||||
You can specify the full compilation config like so:
|
||||
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||
"""
|
||||
profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
|
||||
"""Profiling configuration."""
|
||||
kv_transfer_config: KVTransferConfig | None = None
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: KVEventsConfig | None = None
|
||||
@ -296,6 +299,8 @@ class VllmConfig:
|
||||
vllm_factors.append("None")
|
||||
if self.structured_outputs_config:
|
||||
vllm_factors.append(self.structured_outputs_config.compute_hash())
|
||||
if self.profiler_config:
|
||||
vllm_factors.append(self.profiler_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
vllm_factors.append(self.observability_config.compute_hash())
|
||||
|
||||
@ -50,6 +50,7 @@ from vllm.config import (
|
||||
ObservabilityConfig,
|
||||
ParallelConfig,
|
||||
PoolerConfig,
|
||||
ProfilerConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
StructuredOutputsConfig,
|
||||
@ -532,6 +533,8 @@ class EngineArgs:
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")
|
||||
|
||||
kv_transfer_config: KVTransferConfig | None = None
|
||||
kv_events_config: KVEventsConfig | None = None
|
||||
|
||||
@ -1164,7 +1167,7 @@ class EngineArgs:
|
||||
vllm_group.add_argument(
|
||||
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
|
||||
)
|
||||
|
||||
vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
|
||||
vllm_group.add_argument(
|
||||
"--optimization-level", **vllm_kwargs["optimization_level"]
|
||||
)
|
||||
@ -1782,6 +1785,7 @@ class EngineArgs:
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
ec_transfer_config=self.ec_transfer_config,
|
||||
profiler_config=self.profiler_config,
|
||||
additional_config=self.additional_config,
|
||||
optimization_level=self.optimization_level,
|
||||
)
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.beam_search import (
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
PoolerConfig,
|
||||
ProfilerConfig,
|
||||
StructuredOutputsConfig,
|
||||
is_init_field,
|
||||
)
|
||||
@ -211,6 +212,7 @@ class LLM:
|
||||
structured_outputs_config: dict[str, Any]
|
||||
| StructuredOutputsConfig
|
||||
| None = None,
|
||||
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
|
||||
kv_cache_memory_bytes: int | None = None,
|
||||
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
|
||||
logits_processors: list[str | type[LogitsProcessor]] | None = None,
|
||||
@ -282,6 +284,20 @@ class LLM:
|
||||
else:
|
||||
structured_outputs_instance = StructuredOutputsConfig()
|
||||
|
||||
if profiler_config is not None:
|
||||
if isinstance(profiler_config, dict):
|
||||
profiler_config_instance = ProfilerConfig(
|
||||
**{
|
||||
k: v
|
||||
for k, v in profiler_config.items()
|
||||
if is_init_field(ProfilerConfig, k)
|
||||
}
|
||||
)
|
||||
else:
|
||||
profiler_config_instance = profiler_config
|
||||
else:
|
||||
profiler_config_instance = ProfilerConfig()
|
||||
|
||||
# warn about single-process data parallel usage.
|
||||
_dp_size = int(kwargs.get("data_parallel_size", 1))
|
||||
_distributed_executor_backend = kwargs.get("distributed_executor_backend")
|
||||
@ -324,6 +340,7 @@ class LLM:
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
pooler_config=pooler_config,
|
||||
structured_outputs_config=structured_outputs_instance,
|
||||
profiler_config=profiler_config_instance,
|
||||
compilation_config=compilation_config_instance,
|
||||
logits_processors=logits_processors,
|
||||
**kwargs,
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ProfilerConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -35,15 +35,12 @@ async def stop_profile(raw_request: Request):
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
profiler_config = getattr(app.state.args, "profiler_config", None)
|
||||
assert profiler_config is None or isinstance(profiler_config, ProfilerConfig)
|
||||
if profiler_config is not None and profiler_config.profiler is not None:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
"Profiler with mode '%s' is enabled in the "
|
||||
"API server. This should ONLY be used for local development!",
|
||||
profiler_config.profiler,
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
logger.warning_once(
|
||||
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
app.include_router(router)
|
||||
|
||||
110
vllm/envs.py
110
vllm/envs.py
@ -89,20 +89,23 @@ if TYPE_CHECKING:
|
||||
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
|
||||
VLLM_PLUGINS: list[str] | None = None
|
||||
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
|
||||
VLLM_TORCH_CUDA_PROFILE: bool = False
|
||||
# Deprecated env variables for profiling, kept for backward compatibility
|
||||
# See also vllm/config/profiler.py and `--profiler-config` argument
|
||||
VLLM_TORCH_CUDA_PROFILE: str | None = None
|
||||
VLLM_TORCH_PROFILER_DIR: str | None = None
|
||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
||||
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False
|
||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None
|
||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None
|
||||
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None
|
||||
VLLM_TORCH_PROFILER_WITH_STACK: str | None = None
|
||||
VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None
|
||||
VLLM_TORCH_PROFILER_USE_GZIP: str | None = None
|
||||
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None
|
||||
VLLM_PROFILER_DELAY_ITERS: str | None = None
|
||||
VLLM_PROFILER_MAX_ITERS: str | None = None
|
||||
# End of deprecated env variables for profiling
|
||||
VLLM_USE_AOT_COMPILE: bool = False
|
||||
VLLM_USE_BYTECODE_HOOK: bool = False
|
||||
VLLM_FORCE_AOT_LOAD: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
|
||||
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
|
||||
VLLM_PROFILER_DELAY_ITERS: int = 0
|
||||
VLLM_PROFILER_MAX_ITERS: int = 0
|
||||
VLLM_TORCH_PROFILER_USE_GZIP: bool = True
|
||||
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True
|
||||
VLLM_USE_TRITON_AWQ: bool = False
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
@ -850,71 +853,52 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR", None
|
||||
),
|
||||
# Enables torch CUDA profiling if set.
|
||||
# On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered.
|
||||
"VLLM_TORCH_CUDA_PROFILE": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0"
|
||||
),
|
||||
# Enables torch CUDA profiling if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
|
||||
# Enables torch profiler if set.
|
||||
# Both AsyncLLM's CPU traces as well as workers'
|
||||
# traces (CPU & GPU) will be saved under this directory.
|
||||
# Note that it must be an absolute path.
|
||||
"VLLM_TORCH_PROFILER_DIR": lambda: (
|
||||
None
|
||||
if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None
|
||||
else (
|
||||
val
|
||||
if val.startswith("gs://") and val[5:] and val[5] != "/"
|
||||
else os.path.abspath(os.path.expanduser(val))
|
||||
)
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"),
|
||||
# Enable torch profiler to record shapes if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES")
|
||||
),
|
||||
# Enable torch profiler to record shapes if set
|
||||
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
|
||||
# not record shapes.
|
||||
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"
|
||||
# Enable torch profiler to profile memory if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY")
|
||||
),
|
||||
# Enable torch profiler to profile memory if set
|
||||
# VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler
|
||||
# will not profile memory.
|
||||
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"
|
||||
# Enable torch profiler to profile stack if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK")
|
||||
),
|
||||
# Enable torch profiler to profile stack if set
|
||||
# VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL
|
||||
# profile stack by default.
|
||||
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"
|
||||
# Enable torch profiler to profile flops if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS")
|
||||
),
|
||||
# Enable torch profiler to profile flops if set
|
||||
# VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will
|
||||
# not profile flops.
|
||||
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
|
||||
),
|
||||
# Disable torch profiling of the AsyncLLMEngine process.
|
||||
# If set to 1, will not profile the engine process.
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
|
||||
# Disable torch profiling of the AsyncLLMEngine process if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM")
|
||||
),
|
||||
# Delay number of iterations before starting profiling when using
|
||||
# the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
|
||||
"VLLM_PROFILER_DELAY_ITERS": lambda: int(
|
||||
os.getenv("VLLM_PROFILER_DELAY_ITERS", "0")
|
||||
),
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")),
|
||||
# Maximum number of iterations to profile when using the torch/torch CUDA profiler.
|
||||
# If set to 0, will not limit the number of iterations.
|
||||
"VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")),
|
||||
"VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"),
|
||||
# Control whether torch profiler gzip-compresses profiling files.
|
||||
# Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default).
|
||||
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0"
|
||||
),
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"),
|
||||
# Control whether torch profiler dumps the self_cuda_time_total table.
|
||||
# Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping
|
||||
# (enabled by default).
|
||||
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0"
|
||||
# Set to 0 to disable dumping the table.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL")
|
||||
),
|
||||
# If set, vLLM will use Triton implementations of AWQ.
|
||||
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
|
||||
|
||||
@ -3,26 +3,27 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ProfilerConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerProfiler(ABC):
|
||||
def __init__(self) -> None:
|
||||
self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS
|
||||
def __init__(self, profiler_config: ProfilerConfig) -> None:
|
||||
self._delay_iters = profiler_config.delay_iterations
|
||||
if self._delay_iters > 0:
|
||||
logger.info_once(
|
||||
"GPU profiling will start "
|
||||
f"{self._delay_iters} steps after start_profile."
|
||||
)
|
||||
|
||||
self._max_iters = envs.VLLM_PROFILER_MAX_ITERS
|
||||
self._max_iters = profiler_config.max_iterations
|
||||
if self._max_iters > 0:
|
||||
logger.info_once(
|
||||
"GPU profiling will stop "
|
||||
@ -133,12 +134,27 @@ class WorkerProfiler(ABC):
|
||||
return nullcontext()
|
||||
|
||||
|
||||
TorchProfilerActivity = Literal["CPU", "CUDA", "XPU"]
|
||||
TorchProfilerActivityMap = {
|
||||
"CPU": torch.profiler.ProfilerActivity.CPU,
|
||||
"CUDA": torch.profiler.ProfilerActivity.CUDA,
|
||||
"XPU": torch.profiler.ProfilerActivity.XPU,
|
||||
}
|
||||
|
||||
|
||||
class TorchProfilerWrapper(WorkerProfiler):
|
||||
def __init__(self, worker_name: str, local_rank: int) -> None:
|
||||
super().__init__()
|
||||
def __init__(
|
||||
self,
|
||||
profiler_config: ProfilerConfig,
|
||||
worker_name: str,
|
||||
local_rank: int,
|
||||
activities: list[TorchProfilerActivity],
|
||||
) -> None:
|
||||
super().__init__(profiler_config)
|
||||
|
||||
self.local_rank = local_rank
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
self.profiler_config = profiler_config
|
||||
torch_profiler_trace_dir = profiler_config.torch_profiler_dir
|
||||
if local_rank in (None, 0):
|
||||
logger.info(
|
||||
"Torch profiling enabled. Traces will be saved to: %s",
|
||||
@ -147,24 +163,23 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
logger.debug(
|
||||
"Profiler config: record_shapes=%s,"
|
||||
"profile_memory=%s,with_stack=%s,with_flops=%s",
|
||||
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
profiler_config.torch_profiler_record_shapes,
|
||||
profiler_config.torch_profiler_with_memory,
|
||||
profiler_config.torch_profiler_with_stack,
|
||||
profiler_config.torch_profiler_with_flops,
|
||||
)
|
||||
|
||||
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
activities=[TorchProfilerActivityMap[activity] for activity in activities],
|
||||
record_shapes=profiler_config.torch_profiler_record_shapes,
|
||||
profile_memory=profiler_config.torch_profiler_with_memory,
|
||||
with_stack=profiler_config.torch_profiler_with_stack,
|
||||
with_flops=profiler_config.torch_profiler_with_flops,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir,
|
||||
worker_name=worker_name,
|
||||
use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP,
|
||||
use_gzip=profiler_config.torch_profiler_use_gzip,
|
||||
),
|
||||
)
|
||||
|
||||
@ -176,9 +191,10 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
def _stop(self) -> None:
|
||||
self.profiler.stop()
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL:
|
||||
rank = self.local_rank
|
||||
profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
profiler_config = self.profiler_config
|
||||
rank = self.local_rank
|
||||
if profiler_config.torch_profiler_dump_cuda_time_total:
|
||||
profiler_dir = profiler_config.torch_profiler_dir
|
||||
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
|
||||
sort_key = "self_cuda_time_total"
|
||||
table = self.profiler.key_averages().table(sort_by=sort_key)
|
||||
@ -189,6 +205,12 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
# only print profiler results on rank 0
|
||||
if rank == 0:
|
||||
print(table)
|
||||
if self.dump_cpu_time_total and rank == 0:
|
||||
logger.info(
|
||||
self.profiler.key_averages().table(
|
||||
sort_by="self_cpu_time_total", row_limit=50
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def annotate_context_manager(self, name: str):
|
||||
@ -196,8 +218,8 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
|
||||
|
||||
class CudaProfilerWrapper(WorkerProfiler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, profiler_config: ProfilerConfig) -> None:
|
||||
super().__init__(profiler_config)
|
||||
# Note: lazy import to avoid dependency issues if CUDA is not available.
|
||||
import torch.cuda.profiler as cuda_profiler
|
||||
|
||||
@ -166,32 +166,24 @@ class AsyncLLM(EngineClient):
|
||||
pass
|
||||
|
||||
if (
|
||||
envs.VLLM_TORCH_PROFILER_DIR
|
||||
and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM
|
||||
vllm_config.profiler_config.profiler == "torch"
|
||||
and not vllm_config.profiler_config.ignore_frontend
|
||||
):
|
||||
profiler_dir = vllm_config.profiler_config.torch_profiler_dir
|
||||
logger.info(
|
||||
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
|
||||
envs.VLLM_TORCH_PROFILER_DIR,
|
||||
profiler_dir,
|
||||
)
|
||||
if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0:
|
||||
logger.warning_once(
|
||||
"Torch profiler received max_iters or delay_iters setting. These "
|
||||
"are not compatible with the AsyncLLM profiler and will be ignored "
|
||||
"for the AsyncLLM process. Engine process profiling will still "
|
||||
"respect these settings. Consider setting "
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable "
|
||||
"AsyncLLM profiling."
|
||||
)
|
||||
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
],
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
envs.VLLM_TORCH_PROFILER_DIR,
|
||||
profiler_dir,
|
||||
worker_name=worker_name,
|
||||
use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP,
|
||||
use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
|
||||
from vllm.profiler.wrapper import TorchProfilerWrapper
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||
|
||||
@ -38,30 +39,17 @@ class CPUWorker(Worker):
|
||||
|
||||
self.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
# Torch profiler. Enabled and configured through profiler_config.
|
||||
self.profiler: Any | None = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
profiler_config = vllm_config.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir,
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
profiler_config,
|
||||
worker_name=worker_name,
|
||||
local_rank=self.local_rank,
|
||||
activities=["CPU"],
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
],
|
||||
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=False
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def init_device(self):
|
||||
# Setup OpenMP threads affinity.
|
||||
@ -202,9 +190,3 @@ class CPUWorker(Worker):
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
if self.local_rank == 0:
|
||||
logger.info(
|
||||
self.profiler.key_averages().table(
|
||||
sort_by="self_cpu_time_total", row_limit=50
|
||||
)
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
|
||||
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
@ -92,17 +92,19 @@ class Worker(WorkerBase):
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch/CUDA profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
# VLLM_TORCH_CUDA_PROFILE=1
|
||||
# Torch/CUDA profiler. Enabled and configured through profiler_config.
|
||||
self.profiler: Any | None = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
profiler_config = vllm_config.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
worker_name=worker_name, local_rank=self.local_rank
|
||||
profiler_config,
|
||||
worker_name=worker_name,
|
||||
local_rank=self.local_rank,
|
||||
activities=["CPU", "CUDA"],
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
self.profiler = CudaProfilerWrapper()
|
||||
elif profiler_config.profiler == "cuda":
|
||||
self.profiler = CudaProfilerWrapper(profiler_config)
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
|
||||
@ -98,10 +98,10 @@ class TPUWorker:
|
||||
# MP runtime is initialized.
|
||||
self.profiler = None
|
||||
self.profile_dir = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
||||
if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
|
||||
# For TPU, we can only have 1 active profiler session for 1 profiler
|
||||
# server. So we only profile on rank0.
|
||||
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s", self.profile_dir
|
||||
)
|
||||
|
||||
@ -6,12 +6,12 @@ from typing import Any
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_world_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.wrapper import TorchProfilerWrapper
|
||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
|
||||
|
||||
@ -36,41 +36,17 @@ class XPUWorker(Worker):
|
||||
assert device_config.device_type == "xpu"
|
||||
assert current_platform.is_xpu()
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
# Torch profiler. Enabled and configured through profiler_config.
|
||||
self.profiler: Any | None = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
profiler_config = vllm_config.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir,
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
profiler_config,
|
||||
worker_name=worker_name,
|
||||
local_rank=self.local_rank,
|
||||
activities=["CPU", "XPU"],
|
||||
)
|
||||
logger.debug(
|
||||
"Profiler config: record_shapes=%s,"
|
||||
"profile_memory=%s,with_stack=%s,with_flops=%s",
|
||||
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.XPU,
|
||||
],
|
||||
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir,
|
||||
worker_name=worker_name,
|
||||
use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
|
||||
# return correct free_gpu_memory on intel client GPU. We need to
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user