monitor metrics of tokens per step using cudagraph batchsizes (#11031)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-09 22:35:36 -08:00 committed by GitHub
parent 28b3a1c7e5
commit ebf778061d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 13 deletions

View File

@ -411,7 +411,7 @@ def test_engine_log_metrics_ray(
logger = _RayPrometheusStatLogger(
local_interval=0.5,
labels=dict(model_name=engine.model_config.served_model_name),
max_model_len=engine.model_config.max_model_len)
vllm_config=engine.vllm_config)
engine.add_logger("ray", logger)
for i, prompt in enumerate(example_prompts):
engine.add_request(

View File

@ -232,6 +232,7 @@ class LLMEngine:
use_cached_outputs: bool = False,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
@ -385,13 +386,14 @@ class LLMEngine:
self.stat_loggers = {
"logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
vllm_config=vllm_config),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(
model_name=self.model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
vllm_config=vllm_config),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Type, Union, cast
import numpy as np
import prometheus_client
from vllm.config import VllmConfig
from vllm.engine.metrics_types import (StatLoggerBase, Stats,
SupportsMetricsInfo)
from vllm.executor.ray_utils import ray
@ -44,10 +45,12 @@ class Metrics:
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], max_model_len: int):
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
# Unregister any existing vLLM collectors (for CI/CD)
self._unregister_vllm_metrics()
max_model_len = vllm_config.model_config.max_model_len
# System stats
# Scheduler State
self.gauge_scheduler_running = self._gauge_cls(
@ -115,11 +118,15 @@ class Metrics:
name="vllm:tokens_total",
documentation="Number of prefill plus generation tokens processed.",
labelnames=labelnames)
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.capture_sizes.copy()
buckets.sort()
self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.",
labelnames=labelnames,
buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096])
buckets=buckets)
self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
@ -361,10 +368,10 @@ class RayMetrics(Metrics):
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper)
def __init__(self, labelnames: List[str], max_model_len: int):
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
if ray_metrics is None:
raise ImportError("RayMetrics requires Ray to be installed.")
super().__init__(labelnames, max_model_len)
super().__init__(labelnames, vllm_config)
def _unregister_vllm_metrics(self) -> None:
# No-op on purpose
@ -421,8 +428,8 @@ def get_throughput(tracked_stats: List[int], now: float,
class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
super().__init__(local_interval, vllm_config)
self.last_prompt_throughput: Optional[float] = None
self.last_generation_throughput: Optional[float] = None
@ -515,12 +522,12 @@ class PrometheusStatLogger(StatLoggerBase):
_gauge_cls = prometheus_client.Gauge
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
super().__init__(local_interval)
vllm_config: VllmConfig) -> None:
super().__init__(local_interval, vllm_config)
# Prometheus metrics
self.labels = labels
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
max_model_len=max_model_len)
vllm_config=vllm_config)
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.

View File

@ -16,6 +16,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from vllm.config import VllmConfig
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@ -77,7 +78,7 @@ class SupportsMetricsInfo(Protocol):
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []