monitor metrics of tokens per step using cudagraph batchsizes (#11031)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-09 22:35:36 -08:00 committed by GitHub
parent 28b3a1c7e5
commit ebf778061d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 13 deletions

View File

@ -411,7 +411,7 @@ def test_engine_log_metrics_ray(
logger = _RayPrometheusStatLogger( logger = _RayPrometheusStatLogger(
local_interval=0.5, local_interval=0.5,
labels=dict(model_name=engine.model_config.served_model_name), labels=dict(model_name=engine.model_config.served_model_name),
max_model_len=engine.model_config.max_model_len) vllm_config=engine.vllm_config)
engine.add_logger("ray", logger) engine.add_logger("ray", logger)
for i, prompt in enumerate(example_prompts): for i, prompt in enumerate(example_prompts):
engine.add_request( engine.add_request(

View File

@ -232,6 +232,7 @@ class LLMEngine:
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
) -> None: ) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
@ -385,13 +386,14 @@ class LLMEngine:
self.stat_loggers = { self.stat_loggers = {
"logging": "logging":
LoggingStatLogger( LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC), local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
vllm_config=vllm_config),
"prometheus": "prometheus":
PrometheusStatLogger( PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC, local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict( labels=dict(
model_name=self.model_config.served_model_name), model_name=self.model_config.served_model_name),
max_model_len=self.model_config.max_model_len), vllm_config=vllm_config),
} }
self.stat_loggers["prometheus"].info("cache_config", self.stat_loggers["prometheus"].info("cache_config",
self.cache_config) self.cache_config)

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Type, Union, cast
import numpy as np import numpy as np
import prometheus_client import prometheus_client
from vllm.config import VllmConfig
from vllm.engine.metrics_types import (StatLoggerBase, Stats, from vllm.engine.metrics_types import (StatLoggerBase, Stats,
SupportsMetricsInfo) SupportsMetricsInfo)
from vllm.executor.ray_utils import ray from vllm.executor.ray_utils import ray
@ -44,10 +45,12 @@ class Metrics:
_counter_cls = prometheus_client.Counter _counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram _histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
# Unregister any existing vLLM collectors (for CI/CD) # Unregister any existing vLLM collectors (for CI/CD)
self._unregister_vllm_metrics() self._unregister_vllm_metrics()
max_model_len = vllm_config.model_config.max_model_len
# System stats # System stats
# Scheduler State # Scheduler State
self.gauge_scheduler_running = self._gauge_cls( self.gauge_scheduler_running = self._gauge_cls(
@ -115,11 +118,15 @@ class Metrics:
name="vllm:tokens_total", name="vllm:tokens_total",
documentation="Number of prefill plus generation tokens processed.", documentation="Number of prefill plus generation tokens processed.",
labelnames=labelnames) labelnames=labelnames)
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.capture_sizes.copy()
buckets.sort()
self.histogram_iteration_tokens = self._histogram_cls( self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total", name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.", documentation="Histogram of number of tokens per engine_step.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]) buckets=buckets)
self.histogram_time_to_first_token = self._histogram_cls( self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds", name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
@ -361,10 +368,10 @@ class RayMetrics(Metrics):
_histogram_cls: Type[prometheus_client.Histogram] = cast( _histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper) Type[prometheus_client.Histogram], _RayHistogramWrapper)
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
if ray_metrics is None: if ray_metrics is None:
raise ImportError("RayMetrics requires Ray to be installed.") raise ImportError("RayMetrics requires Ray to be installed.")
super().__init__(labelnames, max_model_len) super().__init__(labelnames, vllm_config)
def _unregister_vllm_metrics(self) -> None: def _unregister_vllm_metrics(self) -> None:
# No-op on purpose # No-op on purpose
@ -421,8 +428,8 @@ def get_throughput(tracked_stats: List[int], now: float,
class LoggingStatLogger(StatLoggerBase): class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout.""" """LoggingStatLogger is used in LLMEngine to log to Stdout."""
def __init__(self, *args, **kwargs) -> None: def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
super().__init__(*args, **kwargs) super().__init__(local_interval, vllm_config)
self.last_prompt_throughput: Optional[float] = None self.last_prompt_throughput: Optional[float] = None
self.last_generation_throughput: Optional[float] = None self.last_generation_throughput: Optional[float] = None
@ -515,12 +522,12 @@ class PrometheusStatLogger(StatLoggerBase):
_gauge_cls = prometheus_client.Gauge _gauge_cls = prometheus_client.Gauge
def __init__(self, local_interval: float, labels: Dict[str, str], def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None: vllm_config: VllmConfig) -> None:
super().__init__(local_interval) super().__init__(local_interval, vllm_config)
# Prometheus metrics # Prometheus metrics
self.labels = labels self.labels = labels
self.metrics = self._metrics_cls(labelnames=list(labels.keys()), self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
max_model_len=max_model_len) vllm_config=vllm_config)
def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge. # Convenience function for logging to gauge.

View File

@ -16,6 +16,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol from typing import Dict, List, Optional, Protocol
from vllm.config import VllmConfig
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@ -77,7 +78,7 @@ class SupportsMetricsInfo(Protocol):
class StatLoggerBase(ABC): class StatLoggerBase(ABC):
"""Base class for StatLogger.""" """Base class for StatLogger."""
def __init__(self, local_interval: float) -> None: def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
# Tracked stats over current local logging interval. # Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = [] self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []