mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
[V1] fix torch profiling for V1 offline scenarios (#18445)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
parent
9a21e331ff
commit
774c5fde30
@ -6,13 +6,12 @@ import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -80,17 +79,9 @@ def main(args: argparse.Namespace):
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir)
|
||||
),
|
||||
) as p:
|
||||
llm_generate()
|
||||
print(p.key_averages().table(sort_by="self_cuda_time_total"))
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
@ -103,11 +94,7 @@ def main(args: argparse.Namespace):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = args.profile_result_dir
|
||||
if not profile_dir:
|
||||
profile_dir = (
|
||||
Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||
)
|
||||
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
@ -164,15 +151,6 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-result-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"path to save the pytorch profiler output. Can be visualized "
|
||||
"with ui.perfetto.dev or Tensorboard."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
@ -193,4 +171,9 @@ if __name__ == "__main__":
|
||||
# numbers. We need to disable prefix caching by default.
|
||||
parser.set_defaults(enable_prefix_caching=False)
|
||||
args = parser.parse_args()
|
||||
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
|
||||
raise OSError(
|
||||
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
|
||||
"Please set it to a valid path to use torch profiler."
|
||||
)
|
||||
main(args)
|
||||
|
||||
@ -6,13 +6,12 @@ import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
|
||||
write_to_json)
|
||||
@ -59,13 +58,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-result-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("path to save the pytorch profiler output. Can be visualized "
|
||||
"with ui.perfetto.dev or Tensorboard."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
@ -87,7 +79,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
|
||||
raise OSError(
|
||||
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
|
||||
"Please set it to a valid path to use torch profiler.")
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
@ -131,16 +126,9 @@ def main(args: argparse.Namespace):
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir)),
|
||||
) as p:
|
||||
llm_generate()
|
||||
print(p.key_averages().table(sort_by="self_cuda_time_total"))
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
@ -153,10 +141,7 @@ def main(args: argparse.Namespace):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = args.profile_result_dir
|
||||
if not profile_dir:
|
||||
profile_dir = (Path(".") / "vllm_benchmark_result" /
|
||||
f"latency_result_{time.time()}")
|
||||
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
@ -292,6 +292,8 @@ class Worker(WorkerBase):
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
print(self.profiler.key_averages().table(
|
||||
sort_by="self_cuda_time_total"))
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1)
|
||||
|
||||
@ -128,6 +128,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
print(
|
||||
self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user