[V1][Metrics] Add API for accessing in-memory Prometheus metrics (#17010)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-05-27 10:37:06 +01:00 committed by GitHub
parent 4318c0559d
commit 06a0338015
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 543 additions and 28 deletions

View File

@ -222,6 +222,7 @@ steps:
- pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_serial_utils.py
- pytest -v -s v1/test_utils.py - pytest -v -s v1/test_utils.py
- pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_oracle.py
- pytest -v -s v1/test_metrics_reader.py
# TODO: accuracy does not match, whether setting # TODO: accuracy does not match, whether setting
# VLLM_USE_FLASHINFER_SAMPLER or not on H100. # VLLM_USE_FLASHINFER_SAMPLER or not on H100.
- pytest -v -s v1/e2e - pytest -v -s v1/e2e

View File

@ -6,6 +6,7 @@ import os
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.v1.metrics.reader import Counter, Vector
def load_prompts(dataset_path, num_prompts): def load_prompts(dataset_path, num_prompts):
@ -105,30 +106,33 @@ def main():
print(f"generated text: {output.outputs[0].text}") print(f"generated text: {output.outputs[0].text}")
print("-" * 50) print("-" * 50)
if not hasattr(outputs, "metrics") or outputs.metrics is None: try:
metrics = llm.get_metrics()
except AssertionError:
print("Metrics are not supported in the V0 engine.")
return return
# calculate the average number of accepted tokens per forward pass, +1 is num_drafts = num_accepted = 0
# to account for the token from the target model that's always going to be acceptance_counts = [0] * args.num_spec_tokens
# accepted for metric in metrics:
acceptance_counts = [0] * (args.num_spec_tokens + 1) if metric.name == "vllm:spec_decode_num_drafts":
for output in outputs: assert isinstance(metric, Counter)
for step, count in enumerate(output.metrics.spec_token_acceptance_counts): num_drafts += metric.value
acceptance_counts[step] += count elif metric.name == "vllm:spec_decode_num_accepted_tokens":
assert isinstance(metric, Counter)
num_accepted += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
acceptance_counts[pos] += metric.values[pos]
print("-" * 50) print("-" * 50)
print( print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}")
f"mean acceptance length (including bonus tokens): \
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}"
)
print("-" * 50) print("-" * 50)
# print acceptance at each token position # print acceptance at each token position
for i in range(len(acceptance_counts)): for i in range(len(acceptance_counts)):
print( print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}")
f"acceptance at token {i}:"
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}"
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
from vllm import LLM, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m", disable_log_stats=False)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Dump all metrics
for metric in llm.get_metrics():
if isinstance(metric, Gauge):
print(f"{metric.name} (gauge) = {metric.value}")
elif isinstance(metric, Counter):
print(f"{metric.name} (counter) = {metric.value}")
elif isinstance(metric, Vector):
print(f"{metric.name} (vector) = {metric.values}")
elif isinstance(metric, Histogram):
print(f"{metric.name} (histogram)")
print(f" sum = {metric.sum}")
print(f" count = {metric.count}")
for bucket_le, value in metric.buckets.items():
print(f" {bucket_le} = {value}")
if __name__ == "__main__":
main()

View File

