[Misc] Extend vLLM Metrics logging API (#5925)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
William Lin 2024-06-28 19:36:06 -07:00 committed by GitHub
parent c4bca740e8
commit 906a19cdb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 225 additions and 118 deletions

View File

@ -39,7 +39,7 @@ def test_metric_counter_prompt_tokens(
vllm_prompt_token_count = sum(prompt_token_counts) vllm_prompt_token_count = sum(prompt_token_counts)
_ = vllm_model.generate_greedy(example_prompts, max_tokens) _ = vllm_model.generate_greedy(example_prompts, max_tokens)
stat_logger = vllm_model.model.llm_engine.stat_logger stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metric_count = stat_logger.metrics.counter_prompt_tokens.labels( metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
**stat_logger.labels)._value.get() **stat_logger.labels)._value.get()
@ -64,7 +64,7 @@ def test_metric_counter_generation_tokens(
gpu_memory_utilization=0.4) as vllm_model: gpu_memory_utilization=0.4) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
tokenizer = vllm_model.model.get_tokenizer() tokenizer = vllm_model.model.get_tokenizer()
stat_logger = vllm_model.model.llm_engine.stat_logger stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metric_count = stat_logger.metrics.counter_generation_tokens.labels( metric_count = stat_logger.metrics.counter_generation_tokens.labels(
**stat_logger.labels)._value.get() **stat_logger.labels)._value.get()
vllm_generation_count = 0 vllm_generation_count = 0
@ -92,7 +92,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
disable_log_stats=False, disable_log_stats=False,
gpu_memory_utilization=0.3, gpu_memory_utilization=0.3,
served_model_name=served_model_name) as vllm_model: served_model_name=served_model_name) as vllm_model:
stat_logger = vllm_model.model.llm_engine.stat_logger stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metrics_tag_content = stat_logger.labels["model_name"] metrics_tag_content = stat_logger.labels["model_name"]
if served_model_name is None or served_model_name == []: if served_model_name is None or served_model_name == []:
@ -172,10 +172,10 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests: int) -> None: num_requests: int) -> None:
if disable_log_stats: if disable_log_stats:
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
_ = engine.stat_logger _ = engine.stat_loggers
else: else:
assert (engine.stat_logger assert (engine.stat_loggers
is not None), "engine.stat_logger should be set" is not None), "engine.stat_loggers should be set"
# Ensure the count bucket of request-level histogram metrics matches # Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are # the number of requests as a simple sanity check to ensure metrics are
# generated # generated

View File

