mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 06:04:00 +08:00
[Misc] Extend vLLM Metrics logging API (#5925)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
c4bca740e8
commit
906a19cdb0
@ -39,7 +39,7 @@ def test_metric_counter_prompt_tokens(
|
||||
vllm_prompt_token_count = sum(prompt_token_counts)
|
||||
|
||||
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
stat_logger = vllm_model.model.llm_engine.stat_logger
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
|
||||
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
|
||||
**stat_logger.labels)._value.get()
|
||||
|
||||
@ -64,7 +64,7 @@ def test_metric_counter_generation_tokens(
|
||||
gpu_memory_utilization=0.4) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
stat_logger = vllm_model.model.llm_engine.stat_logger
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
|
||||
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
|
||||
**stat_logger.labels)._value.get()
|
||||
vllm_generation_count = 0
|
||||
@ -92,7 +92,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.3,
|
||||
served_model_name=served_model_name) as vllm_model:
|
||||
stat_logger = vllm_model.model.llm_engine.stat_logger
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
|
||||
metrics_tag_content = stat_logger.labels["model_name"]
|
||||
|
||||
if served_model_name is None or served_model_name == []:
|
||||
@ -172,10 +172,10 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
|
||||
num_requests: int) -> None:
|
||||
if disable_log_stats:
|
||||
with pytest.raises(AttributeError):
|
||||
_ = engine.stat_logger
|
||||
_ = engine.stat_loggers
|
||||
else:
|
||||
assert (engine.stat_logger
|
||||
is not None), "engine.stat_logger should be set"
|
||||
assert (engine.stat_loggers
|
||||
is not None), "engine.stat_loggers should be set"
|
||||
# Ensure the count bucket of request-level histogram metrics matches
|
||||
# the number of requests as a simple sanity check to ensure metrics are
|
||||
# generated
|
||||
|
||||
@ -13,7 +13,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
||||
SchedulerOutputs)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
|
||||
StatLoggerBase, Stats)
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
@ -160,6 +161,7 @@ class LLMEngine:
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
@ -292,11 +294,21 @@ class LLMEngine:
|
||||
|
||||
# Metric Logging.
|
||||
if self.log_stats:
|
||||
self.stat_logger = StatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||
labels=dict(model_name=model_config.served_model_name),
|
||||
max_model_len=self.model_config.max_model_len)
|
||||
self.stat_logger.info("cache_config", self.cache_config)
|
||||
if stat_loggers is not None:
|
||||
self.stat_loggers = stat_loggers
|
||||
else:
|
||||
self.stat_loggers = {
|
||||
"logging":
|
||||
LoggingStatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
|
||||
"prometheus":
|
||||
PrometheusStatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||
labels=dict(model_name=model_config.served_model_name),
|
||||
max_model_len=self.model_config.max_model_len),
|
||||
}
|
||||
self.stat_loggers["prometheus"].info("cache_config",
|
||||
self.cache_config)
|
||||
|
||||
self.tracer = None
|
||||
if self.observability_config.otlp_traces_endpoint:
|
||||
@ -833,14 +845,24 @@ class LLMEngine:
|
||||
|
||||
return request_outputs
|
||||
|
||||
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
||||
if logger_name in self.stat_loggers:
|
||||
raise KeyError(f"Logger with name {logger_name} already exists.")
|
||||
self.stat_loggers[logger_name] = logger
|
||||
|
||||
def remove_logger(self, logger_name: str) -> None:
|
||||
if logger_name not in self.stat_loggers:
|
||||
raise KeyError(f"Logger with name {logger_name} does not exist.")
|
||||
del self.stat_loggers[logger_name]
|
||||
|
||||
def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None) -> None:
|
||||
"""Forced log when no requests active."""
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(
|
||||
self._get_stats(scheduler_outputs, model_output))
|
||||
for logger in self.stat_loggers.values():
|
||||
logger.log(self._get_stats(scheduler_outputs, model_output))
|
||||
|
||||
def _get_stats(
|
||||
self,
|
||||
|
||||
@ -1,21 +1,27 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Counter as CollectionsCounter
|
||||
from typing import Dict, List, Optional, Protocol, Union
|
||||
|
||||
import numpy as np
|
||||
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||
disable_created_metrics)
|
||||
import prometheus_client
|
||||
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if ray is not None:
|
||||
from ray.util import metrics as ray_metrics
|
||||
else:
|
||||
ray_metrics = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
disable_created_metrics()
|
||||
prometheus_client.disable_created_metrics()
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the metrics definitions.
|
||||
@ -24,56 +30,55 @@ disable_created_metrics()
|
||||
# begin-metrics-definitions
|
||||
class Metrics:
|
||||
labelname_finish_reason = "finished_reason"
|
||||
_base_library = prometheus_client
|
||||
|
||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||
# Unregister any existing vLLM collectors
|
||||
for collector in list(REGISTRY._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
REGISTRY.unregister(collector)
|
||||
self._unregister_vllm_metrics()
|
||||
|
||||
# Config Information
|
||||
self.info_cache_config = Info(
|
||||
self.info_cache_config = prometheus_client.Info(
|
||||
name='vllm:cache_config',
|
||||
documentation='information of cache_config')
|
||||
|
||||
# System stats
|
||||
# Scheduler State
|
||||
self.gauge_scheduler_running = Gauge(
|
||||
self.gauge_scheduler_running = self._base_library.Gauge(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests currently running on GPU.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_waiting = Gauge(
|
||||
self.gauge_scheduler_waiting = self._base_library.Gauge(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_swapped = Gauge(
|
||||
self.gauge_scheduler_swapped = self._base_library.Gauge(
|
||||
name="vllm:num_requests_swapped",
|
||||
documentation="Number of requests swapped to CPU.",
|
||||
labelnames=labelnames)
|
||||
# KV Cache Usage in %
|
||||
self.gauge_gpu_cache_usage = Gauge(
|
||||
self.gauge_gpu_cache_usage = self._base_library.Gauge(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_cpu_cache_usage = Gauge(
|
||||
self.gauge_cpu_cache_usage = self._base_library.Gauge(
|
||||
name="vllm:cpu_cache_usage_perc",
|
||||
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames)
|
||||
|
||||
# Iteration stats
|
||||
self.counter_num_preemption = Counter(
|
||||
self.counter_num_preemption = self._base_library.Counter(
|
||||
name="vllm:num_preemptions_total",
|
||||
documentation="Cumulative number of preemption from the engine.",
|
||||
labelnames=labelnames)
|
||||
self.counter_prompt_tokens = Counter(
|
||||
self.counter_prompt_tokens = self._base_library.Counter(
|
||||
name="vllm:prompt_tokens_total",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.counter_generation_tokens = Counter(
|
||||
self.counter_generation_tokens = self._base_library.Counter(
|
||||
name="vllm:generation_tokens_total",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
self.histogram_time_to_first_token = self._base_library.Histogram(
|
||||
name="vllm:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
labelnames=labelnames,
|
||||
@ -81,7 +86,7 @@ class Metrics:
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
|
||||
])
|
||||
self.histogram_time_per_output_token = Histogram(
|
||||
self.histogram_time_per_output_token = self._base_library.Histogram(
|
||||
name="vllm:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
labelnames=labelnames,
|
||||
@ -92,54 +97,77 @@ class Metrics:
|
||||
|
||||
# Request stats
|
||||
# Latency
|
||||
self.histogram_e2e_time_request = Histogram(
|
||||
self.histogram_e2e_time_request = self._base_library.Histogram(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of end to end request latency in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
|
||||
# Metadata
|
||||
self.histogram_num_prompt_tokens_request = Histogram(
|
||||
self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
|
||||
name="vllm:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_num_generation_tokens_request = Histogram(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_best_of_request = Histogram(
|
||||
self.histogram_num_generation_tokens_request = \
|
||||
self._base_library.Histogram(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_best_of_request = self._base_library.Histogram(
|
||||
name="vllm:request_params_best_of",
|
||||
documentation="Histogram of the best_of request parameter.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
)
|
||||
self.histogram_n_request = Histogram(
|
||||
self.histogram_n_request = self._base_library.Histogram(
|
||||
name="vllm:request_params_n",
|
||||
documentation="Histogram of the n request parameter.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
)
|
||||
self.counter_request_success = Counter(
|
||||
self.counter_request_success = self._base_library.Counter(
|
||||
name="vllm:request_success_total",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + [Metrics.labelname_finish_reason])
|
||||
|
||||
# Deprecated in favor of vllm:prompt_tokens_total
|
||||
self.gauge_avg_prompt_throughput = Gauge(
|
||||
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
|
||||
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||
documentation="Average prefill throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
# Deprecated in favor of vllm:generation_tokens_total
|
||||
self.gauge_avg_generation_throughput = Gauge(
|
||||
self.gauge_avg_generation_throughput = self._base_library.Gauge(
|
||||
name="vllm:avg_generation_throughput_toks_per_s",
|
||||
documentation="Average generation throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
|
||||
def _unregister_vllm_metrics(self) -> None:
|
||||
for collector in list(self._base_library.REGISTRY._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
self._base_library.REGISTRY.unregister(collector)
|
||||
|
||||
|
||||
class RayMetrics(Metrics):
|
||||
"""
|
||||
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
|
||||
Provides the same metrics as Metrics but uses Ray's util.metrics library.
|
||||
"""
|
||||
_base_library = ray_metrics
|
||||
|
||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||
if ray_metrics is None:
|
||||
raise ImportError("RayMetrics requires Ray to be installed.")
|
||||
super().__init__(labelnames, max_model_len)
|
||||
|
||||
def _unregister_vllm_metrics(self) -> None:
|
||||
# No-op on purpose
|
||||
pass
|
||||
|
||||
|
||||
# end-metrics-definitions
|
||||
|
||||
@ -206,34 +234,136 @@ class SupportsMetricsInfo(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class StatLogger:
|
||||
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
|
||||
def local_interval_elapsed(now: float, last_log: float,
|
||||
local_interval: float) -> bool:
|
||||
elapsed_time = now - last_log
|
||||
return elapsed_time > local_interval
|
||||
|
||||
def __init__(self, local_interval: float, labels: Dict[str, str],
|
||||
max_model_len: int) -> None:
|
||||
# Metadata for logging locally.
|
||||
self.last_local_log = time.time()
|
||||
self.local_interval = local_interval
|
||||
|
||||
def get_throughput(tracked_stats: List[int], now: float,
|
||||
last_log: float) -> float:
|
||||
return float(np.sum(tracked_stats) / (now - last_log))
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
"""Base class for StatLogger."""
|
||||
|
||||
def __init__(self, local_interval: float) -> None:
|
||||
# Tracked stats over current local logging interval.
|
||||
self.num_prompt_tokens: List[int] = []
|
||||
self.num_generation_tokens: List[int] = []
|
||||
self.last_local_log = time.time()
|
||||
self.local_interval = local_interval
|
||||
|
||||
@abstractmethod
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def log(self, stats: Stats) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, stats: Stats) -> None:
|
||||
"""Called by LLMEngine.
|
||||
Logs to Stdout every self.local_interval seconds."""
|
||||
|
||||
# Save tracked stats for token counters.
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
# Compute summary metrics for tracked stats (and log them
|
||||
# to promethus if applicable).
|
||||
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
generation_throughput = get_throughput(
|
||||
self.num_generation_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
|
||||
# Log to stdout.
|
||||
logger.info(
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Swapped: %d reqs, "
|
||||
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
|
||||
"CPU KV cache usage: %.1f%%.",
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
stats.num_running_sys,
|
||||
stats.num_swapped_sys,
|
||||
stats.num_waiting_sys,
|
||||
stats.gpu_cache_usage_sys * 100,
|
||||
stats.cpu_cache_usage_sys * 100,
|
||||
)
|
||||
|
||||
# Reset tracked stats for next interval.
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
|
||||
if stats.spec_decode_metrics is not None:
|
||||
logger.info(
|
||||
self._format_spec_decode_metrics_str(
|
||||
stats.spec_decode_metrics))
|
||||
|
||||
def _format_spec_decode_metrics_str(
|
||||
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
||||
|
||||
return ("Speculative metrics: "
|
||||
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
|
||||
f"System efficiency: {metrics.system_efficiency:.3f}, "
|
||||
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
|
||||
f"Number of accepted tokens: {metrics.accepted_tokens}, "
|
||||
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
|
||||
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
|
||||
_metrics_cls = Metrics
|
||||
|
||||
def __init__(self, local_interval: float, labels: Dict[str, str],
|
||||
max_model_len: int) -> None:
|
||||
super().__init__(local_interval)
|
||||
# Prometheus metrics
|
||||
self.labels = labels
|
||||
self.metrics = Metrics(labelnames=list(labels.keys()),
|
||||
max_model_len=max_model_len)
|
||||
self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
|
||||
max_model_len=max_model_len)
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
if type == "cache_config":
|
||||
self.metrics.info_cache_config.info(obj.metrics_info())
|
||||
|
||||
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
|
||||
return float(np.sum(tracked_stats) / (now - self.last_local_log))
|
||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
|
||||
def _local_interval_elapsed(self, now: float) -> bool:
|
||||
elapsed_time = now - self.last_local_log
|
||||
return elapsed_time > self.local_interval
|
||||
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to counter.
|
||||
counter.labels(**self.labels).inc(data)
|
||||
|
||||
def _log_counter_labels(self, counter, data: CollectionsCounter,
|
||||
label_key: str) -> None:
|
||||
# Convenience function for collection counter of labels.
|
||||
for label, count in data.items():
|
||||
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
||||
|
||||
def _log_histogram(self, histogram, data: Union[List[int],
|
||||
List[float]]) -> None:
|
||||
# Convenience function for logging list to histogram.
|
||||
for datum in data:
|
||||
histogram.labels(**self.labels).observe(datum)
|
||||
|
||||
def _log_prometheus(self, stats: Stats) -> None:
|
||||
# System state data
|
||||
@ -279,26 +409,6 @@ class StatLogger:
|
||||
self._log_histogram(self.metrics.histogram_best_of_request,
|
||||
stats.best_of_requests)
|
||||
|
||||
def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
|
||||
def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to counter.
|
||||
counter.labels(**self.labels).inc(data)
|
||||
|
||||
def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
|
||||
label_key: str) -> None:
|
||||
# Convenience function for collection counter of labels.
|
||||
for label, count in data.items():
|
||||
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
||||
|
||||
def _log_histogram(self, histogram: Histogram,
|
||||
data: Union[List[int], List[float]]) -> None:
|
||||
# Convenience function for logging list to histogram.
|
||||
for datum in data:
|
||||
histogram.labels(**self.labels).observe(datum)
|
||||
|
||||
def _log_prometheus_interval(self, prompt_throughput: float,
|
||||
generation_throughput: float) -> None:
|
||||
# Logs metrics to prometheus that are computed every logging_interval.
|
||||
@ -313,11 +423,8 @@ class StatLogger:
|
||||
self.metrics.gauge_avg_generation_throughput.labels(
|
||||
**self.labels).set(generation_throughput)
|
||||
|
||||
def log(self, stats: Stats) -> None:
|
||||
"""Called by LLMEngine.
|
||||
Logs to prometheus and tracked stats every iteration.
|
||||
Logs to Stdout every self.local_interval seconds."""
|
||||
|
||||
def log(self, stats: Stats):
|
||||
"""Logs to prometheus and tracked stats every iteration."""
|
||||
# Log to prometheus.
|
||||
self._log_prometheus(stats)
|
||||
|
||||
@ -326,50 +433,28 @@ class StatLogger:
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if self._local_interval_elapsed(stats.now):
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
# Compute summary metrics for tracked stats (and log them
|
||||
# to promethus if applicable).
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now)
|
||||
generation_throughput = self._get_throughput(
|
||||
self.num_generation_tokens, now=stats.now)
|
||||
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
generation_throughput = get_throughput(
|
||||
self.num_generation_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
|
||||
self._log_prometheus_interval(
|
||||
prompt_throughput=prompt_throughput,
|
||||
generation_throughput=generation_throughput)
|
||||
|
||||
# Log to stdout.
|
||||
logger.info(
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Swapped: %d reqs, "
|
||||
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
|
||||
"CPU KV cache usage: %.1f%%.",
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
stats.num_running_sys,
|
||||
stats.num_swapped_sys,
|
||||
stats.num_waiting_sys,
|
||||
stats.gpu_cache_usage_sys * 100,
|
||||
stats.cpu_cache_usage_sys * 100,
|
||||
)
|
||||
|
||||
# Reset tracked stats for next interval.
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
|
||||
if stats.spec_decode_metrics is not None:
|
||||
logger.info(
|
||||
self._format_spec_decode_metrics_str(
|
||||
stats.spec_decode_metrics))
|
||||
|
||||
def _format_spec_decode_metrics_str(
|
||||
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
||||
|
||||
return ("Speculative metrics: "
|
||||
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
|
||||
f"System efficiency: {metrics.system_efficiency:.3f}, "
|
||||
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
|
||||
f"Number of accepted tokens: {metrics.accepted_tokens}, "
|
||||
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
|
||||
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
_metrics_cls = RayMetrics
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user