@ -6,6 +6,7 @@ from typing import Optional
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
MODEL = "facebook/opt-125m" MODEL = "facebook/opt-125m"
DTYPE = "half" DTYPE = "half"
@ -97,3 +98,67 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
raise AssertionError( raise AssertionError(
f"{len(completion_counts)} unique completions; expected" f"{len(completion_counts)} unique completions; expected"
f" {n}. Repeats: {repeats}") f" {n}. Repeats: {repeats}")
def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
max_tokens = 100
# Use spec decoding to test num_accepted_tokens_per_pos
speculative_config = {
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 5,
}
monkeypatch.setenv("VLLM_USE_V1", "1")
with vllm_runner(
MODEL,
speculative_config=speculative_config,
disable_log_stats=False,
) as vllm_model:
model: LLM = vllm_model.model
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens)
outputs = model.generate(example_prompts, sampling_params)
n_prompts = len(example_prompts)
assert len(outputs) == n_prompts
total_tokens = 0
for out in outputs:
assert len(out.outputs) == 1
total_tokens += len(out.outputs[0].token_ids)
assert total_tokens == max_tokens * n_prompts
metrics = model.get_metrics()
def find_metric(name) -> list[Metric]:
found = []
for metric in metrics:
if metric.name == name:
found.append(metric)
return found
num_requests_running = find_metric("vllm:num_requests_running")
assert len(num_requests_running) == 1
assert isinstance(num_requests_running[0], Gauge)
assert num_requests_running[0].value == .0
generation_tokens = find_metric("vllm:generation_tokens")
assert len(generation_tokens) == 1
assert isinstance(generation_tokens[0], Counter)
assert generation_tokens[0].value == total_tokens
request_generation_tokens = find_metric(
"vllm:request_generation_tokens")
assert len(request_generation_tokens) == 1
assert isinstance(request_generation_tokens[0], Histogram)
assert "+Inf" in request_generation_tokens[0].buckets
assert request_generation_tokens[0].buckets["+Inf"] == n_prompts
assert request_generation_tokens[0].count == n_prompts
assert request_generation_tokens[0].sum == total_tokens
num_accepted_tokens_per_pos = find_metric(
"vllm:spec_decode_num_accepted_tokens_per_pos")
assert len(num_accepted_tokens_per_pos) == 1
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
assert len(num_accepted_tokens_per_pos[0].values) == 5

View File

@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
import prometheus_client
import pytest
from vllm.v1.metrics.reader import (Counter, Gauge, Histogram, Vector,
get_metrics_snapshot)
@pytest.fixture(autouse=True)
def test_registry(monkeypatch):
# Use a custom registry for tests
test_registry = prometheus_client.CollectorRegistry(auto_describe=True)
monkeypatch.setattr("vllm.v1.metrics.reader.REGISTRY", test_registry)
return test_registry
@pytest.mark.parametrize("num_engines", [1, 4])
def test_gauge_metric(test_registry, num_engines):
g = prometheus_client.Gauge("vllm:test_gauge",
"Test gauge metric",
labelnames=["model", "engine_index"],
registry=test_registry)
for i in range(num_engines):
g.labels(model="foo", engine_index=str(i)).set(98.5)
metrics = get_metrics_snapshot()
assert len(metrics) == num_engines
engine_labels = [str(i) for i in range(num_engines)]
for m in metrics:
assert isinstance(m, Gauge)
assert m.name == "vllm:test_gauge"
assert m.value == 98.5
assert m.labels["model"] == "foo"
assert m.labels["engine_index"] in engine_labels
engine_labels.remove(m.labels["engine_index"])
@pytest.mark.parametrize("num_engines", [1, 4])
def test_counter_metric(test_registry, num_engines):
c = prometheus_client.Counter("vllm:test_counter",
"Test counter metric",
labelnames=["model", "engine_index"],
registry=test_registry)
for i in range(num_engines):
c.labels(model="bar", engine_index=str(i)).inc(19)
metrics = get_metrics_snapshot()
assert len(metrics) == num_engines
engine_labels = [str(i) for i in range(num_engines)]
for m in metrics:
assert isinstance(m, Counter)
assert m.name == "vllm:test_counter"
assert m.value == 19
assert m.labels["model"] == "bar"
assert m.labels["engine_index"] in engine_labels
engine_labels.remove(m.labels["engine_index"])
@pytest.mark.parametrize("num_engines", [1, 4])
def test_histogram_metric(test_registry, num_engines):
h = prometheus_client.Histogram("vllm:test_histogram",
"Test histogram metric",
labelnames=["model", "engine_index"],
buckets=[10, 20, 30, 40, 50],
registry=test_registry)
for i in range(num_engines):
hist = h.labels(model="blaa", engine_index=str(i))
hist.observe(42)
hist.observe(21)
hist.observe(7)
metrics = get_metrics_snapshot()
assert len(metrics) == num_engines
engine_labels = [str(i) for i in range(num_engines)]
for m in metrics:
assert isinstance(m, Histogram)
assert m.name == "vllm:test_histogram"
assert m.count == 3
assert m.sum == 70
assert m.buckets["10.0"] == 1
assert m.buckets["20.0"] == 1
assert m.buckets["30.0"] == 2
assert m.buckets["40.0"] == 2
assert m.buckets["50.0"] == 3
assert m.labels["model"] == "blaa"
assert m.labels["engine_index"] in engine_labels
engine_labels.remove(m.labels["engine_index"])
@pytest.mark.parametrize("num_engines", [1, 4])
def test_vector_metric(test_registry, num_engines):
c = prometheus_client.Counter(
"vllm:spec_decode_num_accepted_tokens_per_pos",
"Vector-like counter metric",
labelnames=["position", "model", "engine_index"],
registry=test_registry)
for i in range(num_engines):
c.labels(position="0", model="llama", engine_index=str(i)).inc(10)
c.labels(position="1", model="llama", engine_index=str(i)).inc(5)
c.labels(position="2", model="llama", engine_index=str(i)).inc(1)
metrics = get_metrics_snapshot()
assert len(metrics) == num_engines
engine_labels = [str(i) for i in range(num_engines)]
for m in metrics:
assert isinstance(m, Vector)
assert m.name == "vllm:spec_decode_num_accepted_tokens_per_pos"
assert m.values == [10, 5, 1]
assert m.labels["model"] == "llama"
assert m.labels["engine_index"] in engine_labels
engine_labels.remove(m.labels["engine_index"])

