mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 22:55:38 +08:00
[torch.compile] add a flag to track batchsize statistics (#11059)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
e739194926
commit
75f89dc44c
@ -69,6 +69,7 @@ if TYPE_CHECKING:
|
||||
VLLM_DISABLED_KERNELS: List[str] = []
|
||||
VLLM_USE_V1: bool = False
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -452,6 +453,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# If set, enable multiprocessing in LLM for the V1 code path.
|
||||
"VLLM_ENABLE_V1_MULTIPROCESSING":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
|
||||
"VLLM_LOG_BATCHSIZE_INTERVAL":
|
||||
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
import time
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||
batchsize_counter: Counter = Counter()
|
||||
last_logging_time: float = 0
|
||||
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext:
|
||||
@contextmanager
|
||||
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc."""
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
"""
|
||||
global track_batchsize, batchsize_counter
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if track_batchsize and context is not None:
|
||||
if hasattr(context, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = context.num_prefill_tokens + context.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = context.num_input_tokens
|
||||
batchsize_counter[batchsize] += 1
|
||||
if time.monotonic() - last_logging_time > batchsize_logging_interval:
|
||||
last_logging_time = time.monotonic()
|
||||
sorted_data = sorted(batchsize_counter.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True)
|
||||
logger.info("Batchsize distribution (batchsize, count): %s",
|
||||
sorted_data)
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = ForwardContext(
|
||||
|
||||
@ -56,6 +56,7 @@ class FlashAttentionMetadata:
|
||||
seq_start_loc: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
@ -445,6 +445,8 @@ class GPUModelRunner:
|
||||
# Eager mode.
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
|
||||
# Get the inputs embeds.
|
||||
if encoder_outputs:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user