[V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics (#16665)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-04-24 16:57:40 +01:00 committed by GitHub
parent 1bcbcbf574
commit 340d7b1b21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 158 additions and 60 deletions

View File

@ -6,7 +6,7 @@ import pytest
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, VllmConfig)
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
@ -31,6 +31,7 @@ def create_scheduler(
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
) -> Scheduler:
'''Create scheduler under test.
@ -81,11 +82,17 @@ def create_scheduler(
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler()
scheduler = create_scheduler(num_speculative_tokens=1)
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
@ -480,7 +487,7 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token
scheduler = create_scheduler()
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
@ -531,7 +538,7 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens
scheduler = create_scheduler()
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
@ -580,7 +587,7 @@ def test_stop_via_update_from_output():
assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag
scheduler = create_scheduler()
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
@ -682,13 +689,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
([[1]], [[1, 2]], (1, 1)), # single token sequence
([[]], [[5]], (0, 0)), # empty sequence
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
(2, 3, 3, [2, 1])), # multiple sequences
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(6, 3)), # multiple mismatches
(2, 6, 3, [2, 1, 0])), # multiple mismatches
])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.
@ -697,7 +705,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
scheduler = create_scheduler()
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
req_to_index = {}
@ -770,8 +779,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
else:
assert scheduler_stats.spec_decoding_stats is not None
stats = scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0]
assert stats.num_accepted_tokens == expected[1]
assert stats.num_drafts == expected[0]
assert stats.num_draft_tokens == expected[1]
assert stats.num_accepted_tokens == expected[2]
assert stats.num_accepted_tokens_per_pos == expected[3]
def _assert_right_scheduler_output(

View File

@ -122,11 +122,12 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)
self.num_lookahead_tokens = 0
speculative_config = vllm_config.speculative_config
if speculative_config and speculative_config.method == "eagle":
self.num_lookahead_tokens = \
speculative_config.num_speculative_tokens
self.num_spec_tokens = self.num_lookahead_tokens = 0
if speculative_config:
self.num_spec_tokens = speculative_config.num_speculative_tokens
if speculative_config.method == "eagle":
self.num_lookahead_tokens = self.num_spec_tokens
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
@ -824,7 +825,8 @@ class Scheduler(SchedulerInterface):
if not self.log_stats:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats()
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
spec_decoding_stats.observe_draft(
num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
return spec_decoding_stats

View File

@ -12,7 +12,7 @@ from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
@ -39,7 +39,7 @@ class LoggingStatLogger(StatLoggerBase):
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_metrics = SpecDecodingMetrics()
self.spec_decoding_logging = SpecDecodingLogging()
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
@ -70,7 +70,7 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.observe(
self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats
@ -112,7 +112,7 @@ class LoggingStatLogger(StatLoggerBase):
)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.log(log_fn=log_fn)
self.spec_decoding_logging.log(log_fn=log_fn)
class PrometheusStatLogger(StatLoggerBase):
@ -133,6 +133,9 @@ class PrometheusStatLogger(StatLoggerBase):
max_model_len = vllm_config.model_config.max_model_len
self.spec_decoding_prom = SpecDecodingProm(
vllm_config.speculative_config, labelnames, labelvalues)
#
# Scheduler state
#
@ -323,24 +326,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.labelname_running_lora_adapters,
])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
#
# Cache config info metric
#
@ -378,10 +363,8 @@ class PrometheusStatLogger(StatLoggerBase):
scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.counter_spec_decode_num_draft_tokens.inc(
scheduler_stats.spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats)
if iteration_stats is None:
return

View File

@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import prometheus_client
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -11,52 +14,151 @@ logger = init_logger(__name__)
@dataclass
class SpecDecodingStats:
"""Per-step iteration decoding stats from scheduler.
Each scheduler step, statistics on spec decoding performance are
aggregated across requests by the scheduler and returned to the
frontend in EngineCoreOutputs->SchedulerStats.
"""
num_spec_tokens: int
num_drafts: int = 0
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
def take(self):
copied = SpecDecodingStats(self.num_draft_tokens,
self.num_accepted_tokens)
self.reset()
return copied
@classmethod
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
return cls(num_spec_tokens=num_spec_tokens,
num_accepted_tokens_per_pos=[0] * num_spec_tokens)
def reset(self):
self.num_draft_tokens = 0
self.num_accepted_tokens = 0
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_drafts += 1
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
assert num_accepted_tokens <= self.num_spec_tokens
for i in range(num_accepted_tokens):
self.num_accepted_tokens_per_pos[i] += 1
class SpecDecodingMetrics:
class SpecDecodingLogging:
"""Aggregate and log spec decoding metrics.
LoggingStatLogger aggregates per-iteration metrics over a set
time interval using observe() and then logs them using log()
before resetting to zero.
"""
def __init__(self):
self.reset()
def reset(self):
self.num_drafts: list[int] = []
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []
self.accepted_tokens_per_pos_lists: list[list[int]] = []
def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_drafts.append(spec_decoding_stats.num_drafts)
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(
spec_decoding_stats.num_accepted_tokens)
self.accepted_tokens_per_pos_lists.append(
spec_decoding_stats.num_accepted_tokens_per_pos)
def log(self, log_fn=logger.info):
num_drafts = np.sum(self.num_drafts)
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
100 if num_draft_tokens > 0 else float("nan"))
mean_acceptance_length = (num_accepted_tokens / num_drafts)
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
log_fn(
"SpecDecoding metrics: "
"Draft acceptance rate: %.1f%%, "
"Mean acceptance length: %.2f, "
"Accepted: %d tokens, "
"Drafted: %d tokens",
"Drafted: %d tokens, "
"Per-position acceptance rate: %s",
draft_acceptance_rate,
mean_acceptance_length,
num_accepted_tokens,
num_draft_tokens,
rates_str,
)
self.reset()
class SpecDecodingProm:
"""Record spec decoding metrics in Prometheus.
The acceptance rate can be calculated using a PromQL query:
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
The mean acceptance length can be calculated using:
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_drafts[$interval])
A per-position acceptance rate vector can be computed using
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
vllm:spec_decode_num_drafts[$interval]
"""
def __init__(self, speculative_config: Optional[SpeculativeConfig],
labelnames: list[str], labelvalues: list[str]):
self.spec_decoding_enabled = speculative_config is not None
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts = \
prometheus_client.Counter(
name="vllm:spec_decode_num_drafts_total",
documentation="Number of spec decoding drafts.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
assert speculative_config is not None
num_spec_tokens = (speculative_config.num_speculative_tokens
if self.spec_decoding_enabled else 0)
pos_labelnames = labelnames + ["position"]
base_counter = prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_per_pos",
documentation="Accepted tokens per draft position.",
labelnames=pos_labelnames)
self.counter_spec_decode_num_accepted_tokens_per_pos: \
list[prometheus_client.Counter] = []
for pos in range(num_spec_tokens):
pos_labelvalues = labelvalues + [str(pos)]
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
base_counter.labels(*pos_labelvalues))
def observe(self, spec_decoding_stats: SpecDecodingStats):
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts)
self.counter_spec_decode_num_draft_tokens.inc(
spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
spec_decoding_stats.num_accepted_tokens)
for pos, counter in enumerate(
self.counter_spec_decode_num_accepted_tokens_per_pos):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])