View File

@ -4,7 +4,8 @@ import itertools
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
cast, overload)
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
@ -47,6 +48,9 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of) is_list_of)
if TYPE_CHECKING:
from vllm.v1.metrics.reader import Metric
logger = init_logger(__name__) logger = init_logger(__name__)
_R = TypeVar("_R", default=Any) _R = TypeVar("_R", default=Any)
@ -1294,6 +1298,20 @@ class LLM:
""" """
self.llm_engine.wake_up(tags) self.llm_engine.wake_up(tags)
def get_metrics(self) -> list["Metric"]:
"""Return a snapshot of aggregated metrics from Prometheus.
Returns:
A ``MetricSnapshot`` instance capturing the current state
of all aggregated metrics from Prometheus.
Note:
This method is only available with the V1 LLM engine.
"""
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
assert isinstance(self.llm_engine, V1LLMEngine)
return self.llm_engine.get_metrics()
# LEGACY # LEGACY
def _convert_v1_inputs( def _convert_v1_inputs(
self, self,

View File

@ -27,7 +27,10 @@ from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__) logger = init_logger(__name__)
@ -64,6 +67,11 @@ class LLMEngine:
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.log_stats = log_stats
self.stat_logger: Optional[StatLoggerBase] = None
if self.log_stats:
self.stat_logger = PrometheusStatLogger(vllm_config)
# important: init dp group before init the engine_core # important: init dp group before init the engine_core
# In the decoupled engine case this is handled in EngineCoreProc. # In the decoupled engine case this is handled in EngineCoreProc.
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
@ -86,7 +94,7 @@ class LLMEngine:
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput). # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer, self.output_processor = OutputProcessor(self.tokenizer,
log_stats=False) log_stats=self.log_stats)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client( self.engine_core = EngineCoreClient.make_client(
@ -94,7 +102,7 @@ class LLMEngine:
asyncio_mode=False, asyncio_mode=False,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=False, # FIXME: implement log_stats=self.log_stats,
) )
if not multiprocess_mode: if not multiprocess_mode:
@ -223,12 +231,21 @@ class LLMEngine:
outputs = self.engine_core.get_output() outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs. # 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs( processed_outputs = self.output_processor.process_outputs(
outputs.outputs) outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats)
# 3) Abort any reqs that finished due to stop strings. # 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort) self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
if self.stat_logger is not None:
assert outputs.scheduler_stats is not None
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats)
return processed_outputs.request_outputs return processed_outputs.request_outputs
def get_vllm_config(self): def get_vllm_config(self):
@ -260,6 +277,10 @@ class LLMEngine:
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
def get_metrics(self) -> list[Metric]:
assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot()
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because " raise ValueError("Unable to get tokenizer because "

View File

