mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
[V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics (#16665)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
1bcbcbf574
commit
340d7b1b21
@ -6,7 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||||
SchedulerConfig, VllmConfig)
|
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -31,6 +31,7 @@ def create_scheduler(
|
|||||||
num_blocks: int = 10000,
|
num_blocks: int = 10000,
|
||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
max_model_len: Optional[int] = None,
|
max_model_len: Optional[int] = None,
|
||||||
|
num_speculative_tokens: Optional[int] = None,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
'''Create scheduler under test.
|
'''Create scheduler under test.
|
||||||
|
|
||||||
@ -81,11 +82,17 @@ def create_scheduler(
|
|||||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||||
) if use_kv_connector else None
|
) if use_kv_connector else None
|
||||||
|
|
||||||
|
speculative_config: Optional[SpeculativeConfig] = None
|
||||||
|
if num_speculative_tokens is not None:
|
||||||
|
speculative_config = SpeculativeConfig(
|
||||||
|
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
kv_transfer_config=kv_transfer_config,
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
)
|
)
|
||||||
kv_cache_config = KVCacheConfig(
|
kv_cache_config = KVCacheConfig(
|
||||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||||
@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
|
|
||||||
def test_stop_via_update_from_output():
|
def test_stop_via_update_from_output():
|
||||||
"""Test stopping behavior through update_from_output"""
|
"""Test stopping behavior through update_from_output"""
|
||||||
scheduler = create_scheduler()
|
scheduler = create_scheduler(num_speculative_tokens=1)
|
||||||
|
|
||||||
# Test case 1: Stop on EOS token
|
# Test case 1: Stop on EOS token
|
||||||
requests = create_requests(num_requests=2, max_tokens=10)
|
requests = create_requests(num_requests=2, max_tokens=10)
|
||||||
@ -480,7 +487,7 @@ def test_stop_via_update_from_output():
|
|||||||
assert list(requests[1].output_token_ids) == [10, 11]
|
assert list(requests[1].output_token_ids) == [10, 11]
|
||||||
|
|
||||||
# Test case 2: Stop on custom stop token
|
# Test case 2: Stop on custom stop token
|
||||||
scheduler = create_scheduler()
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||||
requests = create_requests(num_requests=2,
|
requests = create_requests(num_requests=2,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
stop_token_ids=[42, 43])
|
stop_token_ids=[42, 43])
|
||||||
@ -531,7 +538,7 @@ def test_stop_via_update_from_output():
|
|||||||
assert list(requests[1].output_token_ids) == [13, 14]
|
assert list(requests[1].output_token_ids) == [13, 14]
|
||||||
|
|
||||||
# Test case 3: Stop on max tokens
|
# Test case 3: Stop on max tokens
|
||||||
scheduler = create_scheduler()
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||||
requests = create_requests(num_requests=2, max_tokens=2)
|
requests = create_requests(num_requests=2, max_tokens=2)
|
||||||
for req in requests:
|
for req in requests:
|
||||||
req.num_computed_tokens = req.num_tokens
|
req.num_computed_tokens = req.num_tokens
|
||||||
@ -580,7 +587,7 @@ def test_stop_via_update_from_output():
|
|||||||
assert list(requests[1].output_token_ids) == [13]
|
assert list(requests[1].output_token_ids) == [13]
|
||||||
|
|
||||||
# Test case 4: Ignore EOS flag
|
# Test case 4: Ignore EOS flag
|
||||||
scheduler = create_scheduler()
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||||
requests = create_requests(num_requests=1, max_tokens=10)
|
requests = create_requests(num_requests=1, max_tokens=10)
|
||||||
requests[0].sampling_params.ignore_eos = True
|
requests[0].sampling_params.ignore_eos = True
|
||||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||||
@ -682,13 +689,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"spec_tokens,output_tokens,expected",
|
"spec_tokens,output_tokens,expected",
|
||||||
[
|
[
|
||||||
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
|
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
|
||||||
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
|
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
|
||||||
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
|
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
|
||||||
([[1]], [[1, 2]], (1, 1)), # single token sequence
|
(2, 3, 3, [2, 1])), # multiple sequences
|
||||||
([[]], [[5]], (0, 0)), # empty sequence
|
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
|
||||||
|
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
|
||||||
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
|
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
|
||||||
(6, 3)), # multiple mismatches
|
(2, 6, 3, [2, 1, 0])), # multiple mismatches
|
||||||
])
|
])
|
||||||
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||||
"""Test scheduling behavior with speculative decoding.
|
"""Test scheduling behavior with speculative decoding.
|
||||||
@ -697,7 +705,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
1. Speculated tokens get scheduled correctly
|
1. Speculated tokens get scheduled correctly
|
||||||
2. Spec decoding stats properly count number of draft and accepted tokens
|
2. Spec decoding stats properly count number of draft and accepted tokens
|
||||||
"""
|
"""
|
||||||
scheduler = create_scheduler()
|
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
|
||||||
|
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
|
||||||
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
||||||
req_ids = []
|
req_ids = []
|
||||||
req_to_index = {}
|
req_to_index = {}
|
||||||
@ -770,8 +779,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
else:
|
else:
|
||||||
assert scheduler_stats.spec_decoding_stats is not None
|
assert scheduler_stats.spec_decoding_stats is not None
|
||||||
stats = scheduler_stats.spec_decoding_stats
|
stats = scheduler_stats.spec_decoding_stats
|
||||||
assert stats.num_draft_tokens == expected[0]
|
assert stats.num_drafts == expected[0]
|
||||||
assert stats.num_accepted_tokens == expected[1]
|
assert stats.num_draft_tokens == expected[1]
|
||||||
|
assert stats.num_accepted_tokens == expected[2]
|
||||||
|
assert stats.num_accepted_tokens_per_pos == expected[3]
|
||||||
|
|
||||||
|
|
||||||
def _assert_right_scheduler_output(
|
def _assert_right_scheduler_output(
|
||||||
|
|||||||
@ -122,11 +122,12 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.encoder_cache_manager = EncoderCacheManager(
|
self.encoder_cache_manager = EncoderCacheManager(
|
||||||
cache_size=encoder_cache_size)
|
cache_size=encoder_cache_size)
|
||||||
|
|
||||||
self.num_lookahead_tokens = 0
|
|
||||||
speculative_config = vllm_config.speculative_config
|
speculative_config = vllm_config.speculative_config
|
||||||
if speculative_config and speculative_config.method == "eagle":
|
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||||
self.num_lookahead_tokens = \
|
if speculative_config:
|
||||||
speculative_config.num_speculative_tokens
|
self.num_spec_tokens = speculative_config.num_speculative_tokens
|
||||||
|
if speculative_config.method == "eagle":
|
||||||
|
self.num_lookahead_tokens = self.num_spec_tokens
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
# NOTE(woosuk) on the scheduling algorithm:
|
# NOTE(woosuk) on the scheduling algorithm:
|
||||||
@ -824,7 +825,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return None
|
return None
|
||||||
if spec_decoding_stats is None:
|
if spec_decoding_stats is None:
|
||||||
spec_decoding_stats = SpecDecodingStats()
|
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
|
||||||
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
|
spec_decoding_stats.observe_draft(
|
||||||
num_accepted_tokens=num_accepted_tokens)
|
num_draft_tokens=num_draft_tokens,
|
||||||
|
num_accepted_tokens=num_accepted_tokens)
|
||||||
return spec_decoding_stats
|
return spec_decoding_stats
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||||
from vllm.v1.engine import FinishReason
|
from vllm.v1.engine import FinishReason
|
||||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||||
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
|
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
# Prefix cache metrics. This cannot be reset.
|
# Prefix cache metrics. This cannot be reset.
|
||||||
# TODO: Make the interval configurable.
|
# TODO: Make the interval configurable.
|
||||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||||
self.spec_decoding_metrics = SpecDecodingMetrics()
|
self.spec_decoding_logging = SpecDecodingLogging()
|
||||||
self.last_prompt_throughput: float = 0.0
|
self.last_prompt_throughput: float = 0.0
|
||||||
self.last_generation_throughput: float = 0.0
|
self.last_generation_throughput: float = 0.0
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
|
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
|
||||||
|
|
||||||
if scheduler_stats.spec_decoding_stats is not None:
|
if scheduler_stats.spec_decoding_stats is not None:
|
||||||
self.spec_decoding_metrics.observe(
|
self.spec_decoding_logging.observe(
|
||||||
scheduler_stats.spec_decoding_stats)
|
scheduler_stats.spec_decoding_stats)
|
||||||
|
|
||||||
self.last_scheduler_stats = scheduler_stats
|
self.last_scheduler_stats = scheduler_stats
|
||||||
@ -112,7 +112,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if scheduler_stats.spec_decoding_stats is not None:
|
if scheduler_stats.spec_decoding_stats is not None:
|
||||||
self.spec_decoding_metrics.log(log_fn=log_fn)
|
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||||
|
|
||||||
|
|
||||||
class PrometheusStatLogger(StatLoggerBase):
|
class PrometheusStatLogger(StatLoggerBase):
|
||||||
@ -133,6 +133,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
|
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
|
||||||
|
self.spec_decoding_prom = SpecDecodingProm(
|
||||||
|
vllm_config.speculative_config, labelnames, labelvalues)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Scheduler state
|
# Scheduler state
|
||||||
#
|
#
|
||||||
@ -323,24 +326,6 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.labelname_running_lora_adapters,
|
self.labelname_running_lora_adapters,
|
||||||
])
|
])
|
||||||
|
|
||||||
#
|
|
||||||
# Speculative Decoding metrics
|
|
||||||
# The acceptance rate can be calculated using a PromQL query:
|
|
||||||
#
|
|
||||||
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
|
||||||
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
|
|
||||||
#
|
|
||||||
self.counter_spec_decode_num_draft_tokens = \
|
|
||||||
prometheus_client.Counter(
|
|
||||||
name="vllm:spec_decode_num_draft_tokens_total",
|
|
||||||
documentation="Number of draft tokens.",
|
|
||||||
labelnames=labelnames).labels(*labelvalues)
|
|
||||||
self.counter_spec_decode_num_accepted_tokens = \
|
|
||||||
prometheus_client.Counter(
|
|
||||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
|
||||||
documentation="Number of accepted tokens.",
|
|
||||||
labelnames=labelnames).labels(*labelvalues)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Cache config info metric
|
# Cache config info metric
|
||||||
#
|
#
|
||||||
@ -378,10 +363,8 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
scheduler_stats.prefix_cache_stats.hits)
|
scheduler_stats.prefix_cache_stats.hits)
|
||||||
|
|
||||||
if scheduler_stats.spec_decoding_stats is not None:
|
if scheduler_stats.spec_decoding_stats is not None:
|
||||||
self.counter_spec_decode_num_draft_tokens.inc(
|
self.spec_decoding_prom.observe(
|
||||||
scheduler_stats.spec_decoding_stats.num_draft_tokens)
|
scheduler_stats.spec_decoding_stats)
|
||||||
self.counter_spec_decode_num_accepted_tokens.inc(
|
|
||||||
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
|
|
||||||
|
|
||||||
if iteration_stats is None:
|
if iteration_stats is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import prometheus_client
|
||||||
|
|
||||||
|
from vllm.config import SpeculativeConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -11,52 +14,151 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpecDecodingStats:
|
class SpecDecodingStats:
|
||||||
|
"""Per-step iteration decoding stats from scheduler.
|
||||||
|
|
||||||
|
Each scheduler step, statistics on spec decoding performance are
|
||||||
|
aggregated across requests by the scheduler and returned to the
|
||||||
|
frontend in EngineCoreOutputs->SchedulerStats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_spec_tokens: int
|
||||||
|
num_drafts: int = 0
|
||||||
num_draft_tokens: int = 0
|
num_draft_tokens: int = 0
|
||||||
num_accepted_tokens: int = 0
|
num_accepted_tokens: int = 0
|
||||||
|
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
|
||||||
|
|
||||||
def take(self):
|
@classmethod
|
||||||
copied = SpecDecodingStats(self.num_draft_tokens,
|
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
|
||||||
self.num_accepted_tokens)
|
return cls(num_spec_tokens=num_spec_tokens,
|
||||||
self.reset()
|
num_accepted_tokens_per_pos=[0] * num_spec_tokens)
|
||||||
return copied
|
|
||||||
|
|
||||||
def reset(self):
|
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
|
||||||
self.num_draft_tokens = 0
|
self.num_drafts += 1
|
||||||
self.num_accepted_tokens = 0
|
|
||||||
|
|
||||||
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
|
|
||||||
self.num_draft_tokens += num_draft_tokens
|
self.num_draft_tokens += num_draft_tokens
|
||||||
self.num_accepted_tokens += num_accepted_tokens
|
self.num_accepted_tokens += num_accepted_tokens
|
||||||
|
assert num_accepted_tokens <= self.num_spec_tokens
|
||||||
|
for i in range(num_accepted_tokens):
|
||||||
|
self.num_accepted_tokens_per_pos[i] += 1
|
||||||
|
|
||||||
|
|
||||||
class SpecDecodingMetrics:
|
class SpecDecodingLogging:
|
||||||
|
"""Aggregate and log spec decoding metrics.
|
||||||
|
|
||||||
|
LoggingStatLogger aggregates per-iteration metrics over a set
|
||||||
|
time interval using observe() and then logs them using log()
|
||||||
|
before resetting to zero.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
self.num_drafts: list[int] = []
|
||||||
self.num_draft_tokens: list[int] = []
|
self.num_draft_tokens: list[int] = []
|
||||||
self.num_accepted_tokens: list[int] = []
|
self.num_accepted_tokens: list[int] = []
|
||||||
|
self.accepted_tokens_per_pos_lists: list[list[int]] = []
|
||||||
|
|
||||||
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||||
|
self.num_drafts.append(spec_decoding_stats.num_drafts)
|
||||||
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
|
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
|
||||||
self.num_accepted_tokens.append(
|
self.num_accepted_tokens.append(
|
||||||
spec_decoding_stats.num_accepted_tokens)
|
spec_decoding_stats.num_accepted_tokens)
|
||||||
|
self.accepted_tokens_per_pos_lists.append(
|
||||||
|
spec_decoding_stats.num_accepted_tokens_per_pos)
|
||||||
|
|
||||||
def log(self, log_fn=logger.info):
|
def log(self, log_fn=logger.info):
|
||||||
|
num_drafts = np.sum(self.num_drafts)
|
||||||
num_draft_tokens = np.sum(self.num_draft_tokens)
|
num_draft_tokens = np.sum(self.num_draft_tokens)
|
||||||
num_accepted_tokens = np.sum(self.num_accepted_tokens)
|
num_accepted_tokens = np.sum(self.num_accepted_tokens)
|
||||||
|
|
||||||
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
|
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
|
||||||
100 if num_draft_tokens > 0 else float("nan"))
|
100 if num_draft_tokens > 0 else float("nan"))
|
||||||
|
mean_acceptance_length = (num_accepted_tokens / num_drafts)
|
||||||
|
|
||||||
|
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
|
||||||
|
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
|
||||||
|
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
|
||||||
|
|
||||||
log_fn(
|
log_fn(
|
||||||
"SpecDecoding metrics: "
|
"SpecDecoding metrics: "
|
||||||
"Draft acceptance rate: %.1f%%, "
|
"Draft acceptance rate: %.1f%%, "
|
||||||
|
"Mean acceptance length: %.2f, "
|
||||||
"Accepted: %d tokens, "
|
"Accepted: %d tokens, "
|
||||||
"Drafted: %d tokens",
|
"Drafted: %d tokens, "
|
||||||
|
"Per-position acceptance rate: %s",
|
||||||
draft_acceptance_rate,
|
draft_acceptance_rate,
|
||||||
|
mean_acceptance_length,
|
||||||
num_accepted_tokens,
|
num_accepted_tokens,
|
||||||
num_draft_tokens,
|
num_draft_tokens,
|
||||||
|
rates_str,
|
||||||
)
|
)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class SpecDecodingProm:
|
||||||
|
"""Record spec decoding metrics in Prometheus.
|
||||||
|
|
||||||
|
The acceptance rate can be calculated using a PromQL query:
|
||||||
|
|
||||||
|
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||||
|
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
|
||||||
|
|
||||||
|
The mean acceptance length can be calculated using:
|
||||||
|
|
||||||
|
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||||
|
rate(vllm:spec_decode_num_drafts[$interval])
|
||||||
|
|
||||||
|
A per-position acceptance rate vector can be computed using
|
||||||
|
|
||||||
|
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
|
||||||
|
vllm:spec_decode_num_drafts[$interval]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, speculative_config: Optional[SpeculativeConfig],
|
||||||
|
labelnames: list[str], labelvalues: list[str]):
|
||||||
|
self.spec_decoding_enabled = speculative_config is not None
|
||||||
|
if not self.spec_decoding_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.counter_spec_decode_num_drafts = \
|
||||||
|
prometheus_client.Counter(
|
||||||
|
name="vllm:spec_decode_num_drafts_total",
|
||||||
|
documentation="Number of spec decoding drafts.",
|
||||||
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
self.counter_spec_decode_num_draft_tokens = \
|
||||||
|
prometheus_client.Counter(
|
||||||
|
name="vllm:spec_decode_num_draft_tokens_total",
|
||||||
|
documentation="Number of draft tokens.",
|
||||||
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
self.counter_spec_decode_num_accepted_tokens = \
|
||||||
|
prometheus_client.Counter(
|
||||||
|
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||||
|
documentation="Number of accepted tokens.",
|
||||||
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
|
||||||
|
assert speculative_config is not None
|
||||||
|
num_spec_tokens = (speculative_config.num_speculative_tokens
|
||||||
|
if self.spec_decoding_enabled else 0)
|
||||||
|
pos_labelnames = labelnames + ["position"]
|
||||||
|
base_counter = prometheus_client.Counter(
|
||||||
|
name="vllm:spec_decode_num_accepted_tokens_per_pos",
|
||||||
|
documentation="Accepted tokens per draft position.",
|
||||||
|
labelnames=pos_labelnames)
|
||||||
|
self.counter_spec_decode_num_accepted_tokens_per_pos: \
|
||||||
|
list[prometheus_client.Counter] = []
|
||||||
|
for pos in range(num_spec_tokens):
|
||||||
|
pos_labelvalues = labelvalues + [str(pos)]
|
||||||
|
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
|
||||||
|
base_counter.labels(*pos_labelvalues))
|
||||||
|
|
||||||
|
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||||
|
if not self.spec_decoding_enabled:
|
||||||
|
return
|
||||||
|
self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts)
|
||||||
|
self.counter_spec_decode_num_draft_tokens.inc(
|
||||||
|
spec_decoding_stats.num_draft_tokens)
|
||||||
|
self.counter_spec_decode_num_accepted_tokens.inc(
|
||||||
|
spec_decoding_stats.num_accepted_tokens)
|
||||||
|
for pos, counter in enumerate(
|
||||||
|
self.counter_spec_decode_num_accepted_tokens_per_pos):
|
||||||
|
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user