[Spec Decoding]Support Spec Decoding Metrics in DP Mode (#24049)

Signed-off-by: wuhang <wuhang6@huawei.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
wuhang 2025-09-15 05:11:09 +08:00 committed by GitHub
parent 6dc8da5dc1
commit 90f3f7d73e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 37 deletions

View File

@ -169,15 +169,11 @@ class PrometheusStatLogger(StatLoggerBase):
model_name = vllm_config.model_config.served_model_name
max_model_len = vllm_config.model_config.max_model_len
if (len(self.engine_indexes) > 1
and vllm_config.speculative_config is not None):
raise NotImplementedError("Prometheus metrics with Spec Decoding "
"with >1 EngineCore per AsyncLLM is not "
"supported yet.")
spec_decode_labelvalues = [
vllm_config.model_config.served_model_name,
str(self.engine_indexes[0])
]
spec_decode_labelvalues: dict[int, list[str]] = {
idx: [model_name, str(idx)]
for idx in engine_indexes
}
self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames,
spec_decode_labelvalues)
@ -530,7 +526,7 @@ class PrometheusStatLogger(StatLoggerBase):
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats)
scheduler_stats.spec_decoding_stats, engine_idx)
if iteration_stats is None:
return

View File

@ -140,27 +140,32 @@ class SpecDecodingProm:
self,
speculative_config: Optional[SpeculativeConfig],
labelnames: list[str],
labelvalues: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self.spec_decoding_enabled = speculative_config is not None
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts = \
self._counter_cls(
name="vllm:spec_decode_num_drafts",
documentation="Number of spec decoding drafts.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_draft_tokens = \
self._counter_cls(
name="vllm:spec_decode_num_draft_tokens",
documentation="Number of draft tokens.",
labelnames=labelnames,).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
counter_drafts = self._counter_cls(
name="vllm:spec_decode_num_drafts",
documentation="Number of spec decoding drafts.",
labelnames=labelnames)
self.counter_spec_decode_num_drafts = make_per_engine(
counter_drafts, per_engine_labelvalues)
counter_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_draft_tokens = make_per_engine(
counter_draft_tokens, per_engine_labelvalues)
counter_accepted_tokens = self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens",
documentation="Number of accepted tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
counter_accepted_tokens, per_engine_labelvalues)
assert speculative_config is not None
num_spec_tokens = (speculative_config.num_speculative_tokens
@ -171,21 +176,36 @@ class SpecDecodingProm:
documentation="Accepted tokens per draft position.",
labelnames=pos_labelnames,
)
self.counter_spec_decode_num_accepted_tokens_per_pos: list[
prometheus_client.Counter] = []
for pos in range(num_spec_tokens):
pos_labelvalues = labelvalues + [str(pos)]
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
base_counter.labels(*pos_labelvalues))
self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
int, list[prometheus_client.Counter]] = {
idx: [
base_counter.labels(*lv, str(pos))
for pos in range(num_spec_tokens)
]
for idx, lv in per_engine_labelvalues.items()
}
def observe(self, spec_decoding_stats: SpecDecodingStats):
def observe(self,
spec_decoding_stats: SpecDecodingStats,
engine_idx: int = 0):
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts)
self.counter_spec_decode_num_draft_tokens.inc(
self.counter_spec_decode_num_drafts[engine_idx].inc(
spec_decoding_stats.num_drafts)
self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
spec_decoding_stats.num_accepted_tokens)
for pos, counter in enumerate(
self.counter_spec_decode_num_accepted_tokens_per_pos):
self.
counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
def make_per_engine(counter: prometheus_client.Counter,
per_engine_labelvalues: dict[int, list[str]]):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}