mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 17:27:13 +08:00
70 lines
2.4 KiB
Python
70 lines
2.4 KiB
Python
import time
|
|
from collections import Counter
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
|
batchsize_counter: Counter = Counter()
|
|
last_logging_time: float = 0
|
|
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
|
|
|
|
|
@dataclass
|
|
class ForwardContext:
|
|
static_forward_context: Dict[str, Any]
|
|
# TODO: extend to support per-layer dynamic forward context
|
|
dynamic_forward_context: Any
|
|
|
|
|
|
_forward_context: Optional[ForwardContext] = None
|
|
|
|
|
|
def get_forward_context() -> ForwardContext:
|
|
"""Get the current forward context."""
|
|
assert _forward_context is not None, (
|
|
"Forward context is not set. "
|
|
"Please use `set_forward_context` to set the forward context.")
|
|
return _forward_context
|
|
|
|
|
|
@contextmanager
|
|
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
|
"""A context manager that stores the current forward context,
|
|
can be attention metadata, etc.
|
|
Here we can inject common logic for every model forward pass.
|
|
"""
|
|
global track_batchsize, batchsize_counter
|
|
global last_logging_time, batchsize_logging_interval
|
|
if track_batchsize and context is not None:
|
|
if hasattr(context, "num_prefill_tokens"):
|
|
# for v0 attention backends
|
|
batchsize = context.num_prefill_tokens + context.num_decode_tokens
|
|
else:
|
|
# for v1 attention backends
|
|
batchsize = context.num_input_tokens
|
|
batchsize_counter[batchsize] += 1
|
|
if time.monotonic() - last_logging_time > batchsize_logging_interval:
|
|
last_logging_time = time.monotonic()
|
|
sorted_data = sorted(batchsize_counter.items(),
|
|
key=lambda x: x[1],
|
|
reverse=True)
|
|
logger.info("Batchsize distribution (batchsize, count): %s",
|
|
sorted_data)
|
|
global _forward_context
|
|
prev_context = _forward_context
|
|
_forward_context = ForwardContext(
|
|
static_forward_context=vllm_config.compilation_config.
|
|
static_forward_context,
|
|
dynamic_forward_context=context)
|
|
try:
|
|
yield
|
|
finally:
|
|
_forward_context = prev_context
|