[Misc] Add ReplicaId to Ray metrics (#24267)

Signed-off-by: Seiji Eicher <seiji@anyscale.com>
Co-authored-by: rongfu.leng <1275177125@qq.com>
This commit is contained in:
Seiji Eicher 2025-12-01 19:21:44 -08:00 committed by GitHub
parent fc95521ba5
commit 22274b2184
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,37 +7,55 @@ from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm
try:
from ray import serve as ray_serve
from ray.util import metrics as ray_metrics
from ray.util.metrics import Metric
except ImportError:
ray_metrics = None
ray_serve = None
import regex as re
def _get_replica_id() -> str | None:
"""Get the current Ray Serve replica ID, or None if not in a Serve context."""
if ray_serve is None:
return None
try:
return ray_serve.get_replica_context().replica_id.unique_id
except ray_serve.exceptions.RayServeException:
return None
class RayPrometheusMetric:
def __init__(self):
if ray_metrics is None:
raise ImportError("RayPrometheusMetric requires Ray to be installed.")
self.metric: Metric = None
@staticmethod
def _get_tag_keys(labelnames: list[str] | None) -> tuple[str, ...]:
labels = list(labelnames) if labelnames else []
labels.append("ReplicaId")
return tuple(labels)
def labels(self, *labels, **labelskwargs):
if labels:
# -1 because ReplicaId was added automatically
expected = len(self.metric._tag_keys) - 1
if len(labels) != expected:
raise ValueError(
"Number of labels must match the number of tag keys. "
f"Expected {expected}, got {len(labels)}"
)
labelskwargs.update(zip(self.metric._tag_keys, labels))
labelskwargs["ReplicaId"] = _get_replica_id() or ""
if labelskwargs:
for k, v in labelskwargs.items():
if not isinstance(v, str):
labelskwargs[k] = str(v)
self.metric.set_default_tags(labelskwargs)
if labels:
if len(labels) != len(self.metric._tag_keys):
raise ValueError(
"Number of labels must match the number of tag keys. "
f"Expected {len(self.metric._tag_keys)}, got {len(labels)}"
)
self.metric.set_default_tags(dict(zip(self.metric._tag_keys, labels)))
return self
@staticmethod
@ -71,10 +89,14 @@ class RayGaugeWrapper(RayPrometheusMetric):
# "mostrecent", "all", "sum" do not apply. This logic can be manually
# implemented at the observability layer (Prometheus/Grafana).
del multiprocess_mode
labelnames_tuple = tuple(labelnames) if labelnames else None
tag_keys = self._get_tag_keys(labelnames)
name = self._get_sanitized_opentelemetry_name(name)
self.metric = ray_metrics.Gauge(
name=name, description=documentation, tag_keys=labelnames_tuple
name=name,
description=documentation,
tag_keys=tag_keys,
)
def set(self, value: int | float):
@ -95,10 +117,12 @@ class RayCounterWrapper(RayPrometheusMetric):
documentation: str | None = "",
labelnames: list[str] | None = None,
):
labelnames_tuple = tuple(labelnames) if labelnames else None
tag_keys = self._get_tag_keys(labelnames)
name = self._get_sanitized_opentelemetry_name(name)
self.metric = ray_metrics.Counter(
name=name, description=documentation, tag_keys=labelnames_tuple
name=name,
description=documentation,
tag_keys=tag_keys,
)
def inc(self, value: int | float = 1.0):
@ -118,13 +142,14 @@ class RayHistogramWrapper(RayPrometheusMetric):
labelnames: list[str] | None = None,
buckets: list[float] | None = None,
):
labelnames_tuple = tuple(labelnames) if labelnames else None
tag_keys = self._get_tag_keys(labelnames)
name = self._get_sanitized_opentelemetry_name(name)
boundaries = buckets if buckets else []
self.metric = ray_metrics.Histogram(
name=name,
description=documentation,
tag_keys=labelnames_tuple,
tag_keys=tag_keys,
boundaries=boundaries,
)