diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 84759c5c354d..de62bf5c63c7 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -6,13 +6,12 @@ import dataclasses import json import os import time -from pathlib import Path from typing import Any, Optional import numpy as np -import torch from tqdm import tqdm +import vllm.envs as envs from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs @@ -80,17 +79,9 @@ def main(args: argparse.Namespace): def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - on_trace_ready=torch.profiler.tensorboard_trace_handler( - str(profile_dir) - ), - ) as p: - llm_generate() - print(p.key_averages().table(sort_by="self_cuda_time_total")) + llm.start_profile() + llm_generate() + llm.stop_profile() else: start_time = time.perf_counter() llm_generate() @@ -103,11 +94,7 @@ def main(args: argparse.Namespace): run_to_completion(profile_dir=None) if args.profile: - profile_dir = args.profile_result_dir - if not profile_dir: - profile_dir = ( - Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" - ) + profile_dir = envs.VLLM_TORCH_PROFILER_DIR print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return @@ -164,15 +151,6 @@ if __name__ == "__main__": action="store_true", help="profile the generation process of a single batch", ) - parser.add_argument( - "--profile-result-dir", - type=str, - default=None, - help=( - "path to save the pytorch profiler output. Can be visualized " - "with ui.perfetto.dev or Tensorboard." - ), - ) parser.add_argument( "--output-json", type=str, @@ -193,4 +171,9 @@ if __name__ == "__main__": # numbers. We need to disable prefix caching by default. parser.set_defaults(enable_prefix_caching=False) args = parser.parse_args() + if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: + raise OSError( + "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " + "Please set it to a valid path to use torch profiler." + ) main(args) diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index 0dd938e75129..c9e03cc3bf78 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -6,13 +6,12 @@ import dataclasses import json import os import time -from pathlib import Path from typing import Any, Optional import numpy as np -import torch from tqdm import tqdm +import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, write_to_json) @@ -59,13 +58,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="profile the generation process of a single batch", ) - parser.add_argument( - "--profile-result-dir", - type=str, - default=None, - help=("path to save the pytorch profiler output. Can be visualized " - "with ui.perfetto.dev or Tensorboard."), - ) parser.add_argument( "--output-json", type=str, @@ -87,7 +79,10 @@ def add_cli_args(parser: argparse.ArgumentParser): def main(args: argparse.Namespace): print(args) - + if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: + raise OSError( + "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " + "Please set it to a valid path to use torch profiler.") engine_args = EngineArgs.from_cli_args(args) # NOTE(woosuk): If the request cannot be processed in a single batch, @@ -131,16 +126,9 @@ def main(args: argparse.Namespace): def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - on_trace_ready=torch.profiler.tensorboard_trace_handler( - str(profile_dir)), - ) as p: - llm_generate() - print(p.key_averages().table(sort_by="self_cuda_time_total")) + llm.start_profile() + llm_generate() + llm.stop_profile() else: start_time = time.perf_counter() llm_generate() @@ -153,10 +141,7 @@ def main(args: argparse.Namespace): run_to_completion(profile_dir=None) if args.profile: - profile_dir = args.profile_result_dir - if not profile_dir: - profile_dir = (Path(".") / "vllm_benchmark_result" / - f"latency_result_{time.time()}") + profile_dir = envs.VLLM_TORCH_PROFILER_DIR print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index bce5cbb5f9d0..dd06e729673f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -292,6 +292,8 @@ class Worker(WorkerBase): self.profiler.start() else: self.profiler.stop() + print(self.profiler.key_averages().table( + sort_by="self_cuda_time_total")) def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6e45b8423e5e..2a4317271934 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -128,6 +128,8 @@ class Worker(LocalOrDistributedWorkerBase): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") self.profiler.stop() + print( + self.profiler.key_averages().table(sort_by="self_cuda_time_total")) def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0]