@ -200,24 +200,24 @@ class PrometheusStatLogger(StatLoggerBase):
# Counters # Counters
# #
self.counter_num_preempted_reqs = self._counter_cls( self.counter_num_preempted_reqs = self._counter_cls(
name="vllm:num_preemptions_total", name="vllm:num_preemptions",
documentation="Cumulative number of preemption from the engine.", documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_prompt_tokens = self._counter_cls( self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_generation_tokens = self._counter_cls( self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens_total", name="vllm:generation_tokens",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_request_success: dict[FinishReason, self.counter_request_success: dict[FinishReason,
prometheus_client.Counter] = {} prometheus_client.Counter] = {}
counter_request_success_base = self._counter_cls( counter_request_success_base = self._counter_cls(
name="vllm:request_success_total", name="vllm:request_success",
documentation="Count of successfully processed requests.", documentation="Count of successfully processed requests.",
labelnames=labelnames + ["finished_reason"]) labelnames=labelnames + ["finished_reason"])
for reason in FinishReason: for reason in FinishReason:

245
vllm/v1/metrics/reader.py Normal file
View File

@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Optional
from prometheus_client import REGISTRY
from prometheus_client import Metric as PromMetric
from prometheus_client.samples import Sample
@dataclass
class Metric:
"""A base class for prometheus metrics.
Each metric may be associated with key=value labels, and
in some cases a single vLLM instance may have multiple
metrics with the same name but different sets of labels.
"""
name: str
labels: dict[str, str]
@dataclass
class Counter(Metric):
"""A monotonically increasing integer counter."""
value: int
@dataclass
class Vector(Metric):
"""An ordered array of integer counters.
This type - which doesn't exist in Prometheus - models one very
specific metric, vllm:spec_decode_num_accepted_tokens_per_pos.
"""
values: list[int]
@dataclass
class Gauge(Metric):
"""A numerical value that can go up or down."""
value: float
@dataclass
class Histogram(Metric):
"""Observations recorded in configurable buckets.
Buckets are represented by a dictionary. The key is
the upper limit of the bucket, and the value is the
observed count in that bucket. A '+Inf' key always
exists.
The count property is the total count across all
buckets, identical to the count of the '+Inf' bucket.
The sum property is the total sum of all observed
values.
"""
count: int
sum: float
buckets: dict[str, int]
def get_metrics_snapshot() -> list[Metric]:
"""An API for accessing in-memory Prometheus metrics.
Example:
>>> for metric in llm.get_metrics():
... if isinstance(metric, Counter):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Gauge):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Histogram):
... print(f"{metric}")
... print(f" sum = {metric.sum}")
... print(f" count = {metric.count}")
... for bucket_le, value in metrics.buckets.items():
... print(f" {bucket_le} = {value}")
"""
collected: list[Metric] = []
for metric in REGISTRY.collect():
if not metric.name.startswith("vllm:"):
continue
if metric.type == "gauge":
samples = _get_samples(metric)
for s in samples:
collected.append(
Gauge(name=metric.name, labels=s.labels, value=s.value))
elif metric.type == "counter":
samples = _get_samples(metric, "_total")
if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
#
# Ugly vllm:num_accepted_tokens_per_pos special case.
#
# This metric is a vector of counters - for each spec
# decoding token position, we observe the number of
# accepted tokens using a Counter labeled with 'position'.
# We convert these into a vector of integer values.
#
for labels, values in _digest_num_accepted_by_pos_samples(
samples):
collected.append(
Vector(name=metric.name, labels=labels, values=values))
else:
for s in samples:
collected.append(
Counter(name=metric.name,
labels=s.labels,
value=int(s.value)))
elif metric.type == "histogram":
#
# A histogram has a number of '_bucket' samples where
# the 'le' label represents the upper limit of the bucket.
# We convert these bucketized values into a dict of values
# indexed by the value of the 'le' label. The 'le=+Inf'
# label is a special case, catching all values observed.
#
bucket_samples = _get_samples(metric, "_bucket")
count_samples = _get_samples(metric, "_count")
sum_samples = _get_samples(metric, "_sum")
for labels, buckets, count_value, sum_value in _digest_histogram(
bucket_samples, count_samples, sum_samples):
collected.append(
Histogram(name=metric.name,
labels=labels,
buckets=buckets,
count=count_value,
sum=sum_value))
else:
raise AssertionError(f"Unknown metric type {metric.type}")
return collected
def _get_samples(metric: PromMetric,
suffix: Optional[str] = None) -> list[Sample]:
name = (metric.name + suffix) if suffix is not None else metric.name
return [s for s in metric.samples if s.name == name]
def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]:
labels_copy = labels.copy()
labels_copy.pop(key_to_remove)
return labels_copy
def _digest_histogram(
bucket_samples: list[Sample], count_samples: list[Sample],
sum_samples: list[Sample]
) -> list[tuple[dict[str, str], dict[str, int], int, float]]:
#
# In the case of DP, we have an indigestable
# per-bucket-per-engine count as a list of labelled
# samples, along with total and sum samples
#
# bucket_samples (in):
# labels = {bucket: 100, idx: 0}, value = 2
# labels = {bucket: 200, idx: 0}, value = 4
# labels = {bucket: Inf, idx: 0}, value = 10
# labels = {bucket: 100, idx: 1}, value = 1
# labels = {bucket: 200, idx: 2}, value = 5
# labels = {bucket: Inf, idx: 3}, value = 7
# count_samples (in):
# labels = {idx: 0}, value = 10
# labels = {idx: 1}, value = 7
# sum_samples (in):
# labels = {idx: 0}, value = 2000
# labels = {idx: 1}, value = 1200
#
# output: [
# {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000
# {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200
# ]
buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {}
for s in bucket_samples:
bucket = s.labels["le"]
labels_key = frozenset(_strip_label(s.labels, "le").items())
if labels_key not in buckets_by_labels:
buckets_by_labels[labels_key] = {}
buckets_by_labels[labels_key][bucket] = int(s.value)
counts_by_labels: dict[frozenset[tuple[str, str]], int] = {}
for s in count_samples:
labels_key = frozenset(s.labels.items())
counts_by_labels[labels_key] = int(s.value)
sums_by_labels: dict[frozenset[tuple[str, str]], float] = {}
for s in sum_samples:
labels_key = frozenset(s.labels.items())
sums_by_labels[labels_key] = s.value
assert set(buckets_by_labels.keys()) == set(
counts_by_labels.keys()) == set(sums_by_labels.keys())
output = []
label_keys = list(buckets_by_labels.keys())
for k in label_keys:
labels = dict(k)
output.append((labels, buckets_by_labels[k], counts_by_labels[k],
sums_by_labels[k]))
return output
def _digest_num_accepted_by_pos_samples(
samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]:
#
# In the case of DP, we have an indigestable
# per-position-per-engine count as a list of
# labelled samples
#
# samples (in):
# labels = {pos: 0, idx: 0}, value = 10
# labels = {pos: 1, idx: 0}, value = 7
# labels = {pos: 2, idx: 0}, value = 2
# labels = {pos: 0, idx: 1}, value = 5
# labels = {pos: 1, idx: 1}, value = 3
# labels = {pos: 2, idx: 1}, value = 1
#
# output: [
# {idx: 0}, [10, 7, 2]
# {idx: 1}, [5, 3, 1]
# ]
#
max_pos = 0
values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {}
for s in samples:
position = int(s.labels["position"])
max_pos = max(max_pos, position)
labels_key = frozenset(_strip_label(s.labels, "position").items())
if labels_key not in values_by_labels:
values_by_labels[labels_key] = {}
values_by_labels[labels_key][position] = int(s.value)
output = []
for labels_key, values_by_position in values_by_labels.items():
labels = dict(labels_key)
values = [0] * (max_pos + 1)
for pos, val in values_by_position.items():
values[pos] = val
output.append((labels, values))
return output

View File

@ -134,17 +134,17 @@ class SpecDecodingProm:
self.counter_spec_decode_num_drafts = \ self.counter_spec_decode_num_drafts = \
self._counter_cls( self._counter_cls(
name="vllm:spec_decode_num_drafts_total", name="vllm:spec_decode_num_drafts",
documentation="Number of spec decoding drafts.", documentation="Number of spec decoding drafts.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_draft_tokens = \ self.counter_spec_decode_num_draft_tokens = \
self._counter_cls( self._counter_cls(
name="vllm:spec_decode_num_draft_tokens_total", name="vllm:spec_decode_num_draft_tokens",
documentation="Number of draft tokens.", documentation="Number of draft tokens.",
labelnames=labelnames,).labels(*labelvalues) labelnames=labelnames,).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \ self.counter_spec_decode_num_accepted_tokens = \
self._counter_cls( self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total", name="vllm:spec_decode_num_accepted_tokens",
documentation="Number of accepted tokens.", documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)