diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 34b648b6e99d..45a387a14adf 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -96,9 +96,14 @@ EXPECTED_VALUES = { [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), ("_count", _NUM_REQUESTS)], "vllm:request_params_n": [("_count", _NUM_REQUESTS)], - "vllm:request_params_max_tokens": - [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], + "vllm:request_params_max_tokens": [ + ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS) + ], + "vllm:iteration_tokens_total": + [("_sum", _NUM_REQUESTS * + (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST)), + ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST)], "vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], "vllm:generation_tokens": [ @@ -197,6 +202,7 @@ EXPECTED_METRICS = [ "vllm:request_params_max_tokens_sum", "vllm:request_params_max_tokens_bucket", "vllm:request_params_max_tokens_count", + "vllm:iteration_tokens_total", "vllm:num_preemptions_total", "vllm:prompt_tokens_total", "vllm:generation_tokens_total", @@ -223,6 +229,7 @@ EXPECTED_METRICS_V1 = [ "vllm:gpu_prefix_cache_hits", "vllm:prompt_tokens_total", "vllm:generation_tokens_total", + "vllm:iteration_tokens_total", "vllm:request_success_total", "vllm:request_prompt_tokens_sum", "vllm:request_prompt_tokens_bucket", diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a669c9f6267c..1920dbf7a7dc 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -57,7 +57,7 @@ class AsyncLLM(EngineClient): if self.log_stats: self.stat_loggers.extend([ LoggingStatLogger(), - PrometheusStatLogger(vllm_config.model_config), + PrometheusStatLogger(vllm_config), ]) # Tokenizer (+ ensure liveness if running in another process). diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 439be38a3e79..5019e2b3f92a 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -7,7 +7,7 @@ from typing import Dict, List import numpy as np import prometheus_client -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason @@ -92,13 +92,13 @@ class LoggingStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase): - def __init__(self, model_config: ModelConfig): + def __init__(self, vllm_config: VllmConfig): self._unregister_vllm_metrics() labelnames = ["model_name"] - labelvalues = [model_config.served_model_name] + labelvalues = [vllm_config.model_config.served_model_name] - max_model_len = model_config.max_model_len + max_model_len = vllm_config.model_config.max_model_len self.gauge_scheduler_running = prometheus_client.Gauge( name="vllm:num_requests_running", @@ -162,6 +162,13 @@ class PrometheusStatLogger(StatLoggerBase): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) + self.histogram_iteration_tokens = \ + prometheus_client.Histogram( + name="vllm:iteration_tokens_total", + documentation="Histogram of number of tokens per engine_step.", + buckets=build_cudagraph_buckets(vllm_config), + labelnames=labelnames).labels(*labelvalues) + self.histogram_time_to_first_token = \ prometheus_client.Histogram( name="vllm:time_to_first_token_seconds", @@ -237,6 +244,9 @@ class PrometheusStatLogger(StatLoggerBase): self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens.inc( iteration_stats.num_generation_tokens) + self.histogram_iteration_tokens.observe( + iteration_stats.num_prompt_tokens + \ + iteration_stats.num_generation_tokens) for finished_request in iteration_stats.finished_requests: self.counter_request_success[finished_request.finish_reason].inc() @@ -293,3 +303,13 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: [1, 2, 5, 10, 20, 50, 100] """ return build_buckets([1, 2, 5], max_value) + + +def build_cudagraph_buckets(vllm_config: VllmConfig) -> List[int]: + if not vllm_config.model_config.enforce_eager: + buckets = vllm_config.compilation_config.\ + cudagraph_capture_sizes.copy() + buckets.sort() + return buckets + else: + return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]