[torch.compile] add a flag to track batchsize statistics (#11059)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-10 12:40:52 -08:00 committed by GitHub
parent e739194926
commit 75f89dc44c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 1 deletions

View File

@ -69,6 +69,7 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
def get_default_cache_root():
@ -452,6 +453,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
"VLLM_LOG_BATCHSIZE_INTERVAL":
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
}
# end-env-vars-definition

View File

@ -1,8 +1,19 @@
import time
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
batchsize_counter: Counter = Counter()
last_logging_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
@dataclass
@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global track_batchsize, batchsize_counter
global last_logging_time, batchsize_logging_interval
if track_batchsize and context is not None:
if hasattr(context, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + context.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize_counter[batchsize] += 1
if time.monotonic() - last_logging_time > batchsize_logging_interval:
last_logging_time = time.monotonic()
sorted_data = sorted(batchsize_counter.items(),
key=lambda x: x[1],
reverse=True)
logger.info("Batchsize distribution (batchsize, count): %s",
sorted_data)
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(

View File

@ -56,6 +56,7 @@ class FlashAttentionMetadata:
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_input_tokens: int = 0 # Number of tokens including padding.
class FlashAttentionImpl(AttentionImpl):

View File

@ -445,6 +445,8 @@ class GPUModelRunner:
# Eager mode.
num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
# Get the inputs embeds.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(