@ -13,7 +13,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs) SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase, Stats)
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
@ -160,6 +161,7 @@ class LLMEngine:
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
@ -292,11 +294,21 @@ class LLMEngine:
# Metric Logging. # Metric Logging.
if self.log_stats: if self.log_stats:
self.stat_logger = StatLogger( if stat_loggers is not None:
local_interval=_LOCAL_LOGGING_INTERVAL_SEC, self.stat_loggers = stat_loggers
labels=dict(model_name=model_config.served_model_name), else:
max_model_len=self.model_config.max_model_len) self.stat_loggers = {
self.stat_logger.info("cache_config", self.cache_config) "logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)
self.tracer = None self.tracer = None
if self.observability_config.otlp_traces_endpoint: if self.observability_config.otlp_traces_endpoint:
@ -833,14 +845,24 @@ class LLMEngine:
return request_outputs return request_outputs
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if logger_name in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} already exists.")
self.stat_loggers[logger_name] = logger
def remove_logger(self, logger_name: str) -> None:
if logger_name not in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} does not exist.")
del self.stat_loggers[logger_name]
def do_log_stats( def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None: model_output: Optional[List[SamplerOutput]] = None) -> None:
"""Forced log when no requests active.""" """Forced log when no requests active."""
if self.log_stats: if self.log_stats:
self.stat_logger.log( for logger in self.stat_loggers.values():
self._get_stats(scheduler_outputs, model_output)) logger.log(self._get_stats(scheduler_outputs, model_output))
def _get_stats( def _get_stats(
self, self,

View File

@ -1,21 +1,27 @@
import time import time
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union from typing import Dict, List, Optional, Protocol, Union
import numpy as np import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, import prometheus_client
disable_created_metrics)
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger from vllm.logger import init_logger
if ray is not None:
from ray.util import metrics as ray_metrics
else:
ray_metrics = None
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__) logger = init_logger(__name__)
disable_created_metrics() prometheus_client.disable_created_metrics()
# The begin-* and end* here are used by the documentation generator # The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions. # to extract the metrics definitions.
@ -24,56 +30,55 @@ disable_created_metrics()
# begin-metrics-definitions # begin-metrics-definitions
class Metrics: class Metrics:
labelname_finish_reason = "finished_reason" labelname_finish_reason = "finished_reason"
_base_library = prometheus_client
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors # Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names): self._unregister_vllm_metrics()
if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector)
# Config Information # Config Information
self.info_cache_config = Info( self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config', name='vllm:cache_config',
documentation='information of cache_config') documentation='information of cache_config')
# System stats # System stats
# Scheduler State # Scheduler State
self.gauge_scheduler_running = Gauge( self.gauge_scheduler_running = self._base_library.Gauge(
name="vllm:num_requests_running", name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.", documentation="Number of requests currently running on GPU.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge( self.gauge_scheduler_waiting = self._base_library.Gauge(
name="vllm:num_requests_waiting", name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.", documentation="Number of requests waiting to be processed.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge( self.gauge_scheduler_swapped = self._base_library.Gauge(
name="vllm:num_requests_swapped", name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.", documentation="Number of requests swapped to CPU.",
labelnames=labelnames) labelnames=labelnames)
# KV Cache Usage in % # KV Cache Usage in %
self.gauge_gpu_cache_usage = Gauge( self.gauge_gpu_cache_usage = self._base_library.Gauge(
name="vllm:gpu_cache_usage_perc", name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_cpu_cache_usage = Gauge( self.gauge_cpu_cache_usage = self._base_library.Gauge(
name="vllm:cpu_cache_usage_perc", name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.", documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames)
# Iteration stats # Iteration stats
self.counter_num_preemption = Counter( self.counter_num_preemption = self._base_library.Counter(
name="vllm:num_preemptions_total", name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.", documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames) labelnames=labelnames)
self.counter_prompt_tokens = Counter( self.counter_prompt_tokens = self._base_library.Counter(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames) labelnames=labelnames)
self.counter_generation_tokens = Counter( self.counter_generation_tokens = self._base_library.Counter(
name="vllm:generation_tokens_total", name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames) labelnames=labelnames)
self.histogram_time_to_first_token = Histogram( self.histogram_time_to_first_token = self._base_library.Histogram(
name="vllm:time_to_first_token_seconds", name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
labelnames=labelnames, labelnames=labelnames,
@ -81,7 +86,7 @@ class Metrics:
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0 0.75, 1.0, 2.5, 5.0, 7.5, 10.0
]) ])
self.histogram_time_per_output_token = Histogram( self.histogram_time_per_output_token = self._base_library.Histogram(
name="vllm:time_per_output_token_seconds", name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.", documentation="Histogram of time per output token in seconds.",
labelnames=labelnames, labelnames=labelnames,
@ -92,54 +97,77 @@ class Metrics:
# Request stats # Request stats
# Latency # Latency
self.histogram_e2e_time_request = Histogram( self.histogram_e2e_time_request = self._base_library.Histogram(
name="vllm:e2e_request_latency_seconds", name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.", documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata # Metadata
self.histogram_num_prompt_tokens_request = Histogram( self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
name="vllm:request_prompt_tokens", name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames, labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
) )
self.histogram_num_generation_tokens_request = Histogram( self.histogram_num_generation_tokens_request = \
name="vllm:request_generation_tokens", self._base_library.Histogram(
documentation="Number of generation tokens processed.", name="vllm:request_generation_tokens",
labelnames=labelnames, documentation="Number of generation tokens processed.",
buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames,
) buckets=build_1_2_5_buckets(max_model_len),
self.histogram_best_of_request = Histogram( )
self.histogram_best_of_request = self._base_library.Histogram(
name="vllm:request_params_best_of", name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.", documentation="Histogram of the best_of request parameter.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1, 2, 5, 10, 20], buckets=[1, 2, 5, 10, 20],
) )
self.histogram_n_request = Histogram( self.histogram_n_request = self._base_library.Histogram(
name="vllm:request_params_n", name="vllm:request_params_n",
documentation="Histogram of the n request parameter.", documentation="Histogram of the n request parameter.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1, 2, 5, 10, 20], buckets=[1, 2, 5, 10, 20],
) )
self.counter_request_success = Counter( self.counter_request_success = self._base_library.Counter(
name="vllm:request_success_total", name="vllm:request_success_total",
documentation="Count of successfully processed requests.", documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason]) labelnames=labelnames + [Metrics.labelname_finish_reason])
# Deprecated in favor of vllm:prompt_tokens_total # Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = Gauge( self.gauge_avg_prompt_throughput = self._base_library.Gauge(
name="vllm:avg_prompt_throughput_toks_per_s", name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.", documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
) )
# Deprecated in favor of vllm:generation_tokens_total # Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = Gauge( self.gauge_avg_generation_throughput = self._base_library.Gauge(
name="vllm:avg_generation_throughput_toks_per_s", name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.", documentation="Average generation throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
) )
def _unregister_vllm_metrics(self) -> None:
for collector in list(self._base_library.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
self._base_library.REGISTRY.unregister(collector)
class RayMetrics(Metrics):
"""
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library = ray_metrics
def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
raise ImportError("RayMetrics requires Ray to be installed.")
super().__init__(labelnames, max_model_len)
def _unregister_vllm_metrics(self) -> None:
# No-op on purpose
pass
# end-metrics-definitions # end-metrics-definitions
@ -206,34 +234,136 @@ class SupportsMetricsInfo(Protocol):
... ...
class StatLogger: def local_interval_elapsed(now: float, last_log: float,
"""StatLogger is used LLMEngine to log to Promethus and Stdout.""" local_interval: float) -> bool:
elapsed_time = now - last_log
return elapsed_time > local_interval
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
# Metadata for logging locally.
self.last_local_log = time.time()
self.local_interval = local_interval
def get_throughput(tracked_stats: List[int], now: float,
last_log: float) -> float:
return float(np.sum(tracked_stats) / (now - last_log))
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval. # Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = [] self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now,
last_log=self.last_local_log)
generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
# Log to stdout.
logger.info(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
class PrometheusStatLogger(StatLoggerBase):
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls = Metrics
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
super().__init__(local_interval)
# Prometheus metrics # Prometheus metrics
self.labels = labels self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()), self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
max_model_len=max_model_len) max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config": if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info()) self.metrics.info_cache_config.info(obj.metrics_info())
def _get_throughput(self, tracked_stats: List[int], now: float) -> float: def _log_gauge(self, gauge, data: Union[int, float]) -> None:
return float(np.sum(tracked_stats) / (now - self.last_local_log)) # Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _local_interval_elapsed(self, now: float) -> bool: def _log_counter(self, counter, data: Union[int, float]) -> None:
elapsed_time = now - self.last_local_log # Convenience function for logging to counter.
return elapsed_time > self.local_interval counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram, data: Union[List[int],
List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus(self, stats: Stats) -> None: def _log_prometheus(self, stats: Stats) -> None:
# System state data # System state data
@ -279,26 +409,6 @@ class StatLogger:
self._log_histogram(self.metrics.histogram_best_of_request, self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests) stats.best_of_requests)
def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram: Histogram,
data: Union[List[int], List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus_interval(self, prompt_throughput: float, def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None: generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval. # Logs metrics to prometheus that are computed every logging_interval.
@ -313,11 +423,8 @@ class StatLogger:
self.metrics.gauge_avg_generation_throughput.labels( self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput) **self.labels).set(generation_throughput)
def log(self, stats: Stats) -> None: def log(self, stats: Stats):
"""Called by LLMEngine. """Logs to prometheus and tracked stats every iteration."""
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
# Log to prometheus. # Log to prometheus.
self._log_prometheus(stats) self._log_prometheus(stats)
@ -326,50 +433,28 @@ class StatLogger:
self.num_generation_tokens.append(stats.num_generation_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now): if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them # Compute summary metrics for tracked stats (and log them
# to promethus if applicable). # to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens, prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now) now=stats.now,
generation_throughput = self._get_throughput( last_log=self.last_local_log)
self.num_generation_tokens, now=stats.now) generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
self._log_prometheus_interval( self._log_prometheus_interval(
prompt_throughput=prompt_throughput, prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput) generation_throughput=generation_throughput)
# Log to stdout.
logger.info(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval. # Reset tracked stats for next interval.
self.num_prompt_tokens = [] self.num_prompt_tokens = []
self.num_generation_tokens = [] self.num_generation_tokens = []
self.last_local_log = stats.now self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str( class RayPrometheusStatLogger(PrometheusStatLogger):
self, metrics: "SpecDecodeWorkerMetrics") -> str: """RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")