mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 04:55:01 +08:00
438 lines
14 KiB
Python
438 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import vllm.envs as envs
|
|
from vllm.compilation.cuda_graph import CUDAGraphStat
|
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
|
|
|
|
|
@dataclass
|
|
class BaseCacheStats:
|
|
"""Stores cache hit statistics."""
|
|
|
|
reset: bool = False
|
|
"""Whether the cache was reset."""
|
|
|
|
requests: int = 0
|
|
"""The number of requests in this update."""
|
|
|
|
queries: int = 0
|
|
"""The number of queries in these requests."""
|
|
|
|
hits: int = 0
|
|
"""The number of hits in these requests."""
|
|
|
|
|
|
class CachingMetrics:
|
|
"""Metrics for caching with a hit rate of the most recent N requests.
|
|
Args:
|
|
interval: The number of the most recent requests to aggregate.
|
|
Defaults to 1000.
|
|
"""
|
|
|
|
def __init__(self, max_recent_requests: int = 1000) -> None:
|
|
super().__init__()
|
|
|
|
self.max_recent_requests = max_recent_requests
|
|
# The current aggregated values.
|
|
self.aggregated_requests = 0
|
|
self.aggregated_query_total = 0
|
|
self.aggregated_query_hit = 0
|
|
|
|
# A deque of (requests, queries, hits) for the most recent requests.
|
|
self.query_queue = deque[tuple[int, int, int]]()
|
|
|
|
def observe(self, stats: BaseCacheStats):
|
|
"""Observe the prefix caching for a set of requests.
|
|
|
|
This function is called with information gathered when new requests
|
|
are being scheduled and are looking for computed blocks.
|
|
|
|
When there are more than `max_recent_requests` requests, the oldest set
|
|
of requests are removed from the metrics.
|
|
|
|
Args:
|
|
stats: The prefix cache stats.
|
|
"""
|
|
# reset_prefix_cache was invoked before the current update.
|
|
# Reset the metrics before aggregating the current stats.
|
|
if stats.reset:
|
|
self.reset()
|
|
|
|
# DO NOT appending empty stats to avoid helpful info get kicked out
|
|
# due to sliding window.
|
|
if stats.requests == 0:
|
|
return
|
|
|
|
# Update the metrics.
|
|
self.query_queue.append((stats.requests, stats.queries, stats.hits))
|
|
self.aggregated_requests += stats.requests
|
|
self.aggregated_query_total += stats.queries
|
|
self.aggregated_query_hit += stats.hits
|
|
|
|
# Remove the oldest stats until number of requests does not exceed
|
|
# the limit.
|
|
# NOTE: We preserve the latest added stats regardless.
|
|
while (
|
|
len(self.query_queue) > 1
|
|
and self.aggregated_requests > self.max_recent_requests
|
|
):
|
|
old_requests, old_queries, old_hits = self.query_queue.popleft()
|
|
self.aggregated_requests -= old_requests
|
|
self.aggregated_query_total -= old_queries
|
|
self.aggregated_query_hit -= old_hits
|
|
|
|
def reset(self):
|
|
"""Reset the metrics."""
|
|
self.aggregated_requests = 0
|
|
self.aggregated_query_total = 0
|
|
self.aggregated_query_hit = 0
|
|
self.query_queue.clear()
|
|
|
|
@property
|
|
def empty(self) -> bool:
|
|
"""Return true if no requests have been observed."""
|
|
return self.aggregated_requests == 0
|
|
|
|
@property
|
|
def hit_rate(self) -> float:
|
|
"""Calculate the hit rate for the past N requests."""
|
|
if self.aggregated_query_total == 0:
|
|
return 0.0
|
|
return self.aggregated_query_hit / self.aggregated_query_total
|
|
|
|
|
|
@dataclass
|
|
class PrefixCacheStats(BaseCacheStats):
|
|
"""
|
|
Stores prefix cache hit statistics.
|
|
- `reset`: Whether `reset_prefix_cache` was invoked.
|
|
- `queries`: Refers to the number of tokens that were queried.
|
|
"""
|
|
|
|
preempted_requests: int = 0
|
|
"""The number of previously preempted requests in this update."""
|
|
|
|
preempted_queries: int = 0
|
|
"""The `queries` number for preempted requests."""
|
|
|
|
preempted_hits: int = 0
|
|
"""The `hits` number for preempted requests."""
|
|
|
|
def record(self, num_tokens: int, num_hits: int, preempted: bool) -> None:
|
|
"""Aggregate request information into the stats."""
|
|
if preempted:
|
|
# Previously preempted request
|
|
self.preempted_requests += 1
|
|
self.preempted_queries += num_tokens
|
|
self.preempted_hits += num_hits
|
|
else:
|
|
# New request
|
|
self.requests += 1
|
|
self.queries += num_tokens
|
|
self.hits += num_hits
|
|
|
|
|
|
@dataclass
|
|
class MultiModalCacheStats(BaseCacheStats):
|
|
"""
|
|
Stores multi-modal cache hit statistics.
|
|
- `reset`: Whether `reset_mm_cache` was invoked.
|
|
- `queries`: Refers to the number of multi-modal data items
|
|
that were queried.
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class KVCacheEvictionEvent:
|
|
"""Single KV cache block eviction sample."""
|
|
|
|
lifetime_seconds: float
|
|
idle_seconds: float
|
|
reuse_gaps_seconds: tuple[float, ...]
|
|
|
|
|
|
@dataclass
|
|
class SchedulerStats:
|
|
"""Stats associated with the scheduler."""
|
|
|
|
num_running_reqs: int = 0
|
|
num_waiting_reqs: int = 0
|
|
|
|
# These are used for internal DP load-balancing.
|
|
step_counter: int = 0
|
|
current_wave: int = 0
|
|
|
|
kv_cache_usage: float = 0.0
|
|
|
|
prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
|
|
connector_prefix_cache_stats: PrefixCacheStats | None = None
|
|
|
|
kv_cache_eviction_events: list[KVCacheEvictionEvent] = field(default_factory=list)
|
|
|
|
spec_decoding_stats: SpecDecodingStats | None = None
|
|
kv_connector_stats: dict[str, Any] | None = None
|
|
|
|
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
|
|
running_lora_adapters: dict[str, int] = field(default_factory=dict)
|
|
|
|
cudagraph_stats: CUDAGraphStat | None = None
|
|
|
|
|
|
@dataclass
|
|
class RequestStateStats:
|
|
"""Stats that need to be tracked across delta updates."""
|
|
|
|
num_generation_tokens: int = 0
|
|
|
|
# This is an engine frontend timestamp (wall-clock)
|
|
arrival_time: float = 0.0
|
|
|
|
# These are engine core timestamps (monotonic)
|
|
queued_ts: float = 0.0
|
|
scheduled_ts: float = 0.0
|
|
first_token_ts: float = 0.0
|
|
last_token_ts: float = 0.0
|
|
|
|
# first token latency
|
|
first_token_latency: float = 0.0
|
|
|
|
# Track if this request is corrupted (NaNs in logits)
|
|
is_corrupted: bool = False
|
|
|
|
|
|
@dataclass
|
|
class FinishedRequestStats:
|
|
"""Stats associated with a finished request."""
|
|
|
|
finish_reason: "FinishReason"
|
|
e2e_latency: float = 0.0
|
|
num_prompt_tokens: int = 0
|
|
num_generation_tokens: int = 0
|
|
max_tokens_param: int | None = None
|
|
queued_time: float = 0.0
|
|
prefill_time: float = 0.0
|
|
inference_time: float = 0.0
|
|
decode_time: float = 0.0
|
|
mean_time_per_output_token: float = 0.0
|
|
is_corrupted: bool = False
|
|
num_cached_tokens: int = 0
|
|
|
|
|
|
class IterationStats:
|
|
"""Stats associated with a single set of EngineCoreOutputs."""
|
|
|
|
def __init__(self):
|
|
self.iteration_timestamp = time.time()
|
|
self.num_generation_tokens = 0
|
|
self.num_prompt_tokens = 0
|
|
self.num_preempted_reqs = 0
|
|
self.finished_requests: list[FinishedRequestStats] = []
|
|
self.max_num_generation_tokens_iter: list[int] = []
|
|
self.n_params_iter: list[int] = []
|
|
self.time_to_first_tokens_iter: list[float] = []
|
|
self.inter_token_latencies_iter: list[float] = []
|
|
self.num_corrupted_reqs: int = 0
|
|
|
|
def __repr__(self) -> str:
|
|
field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
|
|
return f"{self.__class__.__name__}({field_to_value_str})"
|
|
|
|
def _time_since(self, start: float) -> float:
|
|
"""Calculate an interval relative to this iteration's timestamp."""
|
|
return self.iteration_timestamp - start
|
|
|
|
def update_from_output(
|
|
self,
|
|
output: "EngineCoreOutput",
|
|
engine_core_timestamp: float,
|
|
is_prefilling: bool,
|
|
prompt_len: int,
|
|
req_stats: RequestStateStats,
|
|
lora_states: "LoRARequestStates",
|
|
lora_name: str | None,
|
|
):
|
|
num_new_generation_tokens = len(output.new_token_ids)
|
|
|
|
self.num_generation_tokens += num_new_generation_tokens
|
|
if is_prefilling:
|
|
self.num_prompt_tokens += prompt_len
|
|
|
|
first_token_latency = self._time_since(req_stats.arrival_time)
|
|
self.time_to_first_tokens_iter.append(first_token_latency)
|
|
req_stats.first_token_latency = first_token_latency
|
|
|
|
req_stats.num_generation_tokens += num_new_generation_tokens
|
|
|
|
# Track if this request is corrupted (only check once per request)
|
|
# Early exit if already marked as corrupted to avoid redundant checks
|
|
if (
|
|
envs.VLLM_COMPUTE_NANS_IN_LOGITS
|
|
and not req_stats.is_corrupted
|
|
and output.num_nans_in_logits > 0
|
|
):
|
|
req_stats.is_corrupted = True
|
|
|
|
# Process request-level engine core events
|
|
if output.events is not None:
|
|
self.update_from_events(
|
|
output.request_id,
|
|
output.events,
|
|
is_prefilling,
|
|
req_stats,
|
|
lora_states,
|
|
lora_name,
|
|
)
|
|
|
|
# Process the batch-level "new tokens" engine core event
|
|
if is_prefilling:
|
|
req_stats.first_token_ts = engine_core_timestamp
|
|
else:
|
|
itl = engine_core_timestamp - req_stats.last_token_ts
|
|
self.inter_token_latencies_iter.append(itl)
|
|
|
|
req_stats.last_token_ts = engine_core_timestamp
|
|
|
|
def update_from_events(
|
|
self,
|
|
req_id: str,
|
|
events: list["EngineCoreEvent"],
|
|
is_prefilling: bool,
|
|
req_stats: RequestStateStats,
|
|
lora_states: "LoRARequestStates",
|
|
lora_name: str | None,
|
|
):
|
|
# Avoid circular dependency
|
|
from vllm.v1.engine import EngineCoreEventType
|
|
|
|
for event in events:
|
|
if event.type == EngineCoreEventType.QUEUED:
|
|
req_stats.queued_ts = event.timestamp
|
|
lora_states.request_waiting(req_id, lora_name)
|
|
elif event.type == EngineCoreEventType.SCHEDULED:
|
|
if req_stats.scheduled_ts == 0.0: # ignore preemptions
|
|
req_stats.scheduled_ts = event.timestamp
|
|
lora_states.request_running(req_id, lora_name)
|
|
elif event.type == EngineCoreEventType.PREEMPTED:
|
|
self.num_preempted_reqs += 1
|
|
lora_states.request_waiting(req_id, lora_name)
|
|
|
|
def update_from_finished_request(
|
|
self,
|
|
finish_reason: "FinishReason",
|
|
num_prompt_tokens: int,
|
|
max_tokens_param: int | None,
|
|
req_stats: RequestStateStats,
|
|
num_cached_tokens: int = 0,
|
|
):
|
|
e2e_latency = self._time_since(req_stats.arrival_time)
|
|
|
|
# Queued interval is from first QUEUED event to first SCHEDULED
|
|
queued_time = req_stats.scheduled_ts - req_stats.queued_ts
|
|
|
|
# Prefill interval is from first SCHEDULED to first NEW_TOKEN
|
|
# Any preemptions during prefill is included in the interval
|
|
prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts
|
|
|
|
# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
|
|
# Any preemptions during decode are included
|
|
decode_time = req_stats.last_token_ts - req_stats.first_token_ts
|
|
|
|
# Inference interval is from first SCHEDULED to last NEW_TOKEN
|
|
# Any preemptions during prefill or decode are included
|
|
inference_time = req_stats.last_token_ts - req_stats.scheduled_ts
|
|
|
|
# Do not count the token generated by the prefill phase
|
|
mean_time_per_output_token = (
|
|
decode_time / (req_stats.num_generation_tokens - 1)
|
|
if req_stats.num_generation_tokens - 1 > 0
|
|
else 0
|
|
)
|
|
|
|
finished_req = FinishedRequestStats(
|
|
finish_reason=finish_reason,
|
|
e2e_latency=e2e_latency,
|
|
num_prompt_tokens=num_prompt_tokens,
|
|
num_generation_tokens=req_stats.num_generation_tokens,
|
|
max_tokens_param=max_tokens_param,
|
|
queued_time=queued_time,
|
|
prefill_time=prefill_time,
|
|
inference_time=inference_time,
|
|
decode_time=decode_time,
|
|
mean_time_per_output_token=mean_time_per_output_token,
|
|
is_corrupted=req_stats.is_corrupted,
|
|
num_cached_tokens=num_cached_tokens,
|
|
)
|
|
self.finished_requests.append(finished_req)
|
|
|
|
# Count corrupted requests when they finish (only once per request)
|
|
if req_stats.is_corrupted:
|
|
self.num_corrupted_reqs += 1
|
|
|
|
|
|
class LoRAStats:
|
|
"""Tracks waiting and running request IDs for a single LoRA."""
|
|
|
|
def __init__(self):
|
|
self.waiting: set[str] = set()
|
|
self.running: set[str] = set()
|
|
|
|
def update(self, req_id: str, waiting: bool, running: bool):
|
|
assert not (waiting and running)
|
|
if waiting:
|
|
self.waiting.add(req_id)
|
|
else:
|
|
self.waiting.discard(req_id)
|
|
|
|
if running:
|
|
self.running.add(req_id)
|
|
else:
|
|
self.running.discard(req_id)
|
|
|
|
@property
|
|
def empty(self) -> bool:
|
|
return not (self.waiting or self.running)
|
|
|
|
|
|
class LoRARequestStates:
|
|
"""A per-LoRA count of running and waiting requests."""
|
|
|
|
def __init__(self, log_stats: bool = False):
|
|
self.log_stats = log_stats
|
|
self.requests: defaultdict[str, LoRAStats] = defaultdict(LoRAStats)
|
|
|
|
def _request_update(
|
|
self, req_id: str, lora_name: str | None, waiting: bool, running: bool
|
|
):
|
|
if not self.log_stats or lora_name is None:
|
|
return
|
|
|
|
lora_stats = self.requests[lora_name]
|
|
lora_stats.update(req_id, waiting, running)
|
|
if lora_stats.empty:
|
|
del self.requests[lora_name]
|
|
|
|
def request_waiting(self, req_id: str, lora_name: str | None):
|
|
self._request_update(req_id, lora_name, waiting=True, running=False)
|
|
|
|
def request_running(self, req_id: str, lora_name: str | None):
|
|
self._request_update(req_id, lora_name, waiting=False, running=True)
|
|
|
|
def request_finished(self, req_id: str, lora_name: str | None):
|
|
self._request_update(req_id, lora_name, waiting=False, running=False)
|
|
|
|
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
|
if not self.log_stats or scheduler_stats is None:
|
|
return
|
|
for lora_name, stats in self.requests.items():
|
|
scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
|
|
scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)
|