mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-19 18:54:28 +08:00
[Misc] Add Ray Prometheus logger to V1 (#17925)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
parent
67da5720d4
commit
541817670c
57
tests/v1/metrics/test_ray_metrics.py
Normal file
57
tests/v1/metrics/test_ray_metrics.py
Normal file
@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
|
||||
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v1_only(monkeypatch):
|
||||
"""
|
||||
The change relies on V1 APIs, so set VLLM_USE_V1=1.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
|
||||
MODELS = [
|
||||
"distilbert/distilgpt2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
def test_engine_log_metrics_ray(
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
""" Simple smoke test, verifying this can be used without exceptions.
|
||||
Need to start a Ray cluster in order to verify outputs."""
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class EngineTestActor:
|
||||
|
||||
async def run(self):
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
engine_args, stat_loggers=[RayPrometheusStatLogger])
|
||||
|
||||
for i, prompt in enumerate(example_prompts):
|
||||
engine.generate(
|
||||
request_id=f"request-id-{i}",
|
||||
prompt=prompt,
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens),
|
||||
)
|
||||
|
||||
# Create the actor and call the async method
|
||||
actor = EngineTestActor.remote() # type: ignore[attr-defined]
|
||||
ray.get(actor.run.remote())
|
||||
@ -138,6 +138,10 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
_counter_cls = prometheus_client.Counter
|
||||
_histogram_cls = prometheus_client.Histogram
|
||||
_spec_decoding_cls = SpecDecodingProm
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
self._unregister_vllm_metrics()
|
||||
@ -156,18 +160,18 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
self.spec_decoding_prom = SpecDecodingProm(
|
||||
self.spec_decoding_prom = self._spec_decoding_cls(
|
||||
vllm_config.speculative_config, labelnames, labelvalues)
|
||||
|
||||
#
|
||||
# Scheduler state
|
||||
#
|
||||
self.gauge_scheduler_running = prometheus_client.Gauge(
|
||||
self.gauge_scheduler_running = self._gauge_cls(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests in model execution batches.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.gauge_scheduler_waiting = prometheus_client.Gauge(
|
||||
self.gauge_scheduler_waiting = self._gauge_cls(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
@ -175,18 +179,18 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
#
|
||||
# GPU cache
|
||||
#
|
||||
self.gauge_gpu_cache_usage = prometheus_client.Gauge(
|
||||
self.gauge_gpu_cache_usage = self._gauge_cls(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_gpu_prefix_cache_queries = prometheus_client.Counter(
|
||||
self.counter_gpu_prefix_cache_queries = self._counter_cls(
|
||||
name="vllm:gpu_prefix_cache_queries",
|
||||
documentation=
|
||||
"GPU prefix cache queries, in terms of number of queried tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_gpu_prefix_cache_hits = prometheus_client.Counter(
|
||||
self.counter_gpu_prefix_cache_hits = self._counter_cls(
|
||||
name="vllm:gpu_prefix_cache_hits",
|
||||
documentation=
|
||||
"GPU prefix cache hits, in terms of number of cached tokens.",
|
||||
@ -195,24 +199,24 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
#
|
||||
# Counters
|
||||
#
|
||||
self.counter_num_preempted_reqs = prometheus_client.Counter(
|
||||
self.counter_num_preempted_reqs = self._counter_cls(
|
||||
name="vllm:num_preemptions_total",
|
||||
documentation="Cumulative number of preemption from the engine.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_prompt_tokens = prometheus_client.Counter(
|
||||
self.counter_prompt_tokens = self._counter_cls(
|
||||
name="vllm:prompt_tokens_total",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_generation_tokens = prometheus_client.Counter(
|
||||
self.counter_generation_tokens = self._counter_cls(
|
||||
name="vllm:generation_tokens_total",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_request_success: dict[FinishReason,
|
||||
prometheus_client.Counter] = {}
|
||||
counter_request_success_base = prometheus_client.Counter(
|
||||
counter_request_success_base = self._counter_cls(
|
||||
name="vllm:request_success_total",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + ["finished_reason"])
|
||||
@ -225,21 +229,21 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
# Histograms of counts
|
||||
#
|
||||
self.histogram_num_prompt_tokens_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_num_generation_tokens_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_iteration_tokens = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:iteration_tokens_total",
|
||||
documentation="Histogram of number of tokens per engine_step.",
|
||||
buckets=[
|
||||
@ -249,7 +253,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_max_num_generation_tokens_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_max_num_generation_tokens",
|
||||
documentation=
|
||||
"Histogram of maximum number of requested generation tokens.",
|
||||
@ -257,14 +261,14 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_n_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_params_n",
|
||||
documentation="Histogram of the n request parameter.",
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_max_tokens_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_params_max_tokens",
|
||||
documentation="Histogram of the max_tokens request parameter.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
@ -274,7 +278,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
# Histogram of timing intervals
|
||||
#
|
||||
self.histogram_time_to_first_token = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
buckets=[
|
||||
@ -285,7 +289,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.histogram_time_per_output_token = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
buckets=[
|
||||
@ -299,34 +303,34 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
|
||||
]
|
||||
self.histogram_e2e_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of e2e request latency in seconds.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_queue_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_queue_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in WAITING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_inference_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_inference_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in RUNNING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_prefill_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_prefill_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in PREFILL phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_decode_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_decode_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in DECODE phase for request.",
|
||||
@ -343,7 +347,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.labelname_running_lora_adapters = "running_lora_adapters"
|
||||
self.max_lora = vllm_config.lora_config.max_loras
|
||||
self.gauge_lora_info = \
|
||||
prometheus_client.Gauge(
|
||||
self._gauge_cls(
|
||||
name="vllm:lora_requests_info",
|
||||
documentation="Running stats on lora requests.",
|
||||
labelnames=[
|
||||
@ -365,7 +369,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
# Info type metrics are syntactic sugar for a gauge permanently set to 1
|
||||
# Since prometheus multiprocessing mode does not support Info, emulate
|
||||
# info here with a gauge.
|
||||
info_gauge = prometheus_client.Gauge(
|
||||
info_gauge = self._gauge_cls(
|
||||
name=name,
|
||||
documentation=documentation,
|
||||
labelnames=metrics_info.keys()).labels(**metrics_info)
|
||||
|
||||
120
vllm/v1/metrics/ray_wrappers.py
Normal file
120
vllm/v1/metrics/ray_wrappers.py
Normal file
@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.metrics.loggers import PrometheusStatLogger
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingProm
|
||||
|
||||
try:
|
||||
from ray.util import metrics as ray_metrics
|
||||
from ray.util.metrics import Metric
|
||||
except ImportError:
|
||||
ray_metrics = None
|
||||
|
||||
|
||||
class RayPrometheusMetric:
|
||||
|
||||
def __init__(self):
|
||||
if ray_metrics is None:
|
||||
raise ImportError(
|
||||
"RayPrometheusMetric requires Ray to be installed.")
|
||||
|
||||
self.metric: Metric = None
|
||||
|
||||
def labels(self, *labels, **labelskwargs):
|
||||
if labelskwargs:
|
||||
for k, v in labelskwargs.items():
|
||||
if not isinstance(v, str):
|
||||
labelskwargs[k] = str(v)
|
||||
|
||||
self.metric.set_default_tags(labelskwargs)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class RayGaugeWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Gauge to provide same API as
|
||||
prometheus_client.Gauge"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: Optional[str] = "",
|
||||
labelnames: Optional[list[str]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self.metric = ray_metrics.Gauge(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple)
|
||||
|
||||
def set(self, value: Union[int, float]):
|
||||
return self.metric.set(value)
|
||||
|
||||
def set_to_current_time(self):
|
||||
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
|
||||
return self.metric.set(time.time())
|
||||
|
||||
|
||||
class RayCounterWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Counter to provide same API as
|
||||
prometheus_client.Counter"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: Optional[str] = "",
|
||||
labelnames: Optional[list[str]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self.metric = ray_metrics.Counter(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple)
|
||||
|
||||
def inc(self, value: Union[int, float] = 1.0):
|
||||
if value == 0:
|
||||
return
|
||||
return self.metric.inc(value)
|
||||
|
||||
|
||||
class RayHistogramWrapper(RayPrometheusMetric):
|
||||
"""Wraps around ray.util.metrics.Histogram to provide same API as
|
||||
prometheus_client.Histogram"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: Optional[str] = "",
|
||||
labelnames: Optional[list[str]] = None,
|
||||
buckets: Optional[list[float]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
boundaries = buckets if buckets else []
|
||||
self.metric = ray_metrics.Histogram(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple,
|
||||
boundaries=boundaries)
|
||||
|
||||
def observe(self, value: Union[int, float]):
|
||||
return self.metric.observe(value)
|
||||
|
||||
|
||||
class RaySpecDecodingProm(SpecDecodingProm):
|
||||
"""
|
||||
RaySpecDecodingProm is used by RayMetrics to log to Ray metrics.
|
||||
Provides the same metrics as SpecDecodingProm but uses Ray's
|
||||
util.metrics library.
|
||||
"""
|
||||
|
||||
_counter_cls = RayCounterWrapper
|
||||
|
||||
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
|
||||
_gauge_cls = RayGaugeWrapper
|
||||
_counter_cls = RayCounterWrapper
|
||||
_histogram_cls = RayHistogramWrapper
|
||||
_spec_decoding_cls = RaySpecDecodingProm
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
super().__init__(vllm_config, engine_index)
|
||||
|
||||
@staticmethod
|
||||
def _unregister_vllm_metrics():
|
||||
# No-op on purpose
|
||||
pass
|
||||
@ -120,24 +120,30 @@ class SpecDecodingProm:
|
||||
vllm:spec_decode_num_drafts[$interval]
|
||||
"""
|
||||
|
||||
def __init__(self, speculative_config: Optional[SpeculativeConfig],
|
||||
labelnames: list[str], labelvalues: list[str]):
|
||||
_counter_cls = prometheus_client.Counter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
labelnames: list[str],
|
||||
labelvalues: list[str],
|
||||
):
|
||||
self.spec_decoding_enabled = speculative_config is not None
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
|
||||
self.counter_spec_decode_num_drafts = \
|
||||
prometheus_client.Counter(
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_drafts_total",
|
||||
documentation="Number of spec decoding drafts.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_draft_tokens = \
|
||||
prometheus_client.Counter(
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames,).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_accepted_tokens = \
|
||||
prometheus_client.Counter(
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
@ -146,12 +152,13 @@ class SpecDecodingProm:
|
||||
num_spec_tokens = (speculative_config.num_speculative_tokens
|
||||
if self.spec_decoding_enabled else 0)
|
||||
pos_labelnames = labelnames + ["position"]
|
||||
base_counter = prometheus_client.Counter(
|
||||
base_counter = self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens_per_pos",
|
||||
documentation="Accepted tokens per draft position.",
|
||||
labelnames=pos_labelnames)
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos: \
|
||||
list[prometheus_client.Counter] = []
|
||||
labelnames=pos_labelnames,
|
||||
)
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos: list[
|
||||
prometheus_client.Counter] = []
|
||||
for pos in range(num_spec_tokens):
|
||||
pos_labelvalues = labelvalues + [str(pos)]
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user