[Misc] Extend vLLM Metrics logging API (#5925)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
William Lin 2024-06-28 19:36:06 -07:00 committed by GitHub
parent c4bca740e8
commit 906a19cdb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 225 additions and 118 deletions

View File

@ -39,7 +39,7 @@ def test_metric_counter_prompt_tokens(
vllm_prompt_token_count = sum(prompt_token_counts)
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
stat_logger = vllm_model.model.llm_engine.stat_logger
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
**stat_logger.labels)._value.get()
@ -64,7 +64,7 @@ def test_metric_counter_generation_tokens(
gpu_memory_utilization=0.4) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
tokenizer = vllm_model.model.get_tokenizer()
stat_logger = vllm_model.model.llm_engine.stat_logger
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
**stat_logger.labels)._value.get()
vllm_generation_count = 0
@ -92,7 +92,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
disable_log_stats=False,
gpu_memory_utilization=0.3,
served_model_name=served_model_name) as vllm_model:
stat_logger = vllm_model.model.llm_engine.stat_logger
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metrics_tag_content = stat_logger.labels["model_name"]
if served_model_name is None or served_model_name == []:
@ -172,10 +172,10 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests: int) -> None:
if disable_log_stats:
with pytest.raises(AttributeError):
_ = engine.stat_logger
_ = engine.stat_loggers
else:
assert (engine.stat_logger
is not None), "engine.stat_logger should be set"
assert (engine.stat_loggers
is not None), "engine.stat_loggers should be set"
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated

View File

@ -13,7 +13,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase, Stats)
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
@ -160,6 +161,7 @@ class LLMEngine:
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
@ -292,11 +294,21 @@ class LLMEngine:
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
self.stat_loggers = {
"logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
@ -833,14 +845,24 @@ class LLMEngine:
return request_outputs
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if logger_name in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} already exists.")
self.stat_loggers[logger_name] = logger
def remove_logger(self, logger_name: str) -> None:
if logger_name not in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} does not exist.")
del self.stat_loggers[logger_name]
def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
"""Forced log when no requests active."""
if self.log_stats:
self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output))
for logger in self.stat_loggers.values():
logger.log(self._get_stats(scheduler_outputs, model_output))
def _get_stats(
self,

View File

@ -1,21 +1,27 @@
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
disable_created_metrics)
import prometheus_client
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
if ray is not None:
from ray.util import metrics as ray_metrics
else:
ray_metrics = None
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__)
disable_created_metrics()
prometheus_client.disable_created_metrics()
# The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions.
@ -24,56 +30,55 @@ disable_created_metrics()
# begin-metrics-definitions
class Metrics:
labelname_finish_reason = "finished_reason"
_base_library = prometheus_client
def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector)
self._unregister_vllm_metrics()
# Config Information
self.info_cache_config = Info(
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
# System stats
# Scheduler State
self.gauge_scheduler_running = Gauge(
self.gauge_scheduler_running = self._base_library.Gauge(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge(
self.gauge_scheduler_waiting = self._base_library.Gauge(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
self.gauge_scheduler_swapped = self._base_library.Gauge(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
# KV Cache Usage in %
self.gauge_gpu_cache_usage = Gauge(
self.gauge_gpu_cache_usage = self._base_library.Gauge(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
self.gauge_cpu_cache_usage = Gauge(
self.gauge_cpu_cache_usage = self._base_library.Gauge(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
# Iteration stats
self.counter_num_preemption = Counter(
self.counter_num_preemption = self._base_library.Counter(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = Counter(
self.counter_prompt_tokens = self._base_library.Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = Counter(
self.counter_generation_tokens = self._base_library.Counter(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = Histogram(
self.histogram_time_to_first_token = self._base_library.Histogram(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
@ -81,7 +86,7 @@ class Metrics:
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = Histogram(
self.histogram_time_per_output_token = self._base_library.Histogram(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
@ -92,54 +97,77 @@ class Metrics:
# Request stats
# Latency
self.histogram_e2e_time_request = Histogram(
self.histogram_e2e_time_request = self._base_library.Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata
self.histogram_num_prompt_tokens_request = Histogram(
self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_num_generation_tokens_request = Histogram(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = Histogram(
self.histogram_num_generation_tokens_request = \
self._base_library.Histogram(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = self._base_library.Histogram(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = Histogram(
self.histogram_n_request = self._base_library.Histogram(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.counter_request_success = Counter(
self.counter_request_success = self._base_library.Counter(
name="vllm:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = Gauge(
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = Gauge(
self.gauge_avg_generation_throughput = self._base_library.Gauge(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)
def _unregister_vllm_metrics(self) -> None:
for collector in list(self._base_library.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
self._base_library.REGISTRY.unregister(collector)
class RayMetrics(Metrics):
"""
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library = ray_metrics
def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
raise ImportError("RayMetrics requires Ray to be installed.")
super().__init__(labelnames, max_model_len)
def _unregister_vllm_metrics(self) -> None:
# No-op on purpose
pass
# end-metrics-definitions
@ -206,34 +234,136 @@ class SupportsMetricsInfo(Protocol):
...
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
def local_interval_elapsed(now: float, last_log: float,
local_interval: float) -> bool:
elapsed_time = now - last_log
return elapsed_time > local_interval
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
# Metadata for logging locally.
self.last_local_log = time.time()
self.local_interval = local_interval
def get_throughput(tracked_stats: List[int], now: float,
last_log: float) -> float:
return float(np.sum(tracked_stats) / (now - last_log))
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now,
last_log=self.last_local_log)
generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
# Log to stdout.
logger.info(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
class PrometheusStatLogger(StatLoggerBase):
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls = Metrics
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
super().__init__(local_interval)
# Prometheus metrics
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()),
max_model_len=max_model_len)
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
return float(np.sum(tracked_stats) / (now - self.last_local_log))
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _local_interval_elapsed(self, now: float) -> bool:
elapsed_time = now - self.last_local_log
return elapsed_time > self.local_interval
def _log_counter(self, counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram, data: Union[List[int],
List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus(self, stats: Stats) -> None:
# System state data
@ -279,26 +409,6 @@ class StatLogger:
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)
def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram: Histogram,
data: Union[List[int], List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval.
@ -313,11 +423,8 @@ class StatLogger:
self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput)
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
def log(self, stats: Stats):
"""Logs to prometheus and tracked stats every iteration."""
# Log to prometheus.
self._log_prometheus(stats)
@ -326,50 +433,28 @@ class StatLogger:
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now):
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens,
now=stats.now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now=stats.now)
prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now,
last_log=self.last_local_log)
generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
self._log_prometheus_interval(
prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput)
# Log to stdout.
logger.info(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics