From 340d7b1b217794c1b89e13936c66b920fd8d842d Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Thu, 24 Apr 2025 16:57:40 +0100 Subject: [PATCH] [V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics (#16665) Signed-off-by: Mark McLoughlin --- tests/v1/core/test_scheduler.py | 39 ++++++---- vllm/v1/core/sched/scheduler.py | 16 ++-- vllm/v1/metrics/loggers.py | 35 +++------ vllm/v1/spec_decode/metrics.py | 128 ++++++++++++++++++++++++++++---- 4 files changed, 158 insertions(+), 60 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 003f94259e2d..591284ec8541 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -6,7 +6,7 @@ import pytest import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, VllmConfig) + SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput @@ -31,6 +31,7 @@ def create_scheduler( num_blocks: int = 10000, block_size: int = 16, max_model_len: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ) -> Scheduler: '''Create scheduler under test. @@ -81,11 +82,17 @@ def create_scheduler( kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) if use_kv_connector else None + speculative_config: Optional[SpeculativeConfig] = None + if num_speculative_tokens is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=num_speculative_tokens) + vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests @@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): def test_stop_via_update_from_output(): """Test stopping behavior through update_from_output""" - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=1) # Test case 1: Stop on EOS token requests = create_requests(num_requests=2, max_tokens=10) @@ -480,7 +487,7 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [10, 11] # Test case 2: Stop on custom stop token - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43]) @@ -531,7 +538,7 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [13, 14] # Test case 3: Stop on max tokens - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=2, max_tokens=2) for req in requests: req.num_computed_tokens = req.num_tokens @@ -580,7 +587,7 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [13] # Test case 4: Ignore EOS flag - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=1, max_tokens=10) requests[0].sampling_params.ignore_eos = True requests[0].num_computed_tokens = requests[0].num_tokens @@ -682,13 +689,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ - ([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match - ([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences - ([[1]], [[1, 2]], (1, 1)), # single token sequence - ([[]], [[5]], (0, 0)), # empty sequence + ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match + ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch + ([[1, 2], [3]], [[1, 2, 5], [3, 4]], + (2, 3, 3, [2, 1])), # multiple sequences + ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence + ([[]], [[5]], (0, 0, 0, [0])), # empty sequence ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], - (6, 3)), # multiple mismatches + (2, 6, 3, [2, 1, 0])), # multiple mismatches ]) def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): """Test scheduling behavior with speculative decoding. @@ -697,7 +705,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): 1. Speculated tokens get scheduled correctly 2. Spec decoding stats properly count number of draft and accepted tokens """ - scheduler = create_scheduler() + num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) + scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) req_ids = [] req_to_index = {} @@ -770,8 +779,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): else: assert scheduler_stats.spec_decoding_stats is not None stats = scheduler_stats.spec_decoding_stats - assert stats.num_draft_tokens == expected[0] - assert stats.num_accepted_tokens == expected[1] + assert stats.num_drafts == expected[0] + assert stats.num_draft_tokens == expected[1] + assert stats.num_accepted_tokens == expected[2] + assert stats.num_accepted_tokens_per_pos == expected[3] def _assert_right_scheduler_output( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5805950f39b0..44dd9b026c2d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -122,11 +122,12 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) - self.num_lookahead_tokens = 0 speculative_config = vllm_config.speculative_config - if speculative_config and speculative_config.method == "eagle": - self.num_lookahead_tokens = \ - speculative_config.num_speculative_tokens + self.num_spec_tokens = self.num_lookahead_tokens = 0 + if speculative_config: + self.num_spec_tokens = speculative_config.num_speculative_tokens + if speculative_config.method == "eagle": + self.num_lookahead_tokens = self.num_spec_tokens def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: @@ -824,7 +825,8 @@ class Scheduler(SchedulerInterface): if not self.log_stats: return None if spec_decoding_stats is None: - spec_decoding_stats = SpecDecodingStats() - spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) + spec_decoding_stats.observe_draft( + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) return spec_decoding_stats diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 4d70f27f8080..547e60467632 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.stats import IterationStats, SchedulerStats -from vllm.v1.spec_decode.metrics import SpecDecodingMetrics +from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) @@ -39,7 +39,7 @@ class LoggingStatLogger(StatLoggerBase): # Prefix cache metrics. This cannot be reset. # TODO: Make the interval configurable. self.prefix_caching_metrics = PrefixCachingMetrics() - self.spec_decoding_metrics = SpecDecodingMetrics() + self.spec_decoding_logging = SpecDecodingLogging() self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 @@ -70,7 +70,7 @@ class LoggingStatLogger(StatLoggerBase): self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_metrics.observe( + self.spec_decoding_logging.observe( scheduler_stats.spec_decoding_stats) self.last_scheduler_stats = scheduler_stats @@ -112,7 +112,7 @@ class LoggingStatLogger(StatLoggerBase): ) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_metrics.log(log_fn=log_fn) + self.spec_decoding_logging.log(log_fn=log_fn) class PrometheusStatLogger(StatLoggerBase): @@ -133,6 +133,9 @@ class PrometheusStatLogger(StatLoggerBase): max_model_len = vllm_config.model_config.max_model_len + self.spec_decoding_prom = SpecDecodingProm( + vllm_config.speculative_config, labelnames, labelvalues) + # # Scheduler state # @@ -323,24 +326,6 @@ class PrometheusStatLogger(StatLoggerBase): self.labelname_running_lora_adapters, ]) - # - # Speculative Decoding metrics - # The acceptance rate can be calculated using a PromQL query: - # - # rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / - # rate(vllm:spec_decode_num_draft_tokens_total[$interval]) - # - self.counter_spec_decode_num_draft_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_draft_tokens_total", - documentation="Number of draft tokens.", - labelnames=labelnames).labels(*labelvalues) - self.counter_spec_decode_num_accepted_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_accepted_tokens_total", - documentation="Number of accepted tokens.", - labelnames=labelnames).labels(*labelvalues) - # # Cache config info metric # @@ -378,10 +363,8 @@ class PrometheusStatLogger(StatLoggerBase): scheduler_stats.prefix_cache_stats.hits) if scheduler_stats.spec_decoding_stats is not None: - self.counter_spec_decode_num_draft_tokens.inc( - scheduler_stats.spec_decoding_stats.num_draft_tokens) - self.counter_spec_decode_num_accepted_tokens.inc( - scheduler_stats.spec_decoding_stats.num_accepted_tokens) + self.spec_decoding_prom.observe( + scheduler_stats.spec_decoding_stats) if iteration_stats is None: return diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index cc453b74f7eb..33ce98284e20 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional import numpy as np +import prometheus_client +from vllm.config import SpeculativeConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -11,52 +14,151 @@ logger = init_logger(__name__) @dataclass class SpecDecodingStats: + """Per-step iteration decoding stats from scheduler. + + Each scheduler step, statistics on spec decoding performance are + aggregated across requests by the scheduler and returned to the + frontend in EngineCoreOutputs->SchedulerStats. + """ + + num_spec_tokens: int + num_drafts: int = 0 num_draft_tokens: int = 0 num_accepted_tokens: int = 0 + num_accepted_tokens_per_pos: list[int] = field(default_factory=list) - def take(self): - copied = SpecDecodingStats(self.num_draft_tokens, - self.num_accepted_tokens) - self.reset() - return copied + @classmethod + def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": + return cls(num_spec_tokens=num_spec_tokens, + num_accepted_tokens_per_pos=[0] * num_spec_tokens) - def reset(self): - self.num_draft_tokens = 0 - self.num_accepted_tokens = 0 - - def observe(self, num_draft_tokens: int, num_accepted_tokens: int): + def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): + self.num_drafts += 1 self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens + assert num_accepted_tokens <= self.num_spec_tokens + for i in range(num_accepted_tokens): + self.num_accepted_tokens_per_pos[i] += 1 -class SpecDecodingMetrics: +class SpecDecodingLogging: + """Aggregate and log spec decoding metrics. + + LoggingStatLogger aggregates per-iteration metrics over a set + time interval using observe() and then logs them using log() + before resetting to zero. + """ def __init__(self): self.reset() def reset(self): + self.num_drafts: list[int] = [] self.num_draft_tokens: list[int] = [] self.num_accepted_tokens: list[int] = [] + self.accepted_tokens_per_pos_lists: list[list[int]] = [] def observe(self, spec_decoding_stats: SpecDecodingStats): + self.num_drafts.append(spec_decoding_stats.num_drafts) self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) self.num_accepted_tokens.append( spec_decoding_stats.num_accepted_tokens) + self.accepted_tokens_per_pos_lists.append( + spec_decoding_stats.num_accepted_tokens_per_pos) def log(self, log_fn=logger.info): + num_drafts = np.sum(self.num_drafts) num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * 100 if num_draft_tokens > 0 else float("nan")) + mean_acceptance_length = (num_accepted_tokens / num_drafts) + + pos_matrix = np.array(self.accepted_tokens_per_pos_lists) + acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts + rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates) log_fn( "SpecDecoding metrics: " "Draft acceptance rate: %.1f%%, " + "Mean acceptance length: %.2f, " "Accepted: %d tokens, " - "Drafted: %d tokens", + "Drafted: %d tokens, " + "Per-position acceptance rate: %s", draft_acceptance_rate, + mean_acceptance_length, num_accepted_tokens, num_draft_tokens, + rates_str, ) self.reset() + + +class SpecDecodingProm: + """Record spec decoding metrics in Prometheus. + + The acceptance rate can be calculated using a PromQL query: + + rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / + rate(vllm:spec_decode_num_draft_tokens_total[$interval]) + + The mean acceptance length can be calculated using: + + rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / + rate(vllm:spec_decode_num_drafts[$interval]) + + A per-position acceptance rate vector can be computed using + + vllm:spec_decode_num_accepted_tokens_per_pos[$interval] / + vllm:spec_decode_num_drafts[$interval] + """ + + def __init__(self, speculative_config: Optional[SpeculativeConfig], + labelnames: list[str], labelvalues: list[str]): + self.spec_decoding_enabled = speculative_config is not None + if not self.spec_decoding_enabled: + return + + self.counter_spec_decode_num_drafts = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_drafts_total", + documentation="Number of spec decoding drafts.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_draft_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_accepted_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames).labels(*labelvalues) + + assert speculative_config is not None + num_spec_tokens = (speculative_config.num_speculative_tokens + if self.spec_decoding_enabled else 0) + pos_labelnames = labelnames + ["position"] + base_counter = prometheus_client.Counter( + name="vllm:spec_decode_num_accepted_tokens_per_pos", + documentation="Accepted tokens per draft position.", + labelnames=pos_labelnames) + self.counter_spec_decode_num_accepted_tokens_per_pos: \ + list[prometheus_client.Counter] = [] + for pos in range(num_spec_tokens): + pos_labelvalues = labelvalues + [str(pos)] + self.counter_spec_decode_num_accepted_tokens_per_pos.append( + base_counter.labels(*pos_labelvalues)) + + def observe(self, spec_decoding_stats: SpecDecodingStats): + if not self.spec_decoding_enabled: + return + self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts) + self.counter_spec_decode_num_draft_tokens.inc( + spec_decoding_stats.num_draft_tokens) + self.counter_spec_decode_num_accepted_tokens.inc( + spec_decoding_stats.num_accepted_tokens) + for pos, counter in enumerate( + self.counter_spec_decode_num_accepted_tokens_per_pos): + counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])