[Bugfix] Fix Ray Metrics API usage (#6354)

This commit is contained in:
Antoni Baum 2024-07-17 12:40:10 -07:00 committed by GitHub
parent a38524f338
commit 5f0b9933e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 195 additions and 40 deletions

View File

@ -1,11 +1,13 @@
from typing import List from typing import List
import pytest import pytest
import ray
from prometheus_client import REGISTRY from prometheus_client import REGISTRY
from vllm import EngineArgs, LLMEngine from vllm import EngineArgs, LLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
MODELS = [ MODELS = [
@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
labels) labels)
assert ( assert (
metric_value == num_requests), "Metrics should be collected" metric_value == num_requests), "Metrics should be collected"
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [16])
def test_engine_log_metrics_ray(
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is quite weak - it only checks that we can use
# RayPrometheusStatLogger without exceptions.
# Checking whether the metrics are actually emitted is unfortunately
# non-trivial.
# We have to run in a Ray task for Ray metrics to be emitted correctly
@ray.remote(num_gpus=1)
def _inner():
class _RayPrometheusStatLogger(RayPrometheusStatLogger):
def __init__(self, *args, **kwargs):
self._i = 0
super().__init__(*args, **kwargs)
def log(self, *args, **kwargs):
self._i += 1
return super().log(*args, **kwargs)
engine_args = EngineArgs(
model=model,
dtype=dtype,
disable_log_stats=False,
)
engine = LLMEngine.from_engine_args(engine_args)
logger = _RayPrometheusStatLogger(
local_interval=0.5,
labels=dict(model_name=engine.model_config.served_model_name),
max_model_len=engine.model_config.max_model_len)
engine.add_logger("ray", logger)
for i, prompt in enumerate(example_prompts):
engine.add_request(
f"request-id-{i}",
prompt,
SamplingParams(max_tokens=max_tokens),
)
while engine.has_unfinished_requests():
engine.step()
assert logger._i > 0, ".log must be called at least once"
ray.get(_inner.remote())

View File

@ -12,6 +12,7 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
@ -389,6 +390,7 @@ class AsyncLLMEngine:
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine": ) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
@ -451,6 +453,7 @@ class AsyncLLMEngine:
max_log_len=engine_args.max_log_len, max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop, start_engine_loop=start_engine_loop,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers,
) )
return engine return engine
@ -957,3 +960,19 @@ class AsyncLLMEngine:
) )
else: else:
return self.engine.is_tracing_enabled() return self.engine.is_tracing_enabled()
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if self.engine_use_ray:
ray.get(
self.engine.add_logger.remote( # type: ignore
logger_name=logger_name, logger=logger))
else:
self.engine.add_logger(logger_name=logger_name, logger=logger)
def remove_logger(self, logger_name: str) -> None:
if self.engine_use_ray:
ray.get(
self.engine.remove_logger.remote( # type: ignore
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)

View File

@ -379,6 +379,7 @@ class LLMEngine:
cls, cls,
engine_args: EngineArgs, engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine": ) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments.""" """Creates an LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
@ -423,6 +424,7 @@ class LLMEngine:
executor_class=executor_class, executor_class=executor_class,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers,
) )
return engine return engine

View File

@ -30,55 +30,55 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions # begin-metrics-definitions
class Metrics: class Metrics:
labelname_finish_reason = "finished_reason" labelname_finish_reason = "finished_reason"
_base_library = prometheus_client _gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors # Unregister any existing vLLM collectors
self._unregister_vllm_metrics() self._unregister_vllm_metrics()
# Config Information # Config Information
self.info_cache_config = prometheus_client.Info( self._create_info_cache_config()
name='vllm:cache_config',
documentation='information of cache_config')
# System stats # System stats
# Scheduler State # Scheduler State
self.gauge_scheduler_running = self._base_library.Gauge( self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running", name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.", documentation="Number of requests currently running on GPU.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_scheduler_waiting = self._base_library.Gauge( self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting", name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.", documentation="Number of requests waiting to be processed.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_scheduler_swapped = self._base_library.Gauge( self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped", name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.", documentation="Number of requests swapped to CPU.",
labelnames=labelnames) labelnames=labelnames)
# KV Cache Usage in % # KV Cache Usage in %
self.gauge_gpu_cache_usage = self._base_library.Gauge( self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc", name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_cpu_cache_usage = self._base_library.Gauge( self.gauge_cpu_cache_usage = self._gauge_cls(
name="vllm:cpu_cache_usage_perc", name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.", documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames)
# Iteration stats # Iteration stats
self.counter_num_preemption = self._base_library.Counter( self.counter_num_preemption = self._counter_cls(
name="vllm:num_preemptions_total", name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.", documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames) labelnames=labelnames)
self.counter_prompt_tokens = self._base_library.Counter( self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames) labelnames=labelnames)
self.counter_generation_tokens = self._base_library.Counter( self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens_total", name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames) labelnames=labelnames)
self.histogram_time_to_first_token = self._base_library.Histogram( self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds", name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
labelnames=labelnames, labelnames=labelnames,
@ -86,7 +86,7 @@ class Metrics:
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0 0.75, 1.0, 2.5, 5.0, 7.5, 10.0
]) ])
self.histogram_time_per_output_token = self._base_library.Histogram( self.histogram_time_per_output_token = self._histogram_cls(
name="vllm:time_per_output_token_seconds", name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.", documentation="Histogram of time per output token in seconds.",
labelnames=labelnames, labelnames=labelnames,
@ -97,83 +97,157 @@ class Metrics:
# Request stats # Request stats
# Latency # Latency
self.histogram_e2e_time_request = self._base_library.Histogram( self.histogram_e2e_time_request = self._histogram_cls(
name="vllm:e2e_request_latency_seconds", name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.", documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata # Metadata
self.histogram_num_prompt_tokens_request = self._base_library.Histogram( self.histogram_num_prompt_tokens_request = self._histogram_cls(
name="vllm:request_prompt_tokens", name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames, labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
) )
self.histogram_num_generation_tokens_request = \ self.histogram_num_generation_tokens_request = \
self._base_library.Histogram( self._histogram_cls(
name="vllm:request_generation_tokens", name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames, labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
) )
self.histogram_best_of_request = self._base_library.Histogram( self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of", name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.", documentation="Histogram of the best_of request parameter.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1, 2, 5, 10, 20], buckets=[1, 2, 5, 10, 20],
) )
self.histogram_n_request = self._base_library.Histogram( self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n", name="vllm:request_params_n",
documentation="Histogram of the n request parameter.", documentation="Histogram of the n request parameter.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1, 2, 5, 10, 20], buckets=[1, 2, 5, 10, 20],
) )
self.counter_request_success = self._base_library.Counter( self.counter_request_success = self._counter_cls(
name="vllm:request_success_total", name="vllm:request_success_total",
documentation="Count of successfully processed requests.", documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason]) labelnames=labelnames + [Metrics.labelname_finish_reason])
# Speculatie decoding stats # Speculatie decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge( self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate", name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.", documentation="Speulative token acceptance rate.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_spec_decode_efficiency = self._base_library.Gauge( self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency", name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.", documentation="Speculative decoding system efficiency.",
labelnames=labelnames) labelnames=labelnames)
self.counter_spec_decode_num_accepted_tokens = ( self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
self._base_library.Counter( name="vllm:spec_decode_num_accepted_tokens_total",
name="vllm:spec_decode_num_accepted_tokens_total", documentation="Number of accepted tokens.",
documentation="Number of accepted tokens.", labelnames=labelnames))
labelnames=labelnames)) self.counter_spec_decode_num_draft_tokens = self._counter_cls(
self.counter_spec_decode_num_draft_tokens = self._base_library.Counter(
name="vllm:spec_decode_num_draft_tokens_total", name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.", documentation="Number of draft tokens.",
labelnames=labelnames) labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = ( self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
self._base_library.Counter( name="vllm:spec_decode_num_emitted_tokens_total",
name="vllm:spec_decode_num_emitted_tokens_total", documentation="Number of emitted tokens.",
documentation="Number of emitted tokens.", labelnames=labelnames))
labelnames=labelnames))
# Deprecated in favor of vllm:prompt_tokens_total # Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = self._base_library.Gauge( self.gauge_avg_prompt_throughput = self._gauge_cls(
name="vllm:avg_prompt_throughput_toks_per_s", name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.", documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
) )
# Deprecated in favor of vllm:generation_tokens_total # Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = self._base_library.Gauge( self.gauge_avg_generation_throughput = self._gauge_cls(
name="vllm:avg_generation_throughput_toks_per_s", name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.", documentation="Average generation throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
) )
def _create_info_cache_config(self) -> None:
# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
def _unregister_vllm_metrics(self) -> None: def _unregister_vllm_metrics(self) -> None:
for collector in list(self._base_library.REGISTRY._collector_to_names): for collector in list(prometheus_client.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name: if hasattr(collector, "_name") and "vllm" in collector._name:
self._base_library.REGISTRY.unregister(collector) prometheus_client.REGISTRY.unregister(collector)
# end-metrics-definitions
class _RayGaugeWrapper:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._gauge = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def labels(self, **labels):
self._gauge.set_default_tags(labels)
return self
def set(self, value: Union[int, float]):
return self._gauge.set(value)
class _RayCounterWrapper:
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._counter = ray_metrics.Counter(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def labels(self, **labels):
self._counter.set_default_tags(labels)
return self
def inc(self, value: Union[int, float] = 1.0):
if value == 0:
return
return self._counter.inc(value)
class _RayHistogramWrapper:
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=buckets)
def labels(self, **labels):
self._histogram.set_default_tags(labels)
return self
def observe(self, value: Union[int, float]):
return self._histogram.observe(value)
class RayMetrics(Metrics): class RayMetrics(Metrics):
@ -181,7 +255,9 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library. Provides the same metrics as Metrics but uses Ray's util.metrics library.
""" """
_base_library = ray_metrics _gauge_cls = _RayGaugeWrapper
_counter_cls = _RayCounterWrapper
_histogram_cls = _RayHistogramWrapper
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None: if ray_metrics is None:
@ -192,8 +268,9 @@ class RayMetrics(Metrics):
# No-op on purpose # No-op on purpose
pass pass
def _create_info_cache_config(self) -> None:
# end-metrics-definitions # No-op on purpose
pass
def build_1_2_5_buckets(max_value: int) -> List[int]: def build_1_2_5_buckets(max_value: int) -> List[int]:
@ -498,3 +575,6 @@ class PrometheusStatLogger(StatLoggerBase):
class RayPrometheusStatLogger(PrometheusStatLogger): class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead.""" """RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics _metrics_cls = RayMetrics
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
return None