diff --git a/tests/v1/metrics/test_ray_metrics.py b/tests/v1/metrics/test_ray_metrics.py new file mode 100644 index 0000000000000..02475f7c150b8 --- /dev/null +++ b/tests/v1/metrics/test_ray_metrics.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import ray + +from vllm.sampling_params import SamplingParams +from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM +from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger + + +@pytest.fixture(scope="function", autouse=True) +def use_v1_only(monkeypatch): + """ + The change relies on V1 APIs, so set VLLM_USE_V1=1. + """ + monkeypatch.setenv('VLLM_USE_V1', '1') + + +MODELS = [ + "distilbert/distilgpt2", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [16]) +def test_engine_log_metrics_ray( + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """ Simple smoke test, verifying this can be used without exceptions. + Need to start a Ray cluster in order to verify outputs.""" + + @ray.remote(num_gpus=1) + class EngineTestActor: + + async def run(self): + engine_args = AsyncEngineArgs( + model=model, + dtype=dtype, + disable_log_stats=False, + ) + + engine = AsyncLLM.from_engine_args( + engine_args, stat_loggers=[RayPrometheusStatLogger]) + + for i, prompt in enumerate(example_prompts): + engine.generate( + request_id=f"request-id-{i}", + prompt=prompt, + sampling_params=SamplingParams(max_tokens=max_tokens), + ) + + # Create the actor and call the async method + actor = EngineTestActor.remote() # type: ignore[attr-defined] + ray.get(actor.run.remote()) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 6ee40850beb10..2b75a3a2ecbd3 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -138,6 +138,10 @@ class LoggingStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase): + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram + _spec_decoding_cls = SpecDecodingProm def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self._unregister_vllm_metrics() @@ -156,18 +160,18 @@ class PrometheusStatLogger(StatLoggerBase): max_model_len = vllm_config.model_config.max_model_len - self.spec_decoding_prom = SpecDecodingProm( + self.spec_decoding_prom = self._spec_decoding_cls( vllm_config.speculative_config, labelnames, labelvalues) # # Scheduler state # - self.gauge_scheduler_running = prometheus_client.Gauge( + self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", labelnames=labelnames).labels(*labelvalues) - self.gauge_scheduler_waiting = prometheus_client.Gauge( + self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames).labels(*labelvalues) @@ -175,18 +179,18 @@ class PrometheusStatLogger(StatLoggerBase): # # GPU cache # - self.gauge_gpu_cache_usage = prometheus_client.Gauge( + self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames).labels(*labelvalues) - self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( + self.counter_gpu_prefix_cache_queries = self._counter_cls( name="vllm:gpu_prefix_cache_queries", documentation= "GPU prefix cache queries, in terms of number of queried tokens.", labelnames=labelnames).labels(*labelvalues) - self.counter_gpu_prefix_cache_hits = prometheus_client.Counter( + self.counter_gpu_prefix_cache_hits = self._counter_cls( name="vllm:gpu_prefix_cache_hits", documentation= "GPU prefix cache hits, in terms of number of cached tokens.", @@ -195,24 +199,24 @@ class PrometheusStatLogger(StatLoggerBase): # # Counters # - self.counter_num_preempted_reqs = prometheus_client.Counter( + self.counter_num_preempted_reqs = self._counter_cls( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", labelnames=labelnames).labels(*labelvalues) - self.counter_prompt_tokens = prometheus_client.Counter( + self.counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", labelnames=labelnames).labels(*labelvalues) - self.counter_generation_tokens = prometheus_client.Counter( + self.counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", labelnames=labelnames).labels(*labelvalues) self.counter_request_success: dict[FinishReason, prometheus_client.Counter] = {} - counter_request_success_base = prometheus_client.Counter( + counter_request_success_base = self._counter_cls( name="vllm:request_success_total", documentation="Count of successfully processed requests.", labelnames=labelnames + ["finished_reason"]) @@ -225,21 +229,21 @@ class PrometheusStatLogger(StatLoggerBase): # Histograms of counts # self.histogram_num_prompt_tokens_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) self.histogram_num_generation_tokens_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_generation_tokens", documentation="Number of generation tokens processed.", buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) self.histogram_iteration_tokens = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", buckets=[ @@ -249,7 +253,7 @@ class PrometheusStatLogger(StatLoggerBase): labelnames=labelnames).labels(*labelvalues) self.histogram_max_num_generation_tokens_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_max_num_generation_tokens", documentation= "Histogram of maximum number of requested generation tokens.", @@ -257,14 +261,14 @@ class PrometheusStatLogger(StatLoggerBase): labelnames=labelnames).labels(*labelvalues) self.histogram_n_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", buckets=[1, 2, 5, 10, 20], labelnames=labelnames).labels(*labelvalues) self.histogram_max_tokens_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_params_max_tokens", documentation="Histogram of the max_tokens request parameter.", buckets=build_1_2_5_buckets(max_model_len), @@ -274,7 +278,7 @@ class PrometheusStatLogger(StatLoggerBase): # Histogram of timing intervals # self.histogram_time_to_first_token = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", buckets=[ @@ -285,7 +289,7 @@ class PrometheusStatLogger(StatLoggerBase): labelnames=labelnames).labels(*labelvalues) self.histogram_time_per_output_token = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:time_per_output_token_seconds", documentation="Histogram of time per output token in seconds.", buckets=[ @@ -299,34 +303,34 @@ class PrometheusStatLogger(StatLoggerBase): 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 ] self.histogram_e2e_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of e2e request latency in seconds.", buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_queue_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_queue_time_seconds", documentation= "Histogram of time spent in WAITING phase for request.", buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_inference_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_inference_time_seconds", documentation= "Histogram of time spent in RUNNING phase for request.", buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_prefill_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_prefill_time_seconds", documentation= "Histogram of time spent in PREFILL phase for request.", buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) self.histogram_decode_time_request = \ - prometheus_client.Histogram( + self._histogram_cls( name="vllm:request_decode_time_seconds", documentation= "Histogram of time spent in DECODE phase for request.", @@ -343,7 +347,7 @@ class PrometheusStatLogger(StatLoggerBase): self.labelname_running_lora_adapters = "running_lora_adapters" self.max_lora = vllm_config.lora_config.max_loras self.gauge_lora_info = \ - prometheus_client.Gauge( + self._gauge_cls( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", labelnames=[ @@ -365,7 +369,7 @@ class PrometheusStatLogger(StatLoggerBase): # Info type metrics are syntactic sugar for a gauge permanently set to 1 # Since prometheus multiprocessing mode does not support Info, emulate # info here with a gauge. - info_gauge = prometheus_client.Gauge( + info_gauge = self._gauge_cls( name=name, documentation=documentation, labelnames=metrics_info.keys()).labels(**metrics_info) diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py new file mode 100644 index 0000000000000..a51c3ed7f5720 --- /dev/null +++ b/vllm/v1/metrics/ray_wrappers.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +import time +from typing import Optional, Union + +from vllm.config import VllmConfig +from vllm.v1.metrics.loggers import PrometheusStatLogger +from vllm.v1.spec_decode.metrics import SpecDecodingProm + +try: + from ray.util import metrics as ray_metrics + from ray.util.metrics import Metric +except ImportError: + ray_metrics = None + + +class RayPrometheusMetric: + + def __init__(self): + if ray_metrics is None: + raise ImportError( + "RayPrometheusMetric requires Ray to be installed.") + + self.metric: Metric = None + + def labels(self, *labels, **labelskwargs): + if labelskwargs: + for k, v in labelskwargs.items(): + if not isinstance(v, str): + labelskwargs[k] = str(v) + + self.metric.set_default_tags(labelskwargs) + + return self + + +class RayGaugeWrapper(RayPrometheusMetric): + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self.metric = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def set(self, value: Union[int, float]): + return self.metric.set(value) + + def set_to_current_time(self): + # ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html + return self.metric.set(time.time()) + + +class RayCounterWrapper(RayPrometheusMetric): + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self.metric = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self.metric.inc(value) + + +class RayHistogramWrapper(RayPrometheusMetric): + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + buckets: Optional[list[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + boundaries = buckets if buckets else [] + self.metric = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries) + + def observe(self, value: Union[int, float]): + return self.metric.observe(value) + + +class RaySpecDecodingProm(SpecDecodingProm): + """ + RaySpecDecodingProm is used by RayMetrics to log to Ray metrics. + Provides the same metrics as SpecDecodingProm but uses Ray's + util.metrics library. + """ + + _counter_cls = RayCounterWrapper + + +class RayPrometheusStatLogger(PrometheusStatLogger): + """RayPrometheusStatLogger uses Ray metrics instead.""" + + _gauge_cls = RayGaugeWrapper + _counter_cls = RayCounterWrapper + _histogram_cls = RayHistogramWrapper + _spec_decoding_cls = RaySpecDecodingProm + + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + super().__init__(vllm_config, engine_index) + + @staticmethod + def _unregister_vllm_metrics(): + # No-op on purpose + pass diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index f71a59908ef39..899aa9200e85e 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -120,24 +120,30 @@ class SpecDecodingProm: vllm:spec_decode_num_drafts[$interval] """ - def __init__(self, speculative_config: Optional[SpeculativeConfig], - labelnames: list[str], labelvalues: list[str]): + _counter_cls = prometheus_client.Counter + + def __init__( + self, + speculative_config: Optional[SpeculativeConfig], + labelnames: list[str], + labelvalues: list[str], + ): self.spec_decoding_enabled = speculative_config is not None if not self.spec_decoding_enabled: return self.counter_spec_decode_num_drafts = \ - prometheus_client.Counter( + self._counter_cls( name="vllm:spec_decode_num_drafts_total", documentation="Number of spec decoding drafts.", labelnames=labelnames).labels(*labelvalues) self.counter_spec_decode_num_draft_tokens = \ - prometheus_client.Counter( + self._counter_cls( name="vllm:spec_decode_num_draft_tokens_total", documentation="Number of draft tokens.", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames,).labels(*labelvalues) self.counter_spec_decode_num_accepted_tokens = \ - prometheus_client.Counter( + self._counter_cls( name="vllm:spec_decode_num_accepted_tokens_total", documentation="Number of accepted tokens.", labelnames=labelnames).labels(*labelvalues) @@ -146,12 +152,13 @@ class SpecDecodingProm: num_spec_tokens = (speculative_config.num_speculative_tokens if self.spec_decoding_enabled else 0) pos_labelnames = labelnames + ["position"] - base_counter = prometheus_client.Counter( + base_counter = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens_per_pos", documentation="Accepted tokens per draft position.", - labelnames=pos_labelnames) - self.counter_spec_decode_num_accepted_tokens_per_pos: \ - list[prometheus_client.Counter] = [] + labelnames=pos_labelnames, + ) + self.counter_spec_decode_num_accepted_tokens_per_pos: list[ + prometheus_client.Counter] = [] for pos in range(num_spec_tokens): pos_labelvalues = labelvalues + [str(pos)] self.counter_spec_decode_num_accepted_tokens_per_pos.append(