mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:55:01 +08:00
[torch.compile] allow tracking forward time (#11081)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
15859f2357
commit
a1c02058ba
@ -1,9 +1,11 @@
|
||||
import time
|
||||
from collections import Counter
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -11,9 +13,10 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||
batchsize_counter: Counter = Counter()
|
||||
last_logging_time: float = 0
|
||||
forward_start_time: float = 0
|
||||
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -40,23 +43,10 @@ def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
"""
|
||||
global track_batchsize, batchsize_counter
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if track_batchsize and context is not None:
|
||||
if hasattr(context, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = context.num_prefill_tokens + context.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = context.num_input_tokens
|
||||
batchsize_counter[batchsize] += 1
|
||||
if time.monotonic() - last_logging_time > batchsize_logging_interval:
|
||||
last_logging_time = time.monotonic()
|
||||
sorted_data = sorted(batchsize_counter.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True)
|
||||
logger.info("Batchsize distribution (batchsize, count): %s",
|
||||
sorted_data)
|
||||
global forward_start_time
|
||||
need_to_track_batchsize = track_batchsize and context is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = ForwardContext(
|
||||
@ -66,4 +56,37 @@ def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
global batchsize_counter
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if need_to_track_batchsize:
|
||||
if hasattr(context, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = context.num_prefill_tokens + \
|
||||
context.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = context.num_input_tokens
|
||||
# we use synchronous scheduling right now,
|
||||
# adding a sync point here should not affect
|
||||
# scheduling of the next batch
|
||||
torch.cuda.synchronize()
|
||||
now = time.perf_counter()
|
||||
# time measurement is in milliseconds
|
||||
batchsize_forward_time[batchsize].append(
|
||||
(now - forward_start_time) * 1000)
|
||||
if now - last_logging_time > batchsize_logging_interval:
|
||||
last_logging_time = now
|
||||
forward_stats = []
|
||||
for bs, times in batchsize_forward_time.items():
|
||||
if len(times) <= 1:
|
||||
# can be cudagraph / profiling run
|
||||
continue
|
||||
medium = torch.quantile(torch.tensor(times), q=0.5).item()
|
||||
medium = round(medium, 2)
|
||||
forward_stats.append((bs, len(times), medium))
|
||||
forward_stats.sort(key=lambda x: x[1], reverse=True)
|
||||
if forward_stats:
|
||||
logger.info(("Batchsize forward time stats "
|
||||
"(batchsize, count, median_time(ms)): %s"),
|
||||
forward_stats)
|
||||
_forward_context = prev_context
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user