mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[V1][Metrics] Add API for accessing in-memory Prometheus metrics (#17010)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
4318c0559d
commit
06a0338015
@ -222,6 +222,7 @@ steps:
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
- pytest -v -s v1/test_metrics_reader.py
|
||||
# TODO: accuracy does not match, whether setting
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- pytest -v -s v1/e2e
|
||||
|
||||
@ -6,6 +6,7 @@ import os
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.metrics.reader import Counter, Vector
|
||||
|
||||
|
||||
def load_prompts(dataset_path, num_prompts):
|
||||
@ -105,30 +106,33 @@ def main():
|
||||
print(f"generated text: {output.outputs[0].text}")
|
||||
print("-" * 50)
|
||||
|
||||
if not hasattr(outputs, "metrics") or outputs.metrics is None:
|
||||
try:
|
||||
metrics = llm.get_metrics()
|
||||
except AssertionError:
|
||||
print("Metrics are not supported in the V0 engine.")
|
||||
return
|
||||
|
||||
# calculate the average number of accepted tokens per forward pass, +1 is
|
||||
# to account for the token from the target model that's always going to be
|
||||
# accepted
|
||||
acceptance_counts = [0] * (args.num_spec_tokens + 1)
|
||||
for output in outputs:
|
||||
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
|
||||
acceptance_counts[step] += count
|
||||
num_drafts = num_accepted = 0
|
||||
acceptance_counts = [0] * args.num_spec_tokens
|
||||
for metric in metrics:
|
||||
if metric.name == "vllm:spec_decode_num_drafts":
|
||||
assert isinstance(metric, Counter)
|
||||
num_drafts += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
|
||||
assert isinstance(metric, Counter)
|
||||
num_accepted += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||
assert isinstance(metric, Vector)
|
||||
for pos in range(len(metric.values)):
|
||||
acceptance_counts[pos] += metric.values[pos]
|
||||
|
||||
print("-" * 50)
|
||||
print(
|
||||
f"mean acceptance length (including bonus tokens): \
|
||||
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}"
|
||||
)
|
||||
print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}")
|
||||
print("-" * 50)
|
||||
|
||||
# print acceptance at each token position
|
||||
for i in range(len(acceptance_counts)):
|
||||
print(
|
||||
f"acceptance at token {i}:"
|
||||
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}"
|
||||
)
|
||||
print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
49
examples/offline_inference/metrics.py
Normal file
49
examples/offline_inference/metrics.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", disable_log_stats=False)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Dump all metrics
|
||||
for metric in llm.get_metrics():
|
||||
if isinstance(metric, Gauge):
|
||||
print(f"{metric.name} (gauge) = {metric.value}")
|
||||
elif isinstance(metric, Counter):
|
||||
print(f"{metric.name} (counter) = {metric.value}")
|
||||
elif isinstance(metric, Vector):
|
||||
print(f"{metric.name} (vector) = {metric.values}")
|
||||
elif isinstance(metric, Histogram):
|
||||
print(f"{metric.name} (histogram)")
|
||||
print(f" sum = {metric.sum}")
|
||||
print(f" count = {metric.count}")
|
||||
for bucket_le, value in metric.buckets.items():
|
||||
print(f" {bucket_le} = {value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -6,6 +6,7 @@ from typing import Optional
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
DTYPE = "half"
|
||||
@ -97,3 +98,67 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
||||
raise AssertionError(
|
||||
f"{len(completion_counts)} unique completions; expected"
|
||||
f" {n}. Repeats: {repeats}")
|
||||
|
||||
|
||||
def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
|
||||
max_tokens = 100
|
||||
# Use spec decoding to test num_accepted_tokens_per_pos
|
||||
speculative_config = {
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 5,
|
||||
}
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
with vllm_runner(
|
||||
MODEL,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
) as vllm_model:
|
||||
model: LLM = vllm_model.model
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens)
|
||||
outputs = model.generate(example_prompts, sampling_params)
|
||||
|
||||
n_prompts = len(example_prompts)
|
||||
assert len(outputs) == n_prompts
|
||||
|
||||
total_tokens = 0
|
||||
for out in outputs:
|
||||
assert len(out.outputs) == 1
|
||||
total_tokens += len(out.outputs[0].token_ids)
|
||||
assert total_tokens == max_tokens * n_prompts
|
||||
|
||||
metrics = model.get_metrics()
|
||||
|
||||
def find_metric(name) -> list[Metric]:
|
||||
found = []
|
||||
for metric in metrics:
|
||||
if metric.name == name:
|
||||
found.append(metric)
|
||||
return found
|
||||
|
||||
num_requests_running = find_metric("vllm:num_requests_running")
|
||||
assert len(num_requests_running) == 1
|
||||
assert isinstance(num_requests_running[0], Gauge)
|
||||
assert num_requests_running[0].value == .0
|
||||
|
||||
generation_tokens = find_metric("vllm:generation_tokens")
|
||||
assert len(generation_tokens) == 1
|
||||
assert isinstance(generation_tokens[0], Counter)
|
||||
assert generation_tokens[0].value == total_tokens
|
||||
|
||||
request_generation_tokens = find_metric(
|
||||
"vllm:request_generation_tokens")
|
||||
assert len(request_generation_tokens) == 1
|
||||
assert isinstance(request_generation_tokens[0], Histogram)
|
||||
assert "+Inf" in request_generation_tokens[0].buckets
|
||||
assert request_generation_tokens[0].buckets["+Inf"] == n_prompts
|
||||
assert request_generation_tokens[0].count == n_prompts
|
||||
assert request_generation_tokens[0].sum == total_tokens
|
||||
|
||||
num_accepted_tokens_per_pos = find_metric(
|
||||
"vllm:spec_decode_num_accepted_tokens_per_pos")
|
||||
assert len(num_accepted_tokens_per_pos) == 1
|
||||
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
|
||||
assert len(num_accepted_tokens_per_pos[0].values) == 5
|
||||
|
||||
112
tests/v1/test_metrics_reader.py
Normal file
112
tests/v1/test_metrics_reader.py
Normal file
@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import prometheus_client
|
||||
import pytest
|
||||
|
||||
from vllm.v1.metrics.reader import (Counter, Gauge, Histogram, Vector,
|
||||
get_metrics_snapshot)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def test_registry(monkeypatch):
|
||||
# Use a custom registry for tests
|
||||
test_registry = prometheus_client.CollectorRegistry(auto_describe=True)
|
||||
monkeypatch.setattr("vllm.v1.metrics.reader.REGISTRY", test_registry)
|
||||
return test_registry
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_engines", [1, 4])
|
||||
def test_gauge_metric(test_registry, num_engines):
|
||||
g = prometheus_client.Gauge("vllm:test_gauge",
|
||||
"Test gauge metric",
|
||||
labelnames=["model", "engine_index"],
|
||||
registry=test_registry)
|
||||
for i in range(num_engines):
|
||||
g.labels(model="foo", engine_index=str(i)).set(98.5)
|
||||
|
||||
metrics = get_metrics_snapshot()
|
||||
assert len(metrics) == num_engines
|
||||
engine_labels = [str(i) for i in range(num_engines)]
|
||||
for m in metrics:
|
||||
assert isinstance(m, Gauge)
|
||||
assert m.name == "vllm:test_gauge"
|
||||
assert m.value == 98.5
|
||||
assert m.labels["model"] == "foo"
|
||||
assert m.labels["engine_index"] in engine_labels
|
||||
engine_labels.remove(m.labels["engine_index"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_engines", [1, 4])
|
||||
def test_counter_metric(test_registry, num_engines):
|
||||
c = prometheus_client.Counter("vllm:test_counter",
|
||||
"Test counter metric",
|
||||
labelnames=["model", "engine_index"],
|
||||
registry=test_registry)
|
||||
for i in range(num_engines):
|
||||
c.labels(model="bar", engine_index=str(i)).inc(19)
|
||||
|
||||
metrics = get_metrics_snapshot()
|
||||
assert len(metrics) == num_engines
|
||||
engine_labels = [str(i) for i in range(num_engines)]
|
||||
for m in metrics:
|
||||
assert isinstance(m, Counter)
|
||||
assert m.name == "vllm:test_counter"
|
||||
assert m.value == 19
|
||||
assert m.labels["model"] == "bar"
|
||||
assert m.labels["engine_index"] in engine_labels
|
||||
engine_labels.remove(m.labels["engine_index"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_engines", [1, 4])
|
||||
def test_histogram_metric(test_registry, num_engines):
|
||||
h = prometheus_client.Histogram("vllm:test_histogram",
|
||||
"Test histogram metric",
|
||||
labelnames=["model", "engine_index"],
|
||||
buckets=[10, 20, 30, 40, 50],
|
||||
registry=test_registry)
|
||||
for i in range(num_engines):
|
||||
hist = h.labels(model="blaa", engine_index=str(i))
|
||||
hist.observe(42)
|
||||
hist.observe(21)
|
||||
hist.observe(7)
|
||||
|
||||
metrics = get_metrics_snapshot()
|
||||
assert len(metrics) == num_engines
|
||||
engine_labels = [str(i) for i in range(num_engines)]
|
||||
for m in metrics:
|
||||
assert isinstance(m, Histogram)
|
||||
assert m.name == "vllm:test_histogram"
|
||||
assert m.count == 3
|
||||
assert m.sum == 70
|
||||
assert m.buckets["10.0"] == 1
|
||||
assert m.buckets["20.0"] == 1
|
||||
assert m.buckets["30.0"] == 2
|
||||
assert m.buckets["40.0"] == 2
|
||||
assert m.buckets["50.0"] == 3
|
||||
assert m.labels["model"] == "blaa"
|
||||
assert m.labels["engine_index"] in engine_labels
|
||||
engine_labels.remove(m.labels["engine_index"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_engines", [1, 4])
|
||||
def test_vector_metric(test_registry, num_engines):
|
||||
c = prometheus_client.Counter(
|
||||
"vllm:spec_decode_num_accepted_tokens_per_pos",
|
||||
"Vector-like counter metric",
|
||||
labelnames=["position", "model", "engine_index"],
|
||||
registry=test_registry)
|
||||
for i in range(num_engines):
|
||||
c.labels(position="0", model="llama", engine_index=str(i)).inc(10)
|
||||
c.labels(position="1", model="llama", engine_index=str(i)).inc(5)
|
||||
c.labels(position="2", model="llama", engine_index=str(i)).inc(1)
|
||||
|
||||
metrics = get_metrics_snapshot()
|
||||
assert len(metrics) == num_engines
|
||||
engine_labels = [str(i) for i in range(num_engines)]
|
||||
for m in metrics:
|
||||
assert isinstance(m, Vector)
|
||||
assert m.name == "vllm:spec_decode_num_accepted_tokens_per_pos"
|
||||
assert m.values == [10, 5, 1]
|
||||
assert m.labels["model"] == "llama"
|
||||
assert m.labels["engine_index"] in engine_labels
|
||||
engine_labels.remove(m.labels["engine_index"])
|
||||
@ -4,7 +4,8 @@ import itertools
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
|
||||
cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
@ -47,6 +48,9 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
|
||||
is_list_of)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
@ -1294,6 +1298,20 @@ class LLM:
|
||||
"""
|
||||
self.llm_engine.wake_up(tags)
|
||||
|
||||
def get_metrics(self) -> list["Metric"]:
|
||||
"""Return a snapshot of aggregated metrics from Prometheus.
|
||||
|
||||
Returns:
|
||||
A ``MetricSnapshot`` instance capturing the current state
|
||||
of all aggregated metrics from Prometheus.
|
||||
|
||||
Note:
|
||||
This method is only available with the V1 LLM engine.
|
||||
"""
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
assert isinstance(self.llm_engine, V1LLMEngine)
|
||||
return self.llm_engine.get_metrics()
|
||||
|
||||
# LEGACY
|
||||
def _convert_v1_inputs(
|
||||
self,
|
||||
|
||||
@ -27,7 +27,10 @@ from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory
|
||||
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
||||
StatLoggerFactory)
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -64,6 +67,11 @@ class LLMEngine:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.log_stats = log_stats
|
||||
self.stat_logger: Optional[StatLoggerBase] = None
|
||||
if self.log_stats:
|
||||
self.stat_logger = PrometheusStatLogger(vllm_config)
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
# In the decoupled engine case this is handled in EngineCoreProc.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
@ -86,7 +94,7 @@ class LLMEngine:
|
||||
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=False)
|
||||
log_stats=self.log_stats)
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
@ -94,7 +102,7 @@ class LLMEngine:
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False, # FIXME: implement
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
if not multiprocess_mode:
|
||||
@ -223,12 +231,21 @@ class LLMEngine:
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
# 2) Process EngineCoreOutputs.
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs)
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Record stats
|
||||
if self.stat_logger is not None:
|
||||
assert outputs.scheduler_stats is not None
|
||||
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
def get_vllm_config(self):
|
||||
@ -260,6 +277,10 @@ class LLMEngine:
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.engine_core.is_sleeping()
|
||||
|
||||
def get_metrics(self) -> list[Metric]:
|
||||
assert self.log_stats, "Stat logging disabled"
|
||||
return get_metrics_snapshot()
|
||||
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
|
||||
@ -200,24 +200,24 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
# Counters
|
||||
#
|
||||
self.counter_num_preempted_reqs = self._counter_cls(
|
||||
name="vllm:num_preemptions_total",
|
||||
name="vllm:num_preemptions",
|
||||
documentation="Cumulative number of preemption from the engine.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_prompt_tokens = self._counter_cls(
|
||||
name="vllm:prompt_tokens_total",
|
||||
name="vllm:prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_generation_tokens = self._counter_cls(
|
||||
name="vllm:generation_tokens_total",
|
||||
name="vllm:generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
self.counter_request_success: dict[FinishReason,
|
||||
prometheus_client.Counter] = {}
|
||||
counter_request_success_base = self._counter_cls(
|
||||
name="vllm:request_success_total",
|
||||
name="vllm:request_success",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + ["finished_reason"])
|
||||
for reason in FinishReason:
|
||||
|
||||
245
vllm/v1/metrics/reader.py
Normal file
245
vllm/v1/metrics/reader.py
Normal file
@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from prometheus_client import REGISTRY
|
||||
from prometheus_client import Metric as PromMetric
|
||||
from prometheus_client.samples import Sample
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metric:
|
||||
"""A base class for prometheus metrics.
|
||||
|
||||
Each metric may be associated with key=value labels, and
|
||||
in some cases a single vLLM instance may have multiple
|
||||
metrics with the same name but different sets of labels.
|
||||
"""
|
||||
name: str
|
||||
labels: dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Counter(Metric):
|
||||
"""A monotonically increasing integer counter."""
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vector(Metric):
|
||||
"""An ordered array of integer counters.
|
||||
|
||||
This type - which doesn't exist in Prometheus - models one very
|
||||
specific metric, vllm:spec_decode_num_accepted_tokens_per_pos.
|
||||
"""
|
||||
values: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gauge(Metric):
|
||||
"""A numerical value that can go up or down."""
|
||||
value: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Histogram(Metric):
|
||||
"""Observations recorded in configurable buckets.
|
||||
|
||||
Buckets are represented by a dictionary. The key is
|
||||
the upper limit of the bucket, and the value is the
|
||||
observed count in that bucket. A '+Inf' key always
|
||||
exists.
|
||||
|
||||
The count property is the total count across all
|
||||
buckets, identical to the count of the '+Inf' bucket.
|
||||
|
||||
The sum property is the total sum of all observed
|
||||
values.
|
||||
"""
|
||||
count: int
|
||||
sum: float
|
||||
buckets: dict[str, int]
|
||||
|
||||
|
||||
def get_metrics_snapshot() -> list[Metric]:
|
||||
"""An API for accessing in-memory Prometheus metrics.
|
||||
|
||||
Example:
|
||||
>>> for metric in llm.get_metrics():
|
||||
... if isinstance(metric, Counter):
|
||||
... print(f"{metric} = {metric.value}")
|
||||
... elif isinstance(metric, Gauge):
|
||||
... print(f"{metric} = {metric.value}")
|
||||
... elif isinstance(metric, Histogram):
|
||||
... print(f"{metric}")
|
||||
... print(f" sum = {metric.sum}")
|
||||
... print(f" count = {metric.count}")
|
||||
... for bucket_le, value in metrics.buckets.items():
|
||||
... print(f" {bucket_le} = {value}")
|
||||
"""
|
||||
collected: list[Metric] = []
|
||||
for metric in REGISTRY.collect():
|
||||
if not metric.name.startswith("vllm:"):
|
||||
continue
|
||||
if metric.type == "gauge":
|
||||
samples = _get_samples(metric)
|
||||
for s in samples:
|
||||
collected.append(
|
||||
Gauge(name=metric.name, labels=s.labels, value=s.value))
|
||||
elif metric.type == "counter":
|
||||
samples = _get_samples(metric, "_total")
|
||||
if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||
#
|
||||
# Ugly vllm:num_accepted_tokens_per_pos special case.
|
||||
#
|
||||
# This metric is a vector of counters - for each spec
|
||||
# decoding token position, we observe the number of
|
||||
# accepted tokens using a Counter labeled with 'position'.
|
||||
# We convert these into a vector of integer values.
|
||||
#
|
||||
for labels, values in _digest_num_accepted_by_pos_samples(
|
||||
samples):
|
||||
collected.append(
|
||||
Vector(name=metric.name, labels=labels, values=values))
|
||||
else:
|
||||
for s in samples:
|
||||
collected.append(
|
||||
Counter(name=metric.name,
|
||||
labels=s.labels,
|
||||
value=int(s.value)))
|
||||
|
||||
elif metric.type == "histogram":
|
||||
#
|
||||
# A histogram has a number of '_bucket' samples where
|
||||
# the 'le' label represents the upper limit of the bucket.
|
||||
# We convert these bucketized values into a dict of values
|
||||
# indexed by the value of the 'le' label. The 'le=+Inf'
|
||||
# label is a special case, catching all values observed.
|
||||
#
|
||||
bucket_samples = _get_samples(metric, "_bucket")
|
||||
count_samples = _get_samples(metric, "_count")
|
||||
sum_samples = _get_samples(metric, "_sum")
|
||||
for labels, buckets, count_value, sum_value in _digest_histogram(
|
||||
bucket_samples, count_samples, sum_samples):
|
||||
collected.append(
|
||||
Histogram(name=metric.name,
|
||||
labels=labels,
|
||||
buckets=buckets,
|
||||
count=count_value,
|
||||
sum=sum_value))
|
||||
else:
|
||||
raise AssertionError(f"Unknown metric type {metric.type}")
|
||||
|
||||
return collected
|
||||
|
||||
|
||||
def _get_samples(metric: PromMetric,
|
||||
suffix: Optional[str] = None) -> list[Sample]:
|
||||
name = (metric.name + suffix) if suffix is not None else metric.name
|
||||
return [s for s in metric.samples if s.name == name]
|
||||
|
||||
|
||||
def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]:
|
||||
labels_copy = labels.copy()
|
||||
labels_copy.pop(key_to_remove)
|
||||
return labels_copy
|
||||
|
||||
|
||||
def _digest_histogram(
|
||||
bucket_samples: list[Sample], count_samples: list[Sample],
|
||||
sum_samples: list[Sample]
|
||||
) -> list[tuple[dict[str, str], dict[str, int], int, float]]:
|
||||
#
|
||||
# In the case of DP, we have an indigestable
|
||||
# per-bucket-per-engine count as a list of labelled
|
||||
# samples, along with total and sum samples
|
||||
#
|
||||
# bucket_samples (in):
|
||||
# labels = {bucket: 100, idx: 0}, value = 2
|
||||
# labels = {bucket: 200, idx: 0}, value = 4
|
||||
# labels = {bucket: Inf, idx: 0}, value = 10
|
||||
# labels = {bucket: 100, idx: 1}, value = 1
|
||||
# labels = {bucket: 200, idx: 2}, value = 5
|
||||
# labels = {bucket: Inf, idx: 3}, value = 7
|
||||
# count_samples (in):
|
||||
# labels = {idx: 0}, value = 10
|
||||
# labels = {idx: 1}, value = 7
|
||||
# sum_samples (in):
|
||||
# labels = {idx: 0}, value = 2000
|
||||
# labels = {idx: 1}, value = 1200
|
||||
#
|
||||
# output: [
|
||||
# {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000
|
||||
# {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200
|
||||
# ]
|
||||
buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {}
|
||||
for s in bucket_samples:
|
||||
bucket = s.labels["le"]
|
||||
labels_key = frozenset(_strip_label(s.labels, "le").items())
|
||||
if labels_key not in buckets_by_labels:
|
||||
buckets_by_labels[labels_key] = {}
|
||||
buckets_by_labels[labels_key][bucket] = int(s.value)
|
||||
|
||||
counts_by_labels: dict[frozenset[tuple[str, str]], int] = {}
|
||||
for s in count_samples:
|
||||
labels_key = frozenset(s.labels.items())
|
||||
counts_by_labels[labels_key] = int(s.value)
|
||||
|
||||
sums_by_labels: dict[frozenset[tuple[str, str]], float] = {}
|
||||
for s in sum_samples:
|
||||
labels_key = frozenset(s.labels.items())
|
||||
sums_by_labels[labels_key] = s.value
|
||||
|
||||
assert set(buckets_by_labels.keys()) == set(
|
||||
counts_by_labels.keys()) == set(sums_by_labels.keys())
|
||||
|
||||
output = []
|
||||
label_keys = list(buckets_by_labels.keys())
|
||||
for k in label_keys:
|
||||
labels = dict(k)
|
||||
output.append((labels, buckets_by_labels[k], counts_by_labels[k],
|
||||
sums_by_labels[k]))
|
||||
return output
|
||||
|
||||
|
||||
def _digest_num_accepted_by_pos_samples(
|
||||
samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]:
|
||||
#
|
||||
# In the case of DP, we have an indigestable
|
||||
# per-position-per-engine count as a list of
|
||||
# labelled samples
|
||||
#
|
||||
# samples (in):
|
||||
# labels = {pos: 0, idx: 0}, value = 10
|
||||
# labels = {pos: 1, idx: 0}, value = 7
|
||||
# labels = {pos: 2, idx: 0}, value = 2
|
||||
# labels = {pos: 0, idx: 1}, value = 5
|
||||
# labels = {pos: 1, idx: 1}, value = 3
|
||||
# labels = {pos: 2, idx: 1}, value = 1
|
||||
#
|
||||
# output: [
|
||||
# {idx: 0}, [10, 7, 2]
|
||||
# {idx: 1}, [5, 3, 1]
|
||||
# ]
|
||||
#
|
||||
max_pos = 0
|
||||
values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {}
|
||||
|
||||
for s in samples:
|
||||
position = int(s.labels["position"])
|
||||
max_pos = max(max_pos, position)
|
||||
|
||||
labels_key = frozenset(_strip_label(s.labels, "position").items())
|
||||
if labels_key not in values_by_labels:
|
||||
values_by_labels[labels_key] = {}
|
||||
values_by_labels[labels_key][position] = int(s.value)
|
||||
|
||||
output = []
|
||||
for labels_key, values_by_position in values_by_labels.items():
|
||||
labels = dict(labels_key)
|
||||
values = [0] * (max_pos + 1)
|
||||
for pos, val in values_by_position.items():
|
||||
values[pos] = val
|
||||
output.append((labels, values))
|
||||
return output
|
||||
@ -134,17 +134,17 @@ class SpecDecodingProm:
|
||||
|
||||
self.counter_spec_decode_num_drafts = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_drafts_total",
|
||||
name="vllm:spec_decode_num_drafts",
|
||||
documentation="Number of spec decoding drafts.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_draft_tokens = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
name="vllm:spec_decode_num_draft_tokens",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames,).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_accepted_tokens = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
name="vllm:spec_decode_num_accepted_tokens",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user