# SPDX-License-Identifier: Apache-2.0 import time from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed as dist import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata logger = init_logger(__name__) track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 last_logging_time: float = 0 forward_start_time: float = 0 batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL batchsize_forward_time: defaultdict = defaultdict(list) @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] """ Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata set dynamically for each forward pass """ attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None _forward_context: Optional[ForwardContext] = None def get_forward_context() -> ForwardContext: """Get the current forward context.""" assert _forward_context is not None, ( "Forward context is not set. " "Please use `set_forward_context` to set the forward context.") return _forward_context def create_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, num_tokens: int = 0 ): dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1: dp_size = vllm_config.parallel_config.data_parallel_size dp_rank = vllm_config.parallel_config.data_parallel_rank if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends batchsize = attn_metadata.num_prefill_tokens + \ attn_metadata.num_decode_tokens else: # for v1 attention backends or no attn_metadata batchsize = num_tokens num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = batchsize num_tokens_tensor = torch.tensor(num_tokens_across_dp, device="cpu", dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) max_tokens_across_dp_cpu = torch.max(num_tokens_tensor) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_metadata = DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) return ForwardContext( no_compile_layers=vllm_config.compilation_config. static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata) @contextmanager def override_forward_context(forward_context: Optional[ForwardContext]): """A context manager that overrides the current forward context. This is used to override the forward context for a specific forward pass. """ global _forward_context prev_context = _forward_context _forward_context = forward_context try: yield finally: _forward_context = prev_context @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, num_tokens: int = 0): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. """ global forward_start_time need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() forward_context = create_forward_context( attn_metadata, vllm_config, virtual_engine, num_tokens) try: with override_forward_context(forward_context): yield finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: if hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends batchsize = attn_metadata.num_prefill_tokens + \ attn_metadata.num_decode_tokens else: # for v1 attention backends batchsize = num_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch torch.cuda.synchronize() now = time.perf_counter() # time measurement is in milliseconds batchsize_forward_time[batchsize].append( (now - forward_start_time) * 1000) if now - last_logging_time > batchsize_logging_interval: last_logging_time = now forward_stats = [] for bs, times in batchsize_forward_time.items(): if len(times) <= 1: # can be cudagraph / profiling run continue medium = torch.quantile(torch.tensor(times), q=0.5).item() medium = round(medium, 2) forward_stats.append((bs, len(times), medium)) forward_stats.sort(key=lambda x: x[1], reverse=True) if forward_stats: logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats)