[V1][Metrics] Add several request timing histograms (#12644)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-02-11 15:14:00 +00:00 committed by GitHub
parent 110f59a33e
commit 75e6e14516
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 335 additions and 85 deletions

View File

@ -85,6 +85,10 @@ EXPECTED_VALUES = {
"vllm:time_per_output_token_seconds": "vllm:time_per_output_token_seconds":
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))], [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prompt_tokens": "vllm:request_prompt_tokens":
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), [("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)], ("_count", _NUM_REQUESTS)],
@ -169,6 +173,18 @@ EXPECTED_METRICS = [
"vllm:e2e_request_latency_seconds_sum", "vllm:e2e_request_latency_seconds_sum",
"vllm:e2e_request_latency_seconds_bucket", "vllm:e2e_request_latency_seconds_bucket",
"vllm:e2e_request_latency_seconds_count", "vllm:e2e_request_latency_seconds_count",
"vllm:request_queue_time_seconds_sum",
"vllm:request_queue_time_seconds_bucket",
"vllm:request_queue_time_seconds_count",
"vllm:request_inference_time_seconds_sum",
"vllm:request_inference_time_seconds_bucket",
"vllm:request_inference_time_seconds_count",
"vllm:request_prefill_time_seconds_sum",
"vllm:request_prefill_time_seconds_bucket",
"vllm:request_prefill_time_seconds_count",
"vllm:request_decode_time_seconds_sum",
"vllm:request_decode_time_seconds_bucket",
"vllm:request_decode_time_seconds_count",
"vllm:request_prompt_tokens_sum", "vllm:request_prompt_tokens_sum",
"vllm:request_prompt_tokens_bucket", "vllm:request_prompt_tokens_bucket",
"vllm:request_prompt_tokens_count", "vllm:request_prompt_tokens_count",
@ -220,6 +236,21 @@ EXPECTED_METRICS_V1 = [
"vllm:time_per_output_token_seconds_sum", "vllm:time_per_output_token_seconds_sum",
"vllm:time_per_output_token_seconds_bucket", "vllm:time_per_output_token_seconds_bucket",
"vllm:time_per_output_token_seconds_count", "vllm:time_per_output_token_seconds_count",
"vllm:e2e_request_latency_seconds_sum",
"vllm:e2e_request_latency_seconds_bucket",
"vllm:e2e_request_latency_seconds_count",
"vllm:request_queue_time_seconds_sum",
"vllm:request_queue_time_seconds_bucket",
"vllm:request_queue_time_seconds_count",
"vllm:request_inference_time_seconds_sum",
"vllm:request_inference_time_seconds_bucket",
"vllm:request_inference_time_seconds_count",
"vllm:request_prefill_time_seconds_sum",
"vllm:request_prefill_time_seconds_bucket",
"vllm:request_prefill_time_seconds_count",
"vllm:request_decode_time_seconds_sum",
"vllm:request_decode_time_seconds_bucket",
"vllm:request_decode_time_seconds_count",
] ]

View File

@ -38,7 +38,8 @@ def create_scheduler(
return Scheduler(scheduler_config, return Scheduler(scheduler_config,
model_config, model_config,
cache_config, cache_config,
lora_config=None) lora_config=None,
log_stats=True)
def create_requests( def create_requests(

View File

@ -50,7 +50,8 @@ def test_engine_core(monkeypatch):
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class) executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle.""" """Test basic request lifecycle."""
# First request. # First request.
@ -157,7 +158,8 @@ def test_engine_core_advanced_sampling(monkeypatch):
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class) executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle.""" """Test basic request lifecycle."""
# First request. # First request.
request: EngineCoreRequest = make_request() request: EngineCoreRequest = make_request()

View File

@ -94,6 +94,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
asyncio_mode=False, asyncio_mode=False,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=False,
) )
MAX_TOKENS = 20 MAX_TOKENS = 20
@ -163,6 +164,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
asyncio_mode=True, asyncio_mode=True,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=True,
) )
MAX_TOKENS = 20 MAX_TOKENS = 20

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
import time
from typing import Dict, List, Optional from typing import Dict, List, Optional
import pytest import pytest
@ -15,6 +16,7 @@ from vllm.sequence import PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.metrics.stats import IterationStats
def _ref_convert_id_to_token( def _ref_convert_id_to_token(
@ -603,6 +605,7 @@ def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
log_stats=True) log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
# Make N requests. # Make N requests.
requests = [ requests = [
@ -630,8 +633,9 @@ def test_iteration_stats(dummy_test_vectors):
# First iteration has 2 prefills. # First iteration has 2 prefills.
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs) iteration_stats = IterationStats()
iteration_stats = processed_outputs.iteration_stats output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
total_prompt_tokens = sum([ total_prompt_tokens = sum([
len(prompt_tokens) len(prompt_tokens)
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
@ -642,8 +646,9 @@ def test_iteration_stats(dummy_test_vectors):
# Just decodes in this step. # Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs) iteration_stats = IterationStats()
iteration_stats = processed_outputs.iteration_stats output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_prompt_tokens == 0
assert iteration_stats.num_generation_tokens == num_active assert iteration_stats.num_generation_tokens == num_active
@ -652,8 +657,9 @@ def test_iteration_stats(dummy_test_vectors):
output_processor.add_request(inactive_request) output_processor.add_request(inactive_request)
num_active += 1 num_active += 1
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs) iteration_stats = IterationStats()
iteration_stats = processed_outputs.iteration_stats output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_prompt_tokens == total_prompt_tokens
@ -661,8 +667,9 @@ def test_iteration_stats(dummy_test_vectors):
# Just decodes in this step. # Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs) iteration_stats = IterationStats()
iteration_stats = processed_outputs.iteration_stats output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_prompt_tokens == 0
assert iteration_stats.num_generation_tokens == num_active assert iteration_stats.num_generation_tokens == num_active

View File

@ -26,6 +26,7 @@ class KVCacheManager:
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
enable_caching: bool = True, enable_caching: bool = True,
num_preallocate_tokens: int = 64, num_preallocate_tokens: int = 64,
log_stats: bool = False,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
@ -33,6 +34,8 @@ class KVCacheManager:
self.max_num_blocks_per_req = cdiv(max_model_len, block_size) self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.enable_caching = enable_caching self.enable_caching = enable_caching
# FIXME: make prefix cache stats conditional on log_stats
self.log_stats = log_stats
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some # NOTE(woosuk): To avoid frequent block allocation, we preallocate some
# blocks for each request. For example, when a request reaches the end # blocks for each request. For example, when a request reaches the end
# of its block table, we preallocate N blocks in advance. This way, we # of its block table, we preallocate N blocks in advance. This way, we

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import time
from collections import deque from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
@ -10,7 +11,8 @@ from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreOutput, EngineCoreOutputs)
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@ -26,10 +28,12 @@ class Scheduler:
model_config: ModelConfig, model_config: ModelConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
log_stats: bool,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.log_stats = log_stats
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
@ -45,7 +49,8 @@ class Scheduler:
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
sliding_window=self.cache_config.sliding_window, sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching) enable_caching=self.cache_config.enable_prefix_caching,
log_stats=self.log_stats)
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
# req_id -> Request # req_id -> Request
@ -107,6 +112,8 @@ class Scheduler:
scheduled_encoder_inputs: Dict[str, List[int]] = {} scheduled_encoder_inputs: Dict[str, List[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens encoder_budget = self.max_num_encoder_input_tokens
scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests. # First, schedule the RUNNING requests.
req_index = 0 req_index = 0
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
@ -246,6 +253,7 @@ class Scheduler:
self.running.append(request) self.running.append(request)
if request.status == RequestStatus.WAITING: if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request) scheduled_new_reqs.append(request)
self.request_scheduled(request, scheduled_timestamp)
elif request.status == RequestStatus.PREEMPTED: elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request) scheduled_resumed_reqs.append(request)
else: else:
@ -508,7 +516,8 @@ class Scheduler:
finish_reason=request.get_finished_reason(), finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs, new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors, new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason)) stop_reason=request.stop_reason,
events=request.take_events()))
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
@ -541,6 +550,7 @@ class Scheduler:
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
self.waiting.append(request) self.waiting.append(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request
self.request_queued(request)
def finish_requests( def finish_requests(
self, self,
@ -588,7 +598,22 @@ class Scheduler:
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache() return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> SchedulerStats: def request_queued(self, request: Request):
if not self.log_stats:
return
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED))
def request_scheduled(self, request: Request, timestamp: float):
if not self.log_stats:
return
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
timestamp))
def make_stats(self) -> Optional[SchedulerStats]:
if not self.log_stats:
return None
return SchedulerStats( return SchedulerStats(
num_running_reqs=len(self.running), num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting),

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum import enum
import time
from typing import List, Optional, Union from typing import List, Optional, Union
import msgspec import msgspec
@ -60,6 +61,30 @@ class EngineCoreRequest(
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
QUEUED = 1
SCHEDULED = 2
class EngineCoreEvent(msgspec.Struct):
"""A timestamped engine core event associated with a request.
The timestamp is a monotonic timestamps and is used for by the engine
frontend to calculate intervals between engine core events. These
timestamps should not be compared with timestamps from other processes.
"""
type: EngineCoreEventType
timestamp: float
@classmethod
def new_event(cls,
event_type: EngineCoreEventType,
timestamp: Optional[float] = None) -> "EngineCoreEvent":
timestamp = time.monotonic() if timestamp is None else timestamp
return cls(event_type, timestamp)
class EngineCoreOutput( class EngineCoreOutput(
msgspec.Struct, msgspec.Struct,
array_like=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg]
@ -74,6 +99,7 @@ class EngineCoreOutput(
finish_reason: Optional[FinishReason] = None finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None stop_reason: Union[int, str, None] = None
events: Optional[List[EngineCoreEvent]] = None
@property @property
def finished(self) -> bool: def finished(self) -> bool:
@ -91,7 +117,12 @@ class EngineCoreOutputs(
# [num_reqs] # [num_reqs]
outputs: List[EngineCoreOutput] outputs: List[EngineCoreOutput]
scheduler_stats: SchedulerStats scheduler_stats: Optional[SchedulerStats]
timestamp: float = 0.0
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
class EngineCoreRequestType(enum.Enum): class EngineCoreRequestType(enum.Enum):

View File

@ -53,10 +53,12 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests self.log_requests = log_requests
self.log_stats = log_stats self.log_stats = log_stats
self.stat_loggers: List[StatLoggerBase] = [ self.stat_loggers: List[StatLoggerBase] = []
if self.log_stats:
self.stat_loggers.extend([
LoggingStatLogger(), LoggingStatLogger(),
PrometheusStatLogger(vllm_config.model_config), PrometheusStatLogger(vllm_config.model_config),
] ])
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
@ -85,6 +87,7 @@ class AsyncLLM(EngineClient):
asyncio_mode=True, asyncio_mode=True,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=self.log_stats,
) )
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
@ -246,6 +249,8 @@ class AsyncLLM(EngineClient):
# 1) Pull EngineCoreOutputs from the EngineCore. # 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async() outputs = await self.engine_core.get_output_async()
iteration_stats = IterationStats() if self.log_stats else None
# Split outputs into chunks of at most # Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long. # event loop for too long.
@ -257,14 +262,12 @@ class AsyncLLM(EngineClient):
outputs.outputs, outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
iteration_stats = None
for i, outputs_slice in enumerate(slices): for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs. # 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs( processed_outputs = self.output_processor.process_outputs(
outputs_slice, iteration_stats) outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues. # NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs assert not processed_outputs.request_outputs
iteration_stats = processed_outputs.iteration_stats
# Allow other asyncio tasks to run between chunks # Allow other asyncio tasks to run between chunks
if i + 1 < len(slices): if i + 1 < len(slices):
@ -277,7 +280,6 @@ class AsyncLLM(EngineClient):
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
assert iteration_stats is not None
self._log_stats( self._log_stats(
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
@ -299,12 +301,14 @@ class AsyncLLM(EngineClient):
def _log_stats( def _log_stats(
self, self,
scheduler_stats: SchedulerStats, scheduler_stats: Optional[SchedulerStats],
iteration_stats: IterationStats, iteration_stats: Optional[IterationStats],
): ):
if not self.log_stats: if not self.log_stats:
return return
assert scheduler_stats is not None
assert iteration_stats is not None
for logger in self.stat_loggers: for logger in self.stat_loggers:
logger.log(scheduler_stats=scheduler_stats, logger.log(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats) iteration_stats=iteration_stats)

View File

@ -38,12 +38,15 @@ class EngineCore:
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: Type[Executor], executor_class: Type[Executor],
log_stats: bool,
): ):
assert vllm_config.model_config.runner_type != "pooling" assert vllm_config.model_config.runner_type != "pooling"
logger.info("Initializing a V1 LLM engine (v%s) with config: %s", logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config) VLLM_VERSION, vllm_config)
self.log_stats = log_stats
# Setup Model. # Setup Model.
self.model_executor = executor_class(vllm_config) self.model_executor = executor_class(vllm_config)
@ -59,6 +62,7 @@ class EngineCore:
model_config=vllm_config.model_config, model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config, lora_config=vllm_config.lora_config,
log_stats=self.log_stats,
) )
self.mm_input_mapper_server = MMInputMapperServer( self.mm_input_mapper_server = MMInputMapperServer(
@ -148,11 +152,9 @@ class EngineCoreProc(EngineCore):
ready_pipe: Connection, ready_pipe: Connection,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: Type[Executor], executor_class: Type[Executor],
log_stats: bool = False, log_stats: bool,
): ):
super().__init__(vllm_config, executor_class) super().__init__(vllm_config, executor_class, log_stats)
self.log_stats = log_stats
# Background Threads and Queues for IO. These enable us to # Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL, # overlap ZMQ socket IO with GPU since they release the GIL,

