mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics (#16665)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
1bcbcbf574
commit
340d7b1b21
@ -6,7 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -31,6 +31,7 @@ def create_scheduler(
|
||||
num_blocks: int = 10000,
|
||||
block_size: int = 16,
|
||||
max_model_len: Optional[int] = None,
|
||||
num_speculative_tokens: Optional[int] = None,
|
||||
) -> Scheduler:
|
||||
'''Create scheduler under test.
|
||||
|
||||
@ -81,11 +82,17 @@ def create_scheduler(
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
) if use_kv_connector else None
|
||||
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
if num_speculative_tokens is not None:
|
||||
speculative_config = SpeculativeConfig(
|
||||
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
|
||||
def test_stop_via_update_from_output():
|
||||
"""Test stopping behavior through update_from_output"""
|
||||
scheduler = create_scheduler()
|
||||
scheduler = create_scheduler(num_speculative_tokens=1)
|
||||
|
||||
# Test case 1: Stop on EOS token
|
||||
requests = create_requests(num_requests=2, max_tokens=10)
|
||||
@ -480,7 +487,7 @@ def test_stop_via_update_from_output():
|
||||
assert list(requests[1].output_token_ids) == [10, 11]
|
||||
|
||||
# Test case 2: Stop on custom stop token
|
||||
scheduler = create_scheduler()
|
||||
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||
requests = create_requests(num_requests=2,
|
||||
max_tokens=10,
|
||||
stop_token_ids=[42, 43])
|
||||
@ -531,7 +538,7 @@ def test_stop_via_update_from_output():
|
||||
assert list(requests[1].output_token_ids) == [13, 14]
|
||||
|
||||
# Test case 3: Stop on max tokens
|
||||
scheduler = create_scheduler()
|
||||
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||
requests = create_requests(num_requests=2, max_tokens=2)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
@ -580,7 +587,7 @@ def test_stop_via_update_from_output():
|
||||
assert list(requests[1].output_token_ids) == [13]
|
||||
|
||||
# Test case 4: Ignore EOS flag
|
||||
scheduler = create_scheduler()
|
||||
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||
requests = create_requests(num_requests=1, max_tokens=10)
|
||||
requests[0].sampling_params.ignore_eos = True
|
||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||
@ -682,13 +689,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
[
|
||||
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
|
||||
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
|
||||
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
|
||||
([[1]], [[1, 2]], (1, 1)), # single token sequence
|
||||
([[]], [[5]], (0, 0)), # empty sequence
|
||||
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
|
||||
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
|
||||
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
|
||||
(2, 3, 3, [2, 1])), # multiple sequences
|
||||
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
|
||||
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
|
||||
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
|
||||
(6, 3)), # multiple mismatches
|
||||
(2, 6, 3, [2, 1, 0])), # multiple mismatches
|
||||
])
|
||||
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
"""Test scheduling behavior with speculative decoding.
|
||||
@ -697,7 +705,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
1. Speculated tokens get scheduled correctly
|
||||
2. Spec decoding stats properly count number of draft and accepted tokens
|
||||
"""
|
||||
scheduler = create_scheduler()
|
||||
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
|
||||
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
|
||||
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
@ -770,8 +779,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
else:
|
||||
assert scheduler_stats.spec_decoding_stats is not None
|
||||
stats = scheduler_stats.spec_decoding_stats
|
||||
assert stats.num_draft_tokens == expected[0]
|
||||
assert stats.num_accepted_tokens == expected[1]
|
||||
assert stats.num_drafts == expected[0]
|
||||
assert stats.num_draft_tokens == expected[1]
|
||||
assert stats.num_accepted_tokens == expected[2]
|
||||
assert stats.num_accepted_tokens_per_pos == expected[3]
|
||||
|
||||
|
||||
def _assert_right_scheduler_output(
|
||||
|
||||
@ -122,11 +122,12 @@ class Scheduler(SchedulerInterface):
|
||||
self.encoder_cache_manager = EncoderCacheManager(
|
||||
cache_size=encoder_cache_size)
|
||||
|
||||
self.num_lookahead_tokens = 0
|
||||
speculative_config = vllm_config.speculative_config
|
||||
if speculative_config and speculative_config.method == "eagle":
|
||||
self.num_lookahead_tokens = \
|
||||
speculative_config.num_speculative_tokens
|
||||
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||
if speculative_config:
|
||||
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
||||
if speculative_config.method == "eagle":
|
||||
self.num_lookahead_tokens = self.num_spec_tokens
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
@ -824,7 +825,8 @@ class Scheduler(SchedulerInterface):
|
||||
if not self.log_stats:
|
||||
return None
|
||||
if spec_decoding_stats is None:
|
||||
spec_decoding_stats = SpecDecodingStats()
|
||||
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
|
||||
num_accepted_tokens=num_accepted_tokens)
|
||||
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
|
||||
spec_decoding_stats.observe_draft(
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
num_accepted_tokens=num_accepted_tokens)
|
||||
return spec_decoding_stats
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||
from vllm.v1.engine import FinishReason
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -39,7 +39,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
# Prefix cache metrics. This cannot be reset.
|
||||
# TODO: Make the interval configurable.
|
||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||
self.spec_decoding_metrics = SpecDecodingMetrics()
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
self.last_prompt_throughput: float = 0.0
|
||||
self.last_generation_throughput: float = 0.0
|
||||
|
||||
@ -70,7 +70,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_metrics.observe(
|
||||
self.spec_decoding_logging.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
@ -112,7 +112,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_metrics.log(log_fn=log_fn)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
@ -133,6 +133,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
self.spec_decoding_prom = SpecDecodingProm(
|
||||
vllm_config.speculative_config, labelnames, labelvalues)
|
||||
|
||||
#
|
||||
# Scheduler state
|
||||
#
|
||||
@ -323,24 +326,6 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.labelname_running_lora_adapters,
|
||||
])
|
||||
|
||||
#
|
||||
# Speculative Decoding metrics
|
||||
# The acceptance rate can be calculated using a PromQL query:
|
||||
#
|
||||
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
|
||||
#
|
||||
self.counter_spec_decode_num_draft_tokens = \
|
||||
prometheus_client.Counter(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_accepted_tokens = \
|
||||
prometheus_client.Counter(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
#
|
||||
# Cache config info metric
|
||||
#
|
||||
@ -378,10 +363,8 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
scheduler_stats.prefix_cache_stats.hits)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.counter_spec_decode_num_draft_tokens.inc(
|
||||
scheduler_stats.spec_decoding_stats.num_draft_tokens)
|
||||
self.counter_spec_decode_num_accepted_tokens.inc(
|
||||
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
|
||||
self.spec_decoding_prom.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import prometheus_client
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -11,52 +14,151 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class SpecDecodingStats:
|
||||
"""Per-step iteration decoding stats from scheduler.
|
||||
|
||||
Each scheduler step, statistics on spec decoding performance are
|
||||
aggregated across requests by the scheduler and returned to the
|
||||
frontend in EngineCoreOutputs->SchedulerStats.
|
||||
"""
|
||||
|
||||
num_spec_tokens: int
|
||||
num_drafts: int = 0
|
||||
num_draft_tokens: int = 0
|
||||
num_accepted_tokens: int = 0
|
||||
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
|
||||
|
||||
def take(self):
|
||||
copied = SpecDecodingStats(self.num_draft_tokens,
|
||||
self.num_accepted_tokens)
|
||||
self.reset()
|
||||
return copied
|
||||
@classmethod
|
||||
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
|
||||
return cls(num_spec_tokens=num_spec_tokens,
|
||||
num_accepted_tokens_per_pos=[0] * num_spec_tokens)
|
||||
|
||||
def reset(self):
|
||||
self.num_draft_tokens = 0
|
||||
self.num_accepted_tokens = 0
|
||||
|
||||
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
|
||||
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
|
||||
self.num_drafts += 1
|
||||
self.num_draft_tokens += num_draft_tokens
|
||||
self.num_accepted_tokens += num_accepted_tokens
|
||||
assert num_accepted_tokens <= self.num_spec_tokens
|
||||
for i in range(num_accepted_tokens):
|
||||
self.num_accepted_tokens_per_pos[i] += 1
|
||||
|
||||
|
||||
class SpecDecodingMetrics:
|
||||
class SpecDecodingLogging:
|
||||
"""Aggregate and log spec decoding metrics.
|
||||
|
||||
LoggingStatLogger aggregates per-iteration metrics over a set
|
||||
time interval using observe() and then logs them using log()
|
||||
before resetting to zero.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.num_drafts: list[int] = []
|
||||
self.num_draft_tokens: list[int] = []
|
||||
self.num_accepted_tokens: list[int] = []
|
||||
self.accepted_tokens_per_pos_lists: list[list[int]] = []
|
||||
|
||||
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||
self.num_drafts.append(spec_decoding_stats.num_drafts)
|
||||
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
|
||||
self.num_accepted_tokens.append(
|
||||
spec_decoding_stats.num_accepted_tokens)
|
||||
self.accepted_tokens_per_pos_lists.append(
|
||||
spec_decoding_stats.num_accepted_tokens_per_pos)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
num_drafts = np.sum(self.num_drafts)
|
||||
num_draft_tokens = np.sum(self.num_draft_tokens)
|
||||
num_accepted_tokens = np.sum(self.num_accepted_tokens)
|
||||
|
||||
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
|
||||
100 if num_draft_tokens > 0 else float("nan"))
|
||||
mean_acceptance_length = (num_accepted_tokens / num_drafts)
|
||||
|
||||
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
|
||||
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
|
||||
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
|
||||
|
||||
log_fn(
|
||||
"SpecDecoding metrics: "
|
||||
"Draft acceptance rate: %.1f%%, "
|
||||
"Mean acceptance length: %.2f, "
|
||||
"Accepted: %d tokens, "
|
||||
"Drafted: %d tokens",
|
||||
"Drafted: %d tokens, "
|
||||
"Per-position acceptance rate: %s",
|
||||
draft_acceptance_rate,
|
||||
mean_acceptance_length,
|
||||
num_accepted_tokens,
|
||||
num_draft_tokens,
|
||||
rates_str,
|
||||
)
|
||||
self.reset()
|
||||
|
||||
|
||||
class SpecDecodingProm:
|
||||
"""Record spec decoding metrics in Prometheus.
|
||||
|
||||
The acceptance rate can be calculated using a PromQL query:
|
||||
|
||||
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
|
||||
|
||||
The mean acceptance length can be calculated using:
|
||||
|
||||
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||
rate(vllm:spec_decode_num_drafts[$interval])
|
||||
|
||||
A per-position acceptance rate vector can be computed using
|
||||
|
||||
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
|
||||
vllm:spec_decode_num_drafts[$interval]
|
||||
"""
|
||||
|
||||
def __init__(self, speculative_config: Optional[SpeculativeConfig],
|
||||
labelnames: list[str], labelvalues: list[str]):
|
||||
self.spec_decoding_enabled = speculative_config is not None
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
|
||||
self.counter_spec_decode_num_drafts = \
|
||||
prometheus_client.Counter(
|
||||
name="vllm:spec_decode_num_drafts_total",
|
||||
documentation="Number of spec decoding drafts.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_draft_tokens = \
|
||||
prometheus_client.Counter(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_accepted_tokens = \
|
||||
prometheus_client.Counter(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
assert speculative_config is not None
|
||||
num_spec_tokens = (speculative_config.num_speculative_tokens
|
||||
if self.spec_decoding_enabled else 0)
|
||||
pos_labelnames = labelnames + ["position"]
|
||||
base_counter = prometheus_client.Counter(
|
||||
name="vllm:spec_decode_num_accepted_tokens_per_pos",
|
||||
documentation="Accepted tokens per draft position.",
|
||||
labelnames=pos_labelnames)
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos: \
|
||||
list[prometheus_client.Counter] = []
|
||||
for pos in range(num_spec_tokens):
|
||||
pos_labelvalues = labelvalues + [str(pos)]
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
|
||||
base_counter.labels(*pos_labelvalues))
|
||||
|
||||
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts)
|
||||
self.counter_spec_decode_num_draft_tokens.inc(
|
||||
spec_decoding_stats.num_draft_tokens)
|
||||
self.counter_spec_decode_num_accepted_tokens.inc(
|
||||
spec_decoding_stats.num_accepted_tokens)
|
||||
for pos, counter in enumerate(
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos):
|
||||
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user