diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 4a824c7acef2..b3c7850556f9 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -411,7 +411,7 @@ def test_engine_log_metrics_ray( logger = _RayPrometheusStatLogger( local_interval=0.5, labels=dict(model_name=engine.model_config.served_model_name), - max_model_len=engine.model_config.max_model_len) + vllm_config=engine.vllm_config) engine.add_logger("ray", logger) for i, prompt in enumerate(example_prompts): engine.add_request( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8fc69d96d321..6eca304b45f0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -232,6 +232,7 @@ class LLMEngine: use_cached_outputs: bool = False, ) -> None: + self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config @@ -385,13 +386,14 @@ class LLMEngine: self.stat_loggers = { "logging": LoggingStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + vllm_config=vllm_config), "prometheus": PrometheusStatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, labels=dict( model_name=self.model_config.served_model_name), - max_model_len=self.model_config.max_model_len), + vllm_config=vllm_config), } self.stat_loggers["prometheus"].info("cache_config", self.cache_config) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index a5ae21c3966a..c8aec8dd3afa 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Type, Union, cast import numpy as np import prometheus_client +from vllm.config import VllmConfig from vllm.engine.metrics_types import (StatLoggerBase, Stats, SupportsMetricsInfo) from vllm.executor.ray_utils import ray @@ -44,10 +45,12 @@ class Metrics: _counter_cls = prometheus_client.Counter _histogram_cls = prometheus_client.Histogram - def __init__(self, labelnames: List[str], max_model_len: int): + def __init__(self, labelnames: List[str], vllm_config: VllmConfig): # Unregister any existing vLLM collectors (for CI/CD) self._unregister_vllm_metrics() + max_model_len = vllm_config.model_config.max_model_len + # System stats # Scheduler State self.gauge_scheduler_running = self._gauge_cls( @@ -115,11 +118,15 @@ class Metrics: name="vllm:tokens_total", documentation="Number of prefill plus generation tokens processed.", labelnames=labelnames) + buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] + if not vllm_config.model_config.enforce_eager: + buckets = vllm_config.compilation_config.capture_sizes.copy() + buckets.sort() self.histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", labelnames=labelnames, - buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]) + buckets=buckets) self.histogram_time_to_first_token = self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", @@ -361,10 +368,10 @@ class RayMetrics(Metrics): _histogram_cls: Type[prometheus_client.Histogram] = cast( Type[prometheus_client.Histogram], _RayHistogramWrapper) - def __init__(self, labelnames: List[str], max_model_len: int): + def __init__(self, labelnames: List[str], vllm_config: VllmConfig): if ray_metrics is None: raise ImportError("RayMetrics requires Ray to be installed.") - super().__init__(labelnames, max_model_len) + super().__init__(labelnames, vllm_config) def _unregister_vllm_metrics(self) -> None: # No-op on purpose @@ -421,8 +428,8 @@ def get_throughput(tracked_stats: List[int], now: float, class LoggingStatLogger(StatLoggerBase): """LoggingStatLogger is used in LLMEngine to log to Stdout.""" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: + super().__init__(local_interval, vllm_config) self.last_prompt_throughput: Optional[float] = None self.last_generation_throughput: Optional[float] = None @@ -515,12 +522,12 @@ class PrometheusStatLogger(StatLoggerBase): _gauge_cls = prometheus_client.Gauge def __init__(self, local_interval: float, labels: Dict[str, str], - max_model_len: int) -> None: - super().__init__(local_interval) + vllm_config: VllmConfig) -> None: + super().__init__(local_interval, vllm_config) # Prometheus metrics self.labels = labels self.metrics = self._metrics_cls(labelnames=list(labels.keys()), - max_model_len=max_model_len) + vllm_config=vllm_config) def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 5f7ec3bbcb26..5c7a430d11c5 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -16,6 +16,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, List, Optional, Protocol +from vllm.config import VllmConfig from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -77,7 +78,7 @@ class SupportsMetricsInfo(Protocol): class StatLoggerBase(ABC): """Base class for StatLogger.""" - def __init__(self, local_interval: float) -> None: + def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: # Tracked stats over current local logging interval. self.num_prompt_tokens: List[int] = [] self.num_generation_tokens: List[int] = []