[v1][stats][1/n] Add RequestStatsUpdate and RequestStats types (#10907)

Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
Ricky Xu 2025-01-21 11:51:13 -08:00 committed by GitHub
parent 1e60f87bb3
commit 132a132100
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 749 additions and 0 deletions

300
tests/v1/test_stats.py Normal file
View File

@ -0,0 +1,300 @@
import pytest
from vllm.sampling_params import SamplingParams
from vllm.v1.stats.common import RequestStats, RequestStatsUpdate
def make_update(
request_id: str,
update_type: RequestStatsUpdate.Type,
monotonic_ts_s: float,
**kwargs,
):
if update_type == RequestStatsUpdate.Type.INPUT_PROCESSED:
kwargs.setdefault("sampling_params", SamplingParams(n=1))
kwargs.setdefault("num_prompt_tokens", 10)
elif update_type == RequestStatsUpdate.Type.PREFILLING:
kwargs.setdefault("num_computed_tokens", 10)
kwargs.setdefault("num_cached_tokens", 10)
elif update_type == RequestStatsUpdate.Type.DETOKENIZED:
kwargs.setdefault("num_new_tokens", 10)
elif update_type == RequestStatsUpdate.Type.FINISHED:
kwargs.setdefault("finish_reason", "test_reason")
return RequestStatsUpdate(
request_id=request_id,
type=update_type,
monotonic_ts_s=monotonic_ts_s,
**kwargs,
)
def test_invalid_request_update():
request_id = "test_request"
update_specific_required_fields = {
RequestStatsUpdate.Type.INPUT_PROCESSED: [
"sampling_params",
"num_prompt_tokens",
],
RequestStatsUpdate.Type.PREFILLING: [
"num_computed_tokens",
"num_cached_tokens",
],
RequestStatsUpdate.Type.DETOKENIZED: ["num_new_tokens"],
RequestStatsUpdate.Type.FINISHED: ["finish_reason"],
}
# Missing a required field should raise an assertion error.
for update_type in RequestStatsUpdate.Type:
required_fields = update_specific_required_fields.get(update_type, [])
# Try to miss one of the required fields.
kwargs = {field: object() for field in required_fields}
for field in required_fields:
copy_kwargs = kwargs.copy()
copy_kwargs.pop(field)
with pytest.raises(ValueError):
RequestStatsUpdate(
request_id=request_id,
type=update_type,
**copy_kwargs,
)
def test_invalid_request_update_transition():
# Test invalid transition type.
for src in RequestStatsUpdate.Type:
for dst in RequestStatsUpdate.Type:
if dst not in RequestStatsUpdate._VALID_TRANSITIONS[src]:
with pytest.raises(AssertionError):
RequestStatsUpdate.check_valid_update(
make_update(
update_type=dst,
request_id="test_request",
monotonic_ts_s=1,
),
last_update_type=src,
last_updated_ts_s=0,
)
else:
RequestStatsUpdate.check_valid_update(
make_update(
request_id="test_request",
update_type=dst,
monotonic_ts_s=1,
),
last_update_type=src,
last_updated_ts_s=0,
)
# Test invalid timestamp.
with pytest.raises(AssertionError):
RequestStatsUpdate.check_valid_update(
make_update(
request_id="test_request",
update_type=RequestStatsUpdate.Type.ARRIVED,
monotonic_ts_s=1,
),
last_update_type=None,
last_updated_ts_s=2,
)
def test_lifecycle_updates():
request_id = "test_request"
stats = RequestStats(request_id=request_id)
# Test the below scenario:
arrived_ts = 0
input_processed_ts = 1
queued_ts = 2
prefilling_ts = 3
decoded_ts = 5
detokenized_ts = 6
decoded_2_ts = 7
detokenized_2_ts = 8
preempted_ts = 9
resumed_ts = 10
decoded_3_ts = 11
detokenized_3_ts = 12
finished_ts = 13
# Test ARRIVED
arrived_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.ARRIVED,
monotonic_ts_s=arrived_ts,
)
stats.update_from(arrived_update)
assert stats.arrival_ts_s == arrived_ts
assert stats.last_updated_ts_s == arrived_ts
# Test INPUT_PROCESSED
sampling_params = SamplingParams(n=1)
input_processed_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.INPUT_PROCESSED,
monotonic_ts_s=input_processed_ts,
sampling_params=sampling_params,
num_prompt_tokens=6,
)
stats.update_from(input_processed_update)
assert stats.input_processor_end_ts_s == input_processed_ts
assert stats.last_updated_ts_s == input_processed_ts
assert stats.num_prompt_tokens == 6
assert stats.sampling_params == sampling_params
assert stats.first_token_ts_s is None
assert stats.prefill_ts_s is None
# Test QUEUED
queued_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.QUEUED,
monotonic_ts_s=queued_ts,
)
stats.update_from(queued_update)
assert stats.queued_ts_s == queued_ts
assert stats.last_updated_ts_s == queued_ts
# Test PREFILLING
prefilling_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.PREFILLING,
monotonic_ts_s=prefilling_ts,
num_computed_tokens=3,
num_cached_tokens=1,
)
stats.update_from(prefilling_update)
assert stats.prefill_ts_s == prefilling_ts
assert stats.num_computed_tokens == 3
assert stats.num_cached_tokens == 1
assert stats.queue_duration_s == prefilling_ts - queued_ts
# Test DECODING
decoded_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DECODING,
monotonic_ts_s=decoded_ts,
)
stats.update_from(decoded_update)
assert stats.last_updated_ts_s == decoded_ts
# Test DETOKENIZED
detokenized_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DETOKENIZED,
monotonic_ts_s=detokenized_ts,
num_new_tokens=1,
)
stats.update_from(detokenized_update)
assert stats.last_updated_ts_s == detokenized_ts
assert stats.num_output_tokens == 1
# Since arrival
assert stats.first_token_latency_s == detokenized_ts - arrived_ts
# Since first scheduled
assert stats.prefill_latency_s == detokenized_ts - prefilling_ts
# Test another DECODING and DETOKENIZED should
# yield correct inter token latency
decoded_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DECODING,
monotonic_ts_s=decoded_2_ts,
)
stats.update_from(decoded_update)
detokenized_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DETOKENIZED,
monotonic_ts_s=detokenized_2_ts,
num_new_tokens=1,
)
stats.update_from(detokenized_update)
assert stats.output_token_latency_s_lst == [
detokenized_2_ts - detokenized_ts,
]
assert stats.num_output_tokens == 2
# Test PREEMPTED
preempted_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.PREEMPTED,
monotonic_ts_s=preempted_ts,
)
stats.update_from(preempted_update)
assert stats.last_updated_ts_s == preempted_ts
assert stats.preempted_ts_s_lst == [preempted_ts]
# States should be reset
assert stats.num_computed_tokens == 0
assert stats.num_cached_tokens == 0
# These states should not be reset
assert stats.num_output_tokens == 2
assert stats.output_token_latency_s_lst == [
detokenized_2_ts - detokenized_ts,
]
assert stats.prefill_latency_s == prefilling_ts - arrived_ts
assert stats.num_prompt_tokens == 6
assert stats.prefill_start_ts_s_lst == [prefilling_ts]
# Test resumed
resumed_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.PREFILLING,
monotonic_ts_s=resumed_ts,
num_computed_tokens=6,
num_cached_tokens=2,
)
stats.update_from(resumed_update)
# prefill timestamp should not be updated since it's a resumed prefill
assert stats.prefill_ts_s == prefilling_ts
assert stats.num_computed_tokens == 6
assert stats.num_cached_tokens == 2
assert stats.prefill_start_ts_s_lst == [
prefilling_ts,
resumed_ts,
]
assert stats.last_updated_ts_s == resumed_ts
# Test another DECODED/DETOKENIZED should yield correct first token latency.
decoded_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DECODING,
monotonic_ts_s=decoded_3_ts,
)
detokenized_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DETOKENIZED,
monotonic_ts_s=detokenized_3_ts,
num_new_tokens=1,
)
stats.update_from(decoded_update)
stats.update_from(detokenized_update)
assert stats.first_token_ts_s == detokenized_ts - arrived_ts
assert stats.num_output_tokens == 3
assert stats.output_token_latency_s_lst == [
detokenized_2_ts - detokenized_ts,
detokenized_3_ts - detokenized_2_ts,
]
# Test FINISHED
finished_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.FINISHED,
monotonic_ts_s=finished_ts,
finish_reason="test_reason",
)
stats.update_from(finished_update)
assert stats.last_updated_ts_s == finished_ts
assert stats.e2e_latency_s == finished_ts - arrived_ts
assert stats.inference_latency_s == finished_ts - prefilling_ts
assert stats.prefill_latency_s == detokenized_ts - prefilling_ts
assert stats.decode_latency_s == finished_ts - detokenized_ts
assert stats.first_token_latency_s == detokenized_ts - arrived_ts
assert stats.queue_duration_s == prefilling_ts - queued_ts
assert stats.is_finished
assert stats.finish_reason == "test_reason"
# TODO(rickyx): Add model forward/execute time.
assert stats.model_forward_duration_s == 0.0
assert stats.model_execute_duration_s == 0.0