View File

@ -41,6 +41,7 @@ class EngineCoreClient(ABC):
asyncio_mode: bool, asyncio_mode: bool,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: Type[Executor], executor_class: Type[Executor],
log_stats: bool,
) -> "EngineCoreClient": ) -> "EngineCoreClient":
# TODO: support this for debugging purposes. # TODO: support this for debugging purposes.
@ -50,12 +51,12 @@ class EngineCoreClient(ABC):
"is not currently supported.") "is not currently supported.")
if multiprocess_mode and asyncio_mode: if multiprocess_mode and asyncio_mode:
return AsyncMPClient(vllm_config, executor_class) return AsyncMPClient(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode: if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class) return SyncMPClient(vllm_config, executor_class, log_stats)
return InprocClient(vllm_config, executor_class) return InprocClient(vllm_config, executor_class, log_stats)
@abstractmethod @abstractmethod
def shutdown(self): def shutdown(self):
@ -204,13 +205,13 @@ class MPClient(EngineCoreClient):
class SyncMPClient(MPClient): class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore.""" """Synchronous client for multi-proc EngineCore."""
def __init__(self, vllm_config: VllmConfig, def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor],
executor_class: Type[Executor]): log_stats: bool):
super().__init__( super().__init__(
asyncio_mode=False, asyncio_mode=False,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=False, log_stats=log_stats,
) )
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
@ -245,13 +246,13 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
def __init__(self, vllm_config: VllmConfig, def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor],
executor_class: Type[Executor]): log_stats: bool):
super().__init__( super().__init__(
asyncio_mode=True, asyncio_mode=True,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=True, log_stats=log_stats,
) )
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None self.outputs_queue: Optional[asyncio.Queue[bytes]] = None

View File

@ -73,6 +73,7 @@ class LLMEngine:
asyncio_mode=False, asyncio_mode=False,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=False, # FIXME: implement
) )
@classmethod @classmethod

View File

@ -19,7 +19,6 @@ class OutputProcessorOutput:
request_outputs: List[RequestOutput] request_outputs: List[RequestOutput]
reqs_to_abort: List[str] reqs_to_abort: List[str]
iteration_stats: IterationStats
class RequestState: class RequestState:
@ -34,6 +33,7 @@ class RequestState:
detokenizer: IncrementalDetokenizer, detokenizer: IncrementalDetokenizer,
arrival_time: float, arrival_time: float,
queue: Optional[asyncio.Queue[RequestOutput]], queue: Optional[asyncio.Queue[RequestOutput]],
log_stats: bool,
): ):
self.request_id = request_id self.request_id = request_id
self.output_kind = output_kind self.output_kind = output_kind
@ -45,14 +45,16 @@ class RequestState:
self.is_prefilling = True self.is_prefilling = True
self.queue = queue self.queue = queue
self.stats = RequestStateStats(last_token_time=arrival_time) self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None
@classmethod @classmethod
def from_new_request( def from_new_request(
cls, cls,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request: EngineCoreRequest, request: EngineCoreRequest,
queue: Optional[asyncio.Queue[RequestOutput]] = None, queue: Optional[asyncio.Queue[RequestOutput]],
log_stats: bool,
) -> "RequestState": ) -> "RequestState":
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
@ -69,6 +71,7 @@ class RequestState:
), ),
arrival_time=request.arrival_time, arrival_time=request.arrival_time,
queue=queue, queue=queue,
log_stats=log_stats,
) )
@ -112,11 +115,13 @@ class OutputProcessor:
self.request_states[request_id] = RequestState.from_new_request( self.request_states[request_id] = RequestState.from_new_request(
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
request=request, request=request,
queue=queue) queue=queue,
log_stats=self.log_stats)
def process_outputs( def process_outputs(
self, self,
engine_core_outputs: List[EngineCoreOutput], engine_core_outputs: List[EngineCoreOutput],
engine_core_timestamp: Optional[float] = None,
iteration_stats: Optional[IterationStats] = None, iteration_stats: Optional[IterationStats] = None,
) -> OutputProcessorOutput: ) -> OutputProcessorOutput:
""" """
@ -145,8 +150,6 @@ class OutputProcessor:
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
reqs_to_abort: List[str] = [] reqs_to_abort: List[str] = []
if not iteration_stats:
iteration_stats = IterationStats(self.log_stats)
for engine_core_output in engine_core_outputs: for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id) req_state = self.request_states.get(req_id)
@ -155,10 +158,9 @@ class OutputProcessor:
continue continue
# 1) Compute stats for this iteration. # 1) Compute stats for this iteration.
iteration_stats.update_from_output(engine_core_output, self._update_stats_from_output(req_state, engine_core_output,
req_state.is_prefilling, engine_core_timestamp,
req_state.prompt_len, iteration_stats)
req_state.stats)
new_token_ids = engine_core_output.new_token_ids new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason finish_reason = engine_core_output.finish_reason
@ -205,17 +207,44 @@ class OutputProcessor:
# detected stop string, abort needed in EngineCore. # detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id) reqs_to_abort.append(req_id)
# Track per-request stats. # Track per-request stats
assert finish_reason is not None self._update_stats_from_finished(req_state, request_output,
iteration_stats.update_from_finished_request( finish_reason,
finish_reason, request_output, req_state.stats) iteration_stats)
return OutputProcessorOutput( return OutputProcessorOutput(
request_outputs=request_outputs, request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort, reqs_to_abort=reqs_to_abort,
iteration_stats=iteration_stats,
) )
def _update_stats_from_output(self, req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: Optional[float],
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
assert engine_core_timestamp is not None
assert req_state.stats is not None
iteration_stats.update_from_output(engine_core_output,
engine_core_timestamp,
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats)
def _update_stats_from_finished(self, req_state: RequestState,
request_output: RequestOutput,
finish_reason: Optional[FinishReason],
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
assert finish_reason is not None
assert req_state.stats is not None
iteration_stats.update_from_finished_request(finish_reason,
request_output,
req_state.stats)
@staticmethod @staticmethod
def _make_request_output( def _make_request_output(
request_state: RequestState, request_state: RequestState,

View File

@ -182,6 +182,45 @@ class PrometheusStatLogger(StatLoggerBase):
], ],
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
request_latency_buckets = [
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0
]
self.histogram_e2e_time_request = \
prometheus_client.Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of e2e request latency in seconds.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_queue_time_request = \
prometheus_client.Histogram(
name="vllm:request_queue_time_seconds",
documentation=
"Histogram of time spent in WAITING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_inference_time_request = \
prometheus_client.Histogram(
name="vllm:request_inference_time_seconds",
documentation=
"Histogram of time spent in RUNNING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_prefill_time_request = \
prometheus_client.Histogram(
name="vllm:request_prefill_time_seconds",
documentation=
"Histogram of time spent in PREFILL phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_decode_time_request = \
prometheus_client.Histogram(
name="vllm:request_decode_time_seconds",
documentation=
"Histogram of time spent in DECODE phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
def log(self, scheduler_stats: SchedulerStats, def log(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats): iteration_stats: IterationStats):
"""Log to prometheus.""" """Log to prometheus."""
@ -201,6 +240,12 @@ class PrometheusStatLogger(StatLoggerBase):
for finished_request in iteration_stats.finished_requests: for finished_request in iteration_stats.finished_requests:
self.counter_request_success[finished_request.finish_reason].inc() self.counter_request_success[finished_request.finish_reason].inc()
self.histogram_e2e_time_request.observe(
finished_request.e2e_latency)
self.histogram_inference_time_request.observe(
finished_request.inference_time)
self.histogram_decode_time_request.observe(
finished_request.decode_time)
self.histogram_num_prompt_tokens_request.observe( self.histogram_num_prompt_tokens_request.observe(
finished_request.num_prompt_tokens) finished_request.num_prompt_tokens)
self.histogram_num_generation_tokens_request.observe( self.histogram_num_generation_tokens_request.observe(
@ -210,6 +255,10 @@ class PrometheusStatLogger(StatLoggerBase):
self.histogram_time_to_first_token.observe(ttft) self.histogram_time_to_first_token.observe(ttft)
for tpot in iteration_stats.time_per_output_tokens_iter: for tpot in iteration_stats.time_per_output_tokens_iter:
self.histogram_time_per_output_token.observe(tpot) self.histogram_time_per_output_token.observe(tpot)
for queue_time in iteration_stats.queue_times_iter:
self.histogram_queue_time_request.observe(queue_time)
for prefill_time in iteration_stats.prefill_times_iter:
self.histogram_prefill_time_request.observe(prefill_time)
@staticmethod @staticmethod
def _unregister_vllm_metrics(): def _unregister_vllm_metrics():

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
@dataclass @dataclass
@ -41,7 +41,15 @@ class RequestStateStats:
"""Stats that need to be tracked across delta updates.""" """Stats that need to be tracked across delta updates."""
num_generation_tokens: int = 0 num_generation_tokens: int = 0
last_token_time: float = 0.0
# This is a engine frontend timestamp (wall-clock)
arrival_time: float = 0.0
# These are engine core timestamps (monotonic)
queued_ts: float = 0.0
scheduled_ts: float = 0.0
first_token_ts: float = 0.0
last_token_ts: float = 0.0
@dataclass @dataclass
@ -49,33 +57,37 @@ class FinishedRequestStats:
"""Stats associated with a finished request.""" """Stats associated with a finished request."""
finish_reason: "FinishReason" finish_reason: "FinishReason"
e2e_latency: float = 0.0
num_prompt_tokens: int = 0 num_prompt_tokens: int = 0
num_generation_tokens: int = 0 num_generation_tokens: int = 0
inference_time: float = 0.0
decode_time: float = 0.0
class IterationStats: class IterationStats:
"""Stats associated with a single set of EngineCoreOutputs.""" """Stats associated with a single set of EngineCoreOutputs."""
def __init__(self, log_stats: bool): def __init__(self):
self.log_stats = log_stats self.iteration_timestamp = time.time()
self.num_generation_tokens = 0 self.num_generation_tokens = 0
self.num_prompt_tokens = 0 self.num_prompt_tokens = 0
self.finished_requests: List[FinishedRequestStats] = [] self.finished_requests: List[FinishedRequestStats] = []
self.time_to_first_tokens_iter: List[float] = [] self.time_to_first_tokens_iter: List[float] = []
self.time_per_output_tokens_iter: List[float] = [] self.time_per_output_tokens_iter: List[float] = []
self.queue_times_iter: List[float] = []
self.prefill_times_iter: List[float] = []
def _time_since(self, start: float) -> float:
"""Calculate an interval relative to this iteration's timestamp."""
return self.iteration_timestamp - start
def update_from_output(self, output: "EngineCoreOutput", def update_from_output(self, output: "EngineCoreOutput",
is_prefilling: bool, prompt_len: int, engine_core_timestamp: float, is_prefilling: bool,
request_state_stats: RequestStateStats): prompt_len: int, req_stats: RequestStateStats):
if not self.log_stats:
return
num_new_generation_tokens = len(output.new_token_ids) num_new_generation_tokens = len(output.new_token_ids)
now = time.time()
last_token_latency = now - request_state_stats.last_token_time
self.num_generation_tokens += num_new_generation_tokens self.num_generation_tokens += num_new_generation_tokens
if is_prefilling: if is_prefilling and num_new_generation_tokens > 0:
# TODO(andy): we used to assert that num_new_generation_tokens # TODO(andy): we used to assert that num_new_generation_tokens
# > 0 with an invariant that EngineCore does not stream outputs # > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output # for partially completed prefills (scheduler.update_from_output
@ -84,19 +96,58 @@ class IterationStats:
# partially completed prompt. # partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable # This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant. # this assertion / invariant.
if num_new_generation_tokens > 0:
self.num_prompt_tokens += prompt_len self.num_prompt_tokens += prompt_len
self.time_to_first_tokens_iter.append(last_token_latency)
else:
self.time_per_output_tokens_iter.append(last_token_latency)
request_state_stats.num_generation_tokens += num_new_generation_tokens first_token_latency = self._time_since(req_stats.arrival_time)
request_state_stats.last_token_time = now self.time_to_first_tokens_iter.append(first_token_latency)
req_stats.num_generation_tokens += num_new_generation_tokens
# Process request-level engine core events
if output.events is not None:
self.update_from_events(output.events, is_prefilling, req_stats)
# Process the batch-level "new tokens" engine core event
if is_prefilling:
# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
prefill_interval = \
engine_core_timestamp - req_stats.scheduled_ts
self.prefill_times_iter.append(prefill_interval)
req_stats.first_token_ts = engine_core_timestamp
else:
tpot = engine_core_timestamp - req_stats.last_token_ts
self.time_per_output_tokens_iter.append(tpot)
# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, events: List["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats):
# Avoid circular dependency
from vllm.v1.engine import EngineCoreEventType
for event in events:
if event.type == EngineCoreEventType.QUEUED:
req_stats.queued_ts = event.timestamp
elif event.type == EngineCoreEventType.SCHEDULED:
queued_interval = event.timestamp - req_stats.queued_ts
self.queue_times_iter.append(queued_interval)
req_stats.scheduled_ts = event.timestamp
def update_from_finished_request(self, finish_reason: "FinishReason", def update_from_finished_request(self, finish_reason: "FinishReason",
request_output: "RequestOutput", request_output: "RequestOutput",
request_state_stats: RequestStateStats): req_stats: RequestStateStats):
self.finished_requests.append( e2e_latency = self._time_since(req_stats.arrival_time)
FinishedRequestStats(finish_reason,
len(request_output.prompt_token_ids), inference_time = req_stats.last_token_ts - req_stats.scheduled_ts
request_state_stats.num_generation_tokens)) decode_time = req_stats.last_token_ts - req_stats.first_token_ts
finished_req = \
FinishedRequestStats(finish_reason=finish_reason,
e2e_latency=e2e_latency,
num_prompt_tokens=len(request_output.prompt_token_ids),
num_generation_tokens=req_stats.num_generation_tokens,
inference_time=inference_time,
decode_time=decode_time)
self.finished_requests.append(finished_req)

View File

@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, List, Optional, Union
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
from vllm.v1.engine import EngineCoreRequest, FinishReason EngineCoreRequest, FinishReason)
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING: if TYPE_CHECKING:
@ -33,14 +33,10 @@ class Request:
self.sampling_params = sampling_params self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request. # Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.metrics = RequestMetrics(arrival_time=arrival_time,
last_token_time=arrival_time,
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
self.lora_request = lora_request self.lora_request = lora_request
self.status = RequestStatus.WAITING self.status = RequestStatus.WAITING
self.events: List[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens self.max_tokens = sampling_params.max_tokens
@ -83,6 +79,21 @@ class Request:
lora_request=request.lora_request, lora_request=request.lora_request,
) )
def queued(self, timestamp: Optional[float] = None) -> None:
self.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, timestamp))
def scheduled(self, timestamp: Optional[float] = None) -> None:
self.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
timestamp))
def take_events(self) -> Optional[List[EngineCoreEvent]]:
if not self.events:
return None
events, self.events = self.events, []
return events
def append_output_token_ids( def append_output_token_ids(
self, self,
token_ids: Union[int, List[int]], token_ids: Union[int, List[int]],