diff --git a/tests/v1/metrics/test_ray_metrics.py b/tests/v1/metrics/test_ray_metrics.py index ea54038a2c775..0898ae65e7cd3 100644 --- a/tests/v1/metrics/test_ray_metrics.py +++ b/tests/v1/metrics/test_ray_metrics.py @@ -47,12 +47,15 @@ def test_engine_log_metrics_ray( engine_args, stat_loggers=[RayPrometheusStatLogger]) for i, prompt in enumerate(example_prompts): - engine.generate( + results = engine.generate( request_id=f"request-id-{i}", prompt=prompt, sampling_params=SamplingParams(max_tokens=max_tokens), ) + async for _ in results: + pass + # Create the actor and call the async method actor = EngineTestActor.remote() # type: ignore[attr-defined] ray.get(actor.run.remote()) diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index 18c8dcf0a0d35..cce692d6c09e7 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -31,6 +31,16 @@ class RayPrometheusMetric: self.metric.set_default_tags(labelskwargs) + if labels: + if len(labels) != len(self.metric._tag_keys): + raise ValueError( + "Number of labels must match the number of tag keys. " + f"Expected {len(self.metric._tag_keys)}, got {len(labels)}" + ) + + self.metric.set_default_tags( + dict(zip(self.metric._tag_keys, labels))) + return self