View File

449
vllm/v1/stats/common.py Normal file
View File

@ -0,0 +1,449 @@
import time
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import IntEnum
from typing import ClassVar, Dict, List, Optional, Set
import msgspec
from msgspec import field as msgspec_field
from vllm.sampling_params import SamplingParams
class RequestStatsUpdate(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
"""
An update to the request stats.
This represents a stats update at a specific timestamp with metadata
associated with the update.
NOTE: since there might be multiple processes generating updates at
different parts of the engine (e.g. input processor, scheduler, engine core,
etc.), we use the monotonic timestamp to record the update to compute any
intervals, and explicit wall-clock timestamp should be used for timestamps.
WARNING: This assumes stats are generated in a single machine. If there are
potentially multiple machines, one should always generate the stats updates
on one single machine or use something else.
"""
class Type(IntEnum):
"""See `RequestStats` for the lifecycle of a request."""
# Request arrived at the engine frontend.
ARRIVED = 0
# Input processed by the input processor.
INPUT_PROCESSED = 1
# Queued on the engine core.
QUEUED = 2
# Scheduled running prefill by the scheduler.
# A request could be running a new prefill on the prompt tokens or
# a resumed prefill on the original prefill tokens + generated output
# tokens before preemption.
PREFILLING = 3
# Preempted by the scheduler.
PREEMPTED = 4
# Output token is generated by the engine core.
DECODING = 5
# Token detokenized by the detokenizer.
# We will record the timestamp for each output token, as well as the
# finish reason.
DETOKENIZED = 6
# Request finishes (or aborts).
FINISHED = 7
"""
Valid state updates:
ARRIVED
INPUT_PROCESSED QUEUED PREFILLING
- DECODING
|
|
DETOKENIZED
PREEMPTED
FINISHED (All could go to FINISHED)
"""
_VALID_TRANSITIONS: ClassVar[Dict[Type, Set[Type]]] = {
Type.ARRIVED: {
Type.INPUT_PROCESSED,
Type.FINISHED,
},
Type.INPUT_PROCESSED: {
Type.QUEUED,
Type.FINISHED,
},
Type.QUEUED: {
Type.PREFILLING,
Type.FINISHED,
},
Type.PREFILLING: {
Type.DECODING,
Type.PREEMPTED,
Type.FINISHED,
},
Type.DECODING: {
Type.DETOKENIZED,
Type.FINISHED,
},
Type.DETOKENIZED: {
Type.DECODING,
Type.PREEMPTED,
Type.FINISHED,
},
Type.PREEMPTED: {Type.PREFILLING, Type.FINISHED},
Type.FINISHED: set(),
}
request_id: str
type: Type
# Timestamp when the update is recorded. This is used to record time
# intervals between events rather than wall clock time.
monotonic_ts_s: float = msgspec_field(
default_factory=lambda: time.monotonic())
############################################################
# Metadata associated with the update.
############################################################
# For input_processed. Metadata needed for stats logging.
num_prompt_tokens: Optional[int] = None
sampling_params: Optional[SamplingParams] = None
# For running.
# Number of tokens computed when scheduled to run.
num_computed_tokens: Optional[int] = None
# Number of cached tokens when scheduled to run.
num_cached_tokens: Optional[int] = None
# For decoded.
# The number of new output tokens generated.
num_new_tokens: Optional[int] = None
# For both detokenized and decoded.
# Finished reason.
finish_reason: Optional[str] = None
# Non-optional fields for each update type.
_REQUIRED_FIELDS: ClassVar[Dict[Type, List[str]]] = {
Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"],
Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"],
Type.DETOKENIZED: ["num_new_tokens"],
Type.FINISHED: ["finish_reason"],
}
def __post_init__(self):
required_fields = self._REQUIRED_FIELDS.get(self.type, [])
for field in required_fields:
if getattr(self, field) is None:
raise ValueError(
f"Field {field} is required for update type {self.type}.")
@staticmethod
def check_valid_update(
update: "RequestStatsUpdate",
last_update_type: Optional[Type],
last_updated_ts_s: Optional[float],
):
if last_update_type is None:
assert update.type == RequestStatsUpdate.Type.ARRIVED
else:
valid_cur_update_types = RequestStatsUpdate._VALID_TRANSITIONS[
last_update_type]
assert update.type in valid_cur_update_types, (
f"Invalid update type: {update.type} for last_update_type: "
f"{last_update_type}.")
if last_updated_ts_s is not None:
assert update.monotonic_ts_s >= last_updated_ts_s, (
"Update timestamp must be monotonically increasing, but "
f"last_updated_ts_s={last_updated_ts_s} and "
f"update.monotonic_ts_s={update.monotonic_ts_s}.")
@dataclass
class RequestStats:
"""Stats associated with a request (`Request`)."""
############################################################
# Metadata
############################################################
request_id: str
sampling_params: Optional[SamplingParams] = None
num_prompt_tokens: Optional[int] = None
############################################################
# Metrics and Stats
############################################################
# Timestamp when the request was last updated.
last_updated_ts_s: Optional[float] = None
# Last update stats type.
last_update_type: Optional[RequestStatsUpdate.Type] = None
# Timestamp when the request arrived at the llm engine.
arrival_ts_s: Optional[float] = None
# Number of tokens cached. When part of the request prefix is cached,
# this will be set.
num_cached_tokens: int = 0
# Number of tokens computed.
num_computed_tokens: int = 0
# The timestamp when the request become waiting in the queue.
queued_ts_s: Optional[float] = None
# When the input processor is completed.
input_processor_end_ts_s: Optional[float] = None
# A sorted list of timestamps when the request was scheduled to prefill.
# This could be when:
# 1. the request is newly scheduled, so it's a new prefill.
# 2. the request was preempted and resumed. It is equivalent to running
# a prefill of the original prefill tokens + generated output tokens
# before preemption.
prefill_start_ts_s_lst: List[float] = dataclass_field(default_factory=list)
# A list of timestamps when a token is decoded by the engine core.
decoding_ts_s_lst: List[float] = dataclass_field(default_factory=list)
# A sorted list of timestamps for each output token.
output_token_ts_s_lst: List[float] = dataclass_field(default_factory=list)
# First token's timestamp.
first_token_ts_s: Optional[float] = None
# TODO(rickyx): we need model runner to surface these.
model_forward_duration_s: float = 0.0
# Includes model forward, block/sync across workers, cpu-gpu sync time
# and sampling time.
model_execute_duration_s: float = 0.0
# A sorted list of timestamps when the request was preempted at the
# scheduler.
# TODO(rickyx): right now, we don't actually have a good high-level
# metric to measure the impact of preemption other than observation of
# large P99 TPOT. Ideally we could quantify the impact of preemption by
# measuring the number of tokens re-computed due to preemption.
preempted_ts_s_lst: List[float] = dataclass_field(default_factory=list)
# Timestamp when the request was finished at the engine core.
finished_ts_s: Optional[float] = None
# Finish reason.
finish_reason: Optional[str] = None
############################################################
# Derived properties.
############################################################
@property
def prefill_ts_s(self) -> Optional[float]:
"""The timestamp when the request started prefilling.
Since a request could be preempted in decoding and later resumed
to prefill the decoded tokens, we use the first prefill start timestamp.
"""
return (self.prefill_start_ts_s_lst[0]
if self.prefill_start_ts_s_lst else None)
@property
def e2e_latency_s(self) -> Optional[float]:
if self.finished_ts_s is None or self.arrival_ts_s is None:
return None
assert self.finished_ts_s >= self.arrival_ts_s
return self.finished_ts_s - self.arrival_ts_s
@property
def queue_duration_s(self) -> Optional[float]:
"""How long the request was waiting to run."""
if self.queued_ts_s is None or self.prefill_ts_s is None:
# Either not queued or not running yet.
return None
assert self.queued_ts_s <= self.prefill_ts_s
return self.prefill_ts_s - self.queued_ts_s
@property
def inference_latency_s(self) -> Optional[float]:
"""How long the request was running inference
(prefill and decode)."""
if self.finished_ts_s is None or self.prefill_ts_s is None:
return None
assert self.finished_ts_s >= self.prefill_ts_s
return self.finished_ts_s - self.prefill_ts_s
@property
def first_token_latency_s(self) -> Optional[float]:
if self.first_token_ts_s is None or self.arrival_ts_s is None:
return None
assert self.first_token_ts_s >= self.arrival_ts_s
return self.first_token_ts_s - self.arrival_ts_s
@property
def prefill_latency_s(self) -> Optional[float]:
if self.first_token_ts_s is None or self.prefill_ts_s is None:
return None
assert self.first_token_ts_s >= self.prefill_ts_s
return self.first_token_ts_s - self.prefill_ts_s
@property
def decode_latency_s(self) -> Optional[float]:
if self.e2e_latency_s is None or self.first_token_latency_s is None:
return None
assert self.e2e_latency_s >= self.first_token_latency_s
return self.e2e_latency_s - self.first_token_latency_s
@property
def output_token_latency_s_lst(self) -> List[float]:
if len(self.output_token_ts_s_lst) == 0:
return []
latency_s_lst = []
for i in range(1, len(self.output_token_ts_s_lst)):
assert (self.output_token_ts_s_lst[i] >=
self.output_token_ts_s_lst[i - 1])
latency_s = (self.output_token_ts_s_lst[i] -
self.output_token_ts_s_lst[i - 1])
latency_s_lst.append(latency_s)
return latency_s_lst
@property
def num_output_tokens(self) -> int:
return len(self.output_token_ts_s_lst)
@property
def is_finished(self) -> bool:
return self.finished_ts_s is not None
def update_from(self, update: "RequestStatsUpdate"):
RequestStatsUpdate.check_valid_update(update, self.last_update_type,
self.last_updated_ts_s)
ts = update.monotonic_ts_s
self.last_updated_ts_s = ts
self.last_update_type = update.type
if update.type == RequestStatsUpdate.Type.ARRIVED:
self.arrival_ts_s = ts
elif update.type == RequestStatsUpdate.Type.INPUT_PROCESSED:
self.input_processor_end_ts_s = ts
self.sampling_params = update.sampling_params
self.num_prompt_tokens = update.num_prompt_tokens
elif update.type == RequestStatsUpdate.Type.QUEUED:
self.queued_ts_s = ts
elif update.type == RequestStatsUpdate.Type.PREFILLING:
self.prefill_start_ts_s_lst.append(ts)
self.num_cached_tokens = update.num_cached_tokens
self.num_computed_tokens = update.num_computed_tokens
elif update.type == RequestStatsUpdate.Type.PREEMPTED:
self._reset_for_preemption(ts)
elif update.type == RequestStatsUpdate.Type.DECODING:
self.decoding_ts_s_lst.append(ts)
elif update.type == RequestStatsUpdate.Type.DETOKENIZED:
self._record_detokenized_output(
ts,
update.num_new_tokens,
)
elif update.type == RequestStatsUpdate.Type.FINISHED:
self.finished_ts_s = ts
self.finish_reason = update.finish_reason
else:
raise ValueError(f"Unknown update type: {update.type}")
def _record_detokenized_output(
self,
ts_s: float,
num_new_tokens: int,
):
# Update if first output token is generated.
if len(self.output_token_ts_s_lst) == 0:
self.first_token_ts_s = ts_s
assert (
self.prefill_ts_s is not None
), "Request must be running before generating output tokens."
# Some X new tokens were generated at the ts.
self.output_token_ts_s_lst.extend([ts_s] * num_new_tokens)
def _reset_for_preemption(self, ts_s: float):
self.preempted_ts_s_lst.append(ts_s)
# Reset the computed tokens since it might restart the prefill.
self.num_computed_tokens = 0
# Cached token count might also change when resumed.
self.num_cached_tokens = 0
# These stats don't change since they happen before request running.
# - arrival_ts_s
# - input_processor_end_ts_s
# - sampling_params
# - num_prompt_tokens
# - first_token_ts_s
#
# These stats are accumulated over preemptions:
# - output_token_ts_s_lst
# - prefill_start_ts_s_lst (after preemption, it will prefill the
# original prefill tokens and any output tokens generated before
# preemption.)
@dataclass
class KVCacheStats:
# KV Cache Usage in %
gpu_cache_usage_sys: float = 0.0
gpu_prefix_cache_hit_rate: float = 0.0
@dataclass
class SchedulerStats:
"""Stats associated with the scheduler."""
# Number of requests currently running.
num_running_reqs: int = 0
# Number of requests currently waiting.
num_waiting_reqs: int = 0
kv_cache_stats: KVCacheStats = dataclass_field(
default_factory=KVCacheStats)
@dataclass
class EngineCoreProcessStats:
"""Stats associated with the engine core process."""
# Number of requests currently in the input queue. None if the engine core
# is not running in multiprocess mode.
input_queue_size: Optional[int] = None
# Number of outputs currently in the output queue. None if the engine core
# is not running in multiprocess mode.
output_queue_size: Optional[int] = None
class EngineCoreStatsSnapshot(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
"""
A snapshot of the EngineCore's current stats over a period of time.
"""
# Snapshot of the scheduler stats.
scheduler_stats: SchedulerStats = msgspec_field(
default_factory=SchedulerStats)
# Per request stats updates.
requests_stats_updates: List[RequestStatsUpdate] = msgspec_field(
default_factory=list)
# Engine core's queue stats.
engine_core_process_stats: EngineCoreProcessStats = msgspec_field(
default_factory=EngineCoreProcessStats)
# TODO(rickyx): Add other components' stats,
# e.g. model runner/worker and etc.