mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 01:25:33 +08:00
[torch.compile] add a flag to track batchsize statistics (#11059)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
e739194926
commit
75f89dc44c
@ -69,6 +69,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DISABLED_KERNELS: List[str] = []
|
VLLM_DISABLED_KERNELS: List[str] = []
|
||||||
VLLM_USE_V1: bool = False
|
VLLM_USE_V1: bool = False
|
||||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
|
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
|
||||||
|
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -452,6 +453,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# If set, enable multiprocessing in LLM for the V1 code path.
|
# If set, enable multiprocessing in LLM for the V1 code path.
|
||||||
"VLLM_ENABLE_V1_MULTIPROCESSING":
|
"VLLM_ENABLE_V1_MULTIPROCESSING":
|
||||||
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
|
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
|
||||||
|
"VLLM_LOG_BATCHSIZE_INTERVAL":
|
||||||
|
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -1,8 +1,19 @@
|
|||||||
|
import time
|
||||||
|
from collections import Counter
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||||
|
batchsize_counter: Counter = Counter()
|
||||||
|
last_logging_time: float = 0
|
||||||
|
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc."""
|
can be attention metadata, etc.
|
||||||
|
Here we can inject common logic for every model forward pass.
|
||||||
|
"""
|
||||||
|
global track_batchsize, batchsize_counter
|
||||||
|
global last_logging_time, batchsize_logging_interval
|
||||||
|
if track_batchsize and context is not None:
|
||||||
|
if hasattr(context, "num_prefill_tokens"):
|
||||||
|
# for v0 attention backends
|
||||||
|
batchsize = context.num_prefill_tokens + context.num_decode_tokens
|
||||||
|
else:
|
||||||
|
# for v1 attention backends
|
||||||
|
batchsize = context.num_input_tokens
|
||||||
|
batchsize_counter[batchsize] += 1
|
||||||
|
if time.monotonic() - last_logging_time > batchsize_logging_interval:
|
||||||
|
last_logging_time = time.monotonic()
|
||||||
|
sorted_data = sorted(batchsize_counter.items(),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
reverse=True)
|
||||||
|
logger.info("Batchsize distribution (batchsize, count): %s",
|
||||||
|
sorted_data)
|
||||||
global _forward_context
|
global _forward_context
|
||||||
prev_context = _forward_context
|
prev_context = _forward_context
|
||||||
_forward_context = ForwardContext(
|
_forward_context = ForwardContext(
|
||||||
|
|||||||
@ -56,6 +56,7 @@ class FlashAttentionMetadata:
|
|||||||
seq_start_loc: torch.Tensor
|
seq_start_loc: torch.Tensor
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionImpl(AttentionImpl):
|
class FlashAttentionImpl(AttentionImpl):
|
||||||
|
|||||||
@ -445,6 +445,8 @@ class GPUModelRunner:
|
|||||||
# Eager mode.
|
# Eager mode.
|
||||||
num_input_tokens = num_scheduled_tokens
|
num_input_tokens = num_scheduled_tokens
|
||||||
|
|
||||||
|
attn_metadata.num_input_tokens = num_input_tokens
|
||||||
|
|
||||||
# Get the inputs embeds.
|
# Get the inputs embeds.
|
||||||
if encoder_outputs:
|
if encoder_outputs:
|
||||||
inputs_embeds = self.model.get_input_embeddings(
|
inputs_embeds = self.model.get_input_embeddings(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user