mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
[Spec Decoding]Support Spec Decoding Metrics in DP Mode (#24049)
Signed-off-by: wuhang <wuhang6@huawei.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
6dc8da5dc1
commit
90f3f7d73e
@ -169,15 +169,11 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
model_name = vllm_config.model_config.served_model_name
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
if (len(self.engine_indexes) > 1
|
||||
and vllm_config.speculative_config is not None):
|
||||
raise NotImplementedError("Prometheus metrics with Spec Decoding "
|
||||
"with >1 EngineCore per AsyncLLM is not "
|
||||
"supported yet.")
|
||||
spec_decode_labelvalues = [
|
||||
vllm_config.model_config.served_model_name,
|
||||
str(self.engine_indexes[0])
|
||||
]
|
||||
spec_decode_labelvalues: dict[int, list[str]] = {
|
||||
idx: [model_name, str(idx)]
|
||||
for idx in engine_indexes
|
||||
}
|
||||
|
||||
self.spec_decoding_prom = self._spec_decoding_cls(
|
||||
vllm_config.speculative_config, labelnames,
|
||||
spec_decode_labelvalues)
|
||||
@ -530,7 +526,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_prom.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
scheduler_stats.spec_decoding_stats, engine_idx)
|
||||
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
@ -140,27 +140,32 @@ class SpecDecodingProm:
|
||||
self,
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
labelnames: list[str],
|
||||
labelvalues: list[str],
|
||||
per_engine_labelvalues: dict[int, list[str]],
|
||||
):
|
||||
self.spec_decoding_enabled = speculative_config is not None
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
|
||||
self.counter_spec_decode_num_drafts = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_drafts",
|
||||
documentation="Number of spec decoding drafts.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_draft_tokens = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames,).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_accepted_tokens = \
|
||||
self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
counter_drafts = self._counter_cls(
|
||||
name="vllm:spec_decode_num_drafts",
|
||||
documentation="Number of spec decoding drafts.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_drafts = make_per_engine(
|
||||
counter_drafts, per_engine_labelvalues)
|
||||
|
||||
counter_draft_tokens = self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_draft_tokens = make_per_engine(
|
||||
counter_draft_tokens, per_engine_labelvalues)
|
||||
|
||||
counter_accepted_tokens = self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
|
||||
counter_accepted_tokens, per_engine_labelvalues)
|
||||
|
||||
assert speculative_config is not None
|
||||
num_spec_tokens = (speculative_config.num_speculative_tokens
|
||||
@ -171,21 +176,36 @@ class SpecDecodingProm:
|
||||
documentation="Accepted tokens per draft position.",
|
||||
labelnames=pos_labelnames,
|
||||
)
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos: list[
|
||||
prometheus_client.Counter] = []
|
||||
for pos in range(num_spec_tokens):
|
||||
pos_labelvalues = labelvalues + [str(pos)]
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
|
||||
base_counter.labels(*pos_labelvalues))
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
|
||||
int, list[prometheus_client.Counter]] = {
|
||||
idx: [
|
||||
base_counter.labels(*lv, str(pos))
|
||||
for pos in range(num_spec_tokens)
|
||||
]
|
||||
for idx, lv in per_engine_labelvalues.items()
|
||||
}
|
||||
|
||||
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||
def observe(self,
|
||||
spec_decoding_stats: SpecDecodingStats,
|
||||
engine_idx: int = 0):
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts)
|
||||
self.counter_spec_decode_num_draft_tokens.inc(
|
||||
self.counter_spec_decode_num_drafts[engine_idx].inc(
|
||||
spec_decoding_stats.num_drafts)
|
||||
self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
|
||||
spec_decoding_stats.num_draft_tokens)
|
||||
self.counter_spec_decode_num_accepted_tokens.inc(
|
||||
self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
|
||||
spec_decoding_stats.num_accepted_tokens)
|
||||
for pos, counter in enumerate(
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos):
|
||||
self.
|
||||
counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]):
|
||||
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
|
||||
|
||||
|
||||
def make_per_engine(counter: prometheus_client.Counter,
|
||||
per_engine_labelvalues: dict[int, list[str]]):
|
||||
"""Create a counter for each label value."""
|
||||
return {
|
||||
idx: counter.labels(*labelvalues)
|
||||
for idx, labelvalues in per_engine_labelvalues.items()
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user