diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 347185d8341e..3f4699281cce 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -169,15 +169,11 @@ class PrometheusStatLogger(StatLoggerBase): model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len - if (len(self.engine_indexes) > 1 - and vllm_config.speculative_config is not None): - raise NotImplementedError("Prometheus metrics with Spec Decoding " - "with >1 EngineCore per AsyncLLM is not " - "supported yet.") - spec_decode_labelvalues = [ - vllm_config.model_config.served_model_name, - str(self.engine_indexes[0]) - ] + spec_decode_labelvalues: dict[int, list[str]] = { + idx: [model_name, str(idx)] + for idx in engine_indexes + } + self.spec_decoding_prom = self._spec_decoding_cls( vllm_config.speculative_config, labelnames, spec_decode_labelvalues) @@ -530,7 +526,7 @@ class PrometheusStatLogger(StatLoggerBase): if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats) + scheduler_stats.spec_decoding_stats, engine_idx) if iteration_stats is None: return diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 2aa8962f5739..282e6f65e7ab 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -140,27 +140,32 @@ class SpecDecodingProm: self, speculative_config: Optional[SpeculativeConfig], labelnames: list[str], - labelvalues: list[str], + per_engine_labelvalues: dict[int, list[str]], ): self.spec_decoding_enabled = speculative_config is not None if not self.spec_decoding_enabled: return - self.counter_spec_decode_num_drafts = \ - self._counter_cls( - name="vllm:spec_decode_num_drafts", - documentation="Number of spec decoding drafts.", - labelnames=labelnames).labels(*labelvalues) - self.counter_spec_decode_num_draft_tokens = \ - self._counter_cls( - name="vllm:spec_decode_num_draft_tokens", - documentation="Number of draft tokens.", - labelnames=labelnames,).labels(*labelvalues) - self.counter_spec_decode_num_accepted_tokens = \ - self._counter_cls( - name="vllm:spec_decode_num_accepted_tokens", - documentation="Number of accepted tokens.", - labelnames=labelnames).labels(*labelvalues) + counter_drafts = self._counter_cls( + name="vllm:spec_decode_num_drafts", + documentation="Number of spec decoding drafts.", + labelnames=labelnames) + self.counter_spec_decode_num_drafts = make_per_engine( + counter_drafts, per_engine_labelvalues) + + counter_draft_tokens = self._counter_cls( + name="vllm:spec_decode_num_draft_tokens", + documentation="Number of draft tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_draft_tokens = make_per_engine( + counter_draft_tokens, per_engine_labelvalues) + + counter_accepted_tokens = self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens", + documentation="Number of accepted tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_accepted_tokens = make_per_engine( + counter_accepted_tokens, per_engine_labelvalues) assert speculative_config is not None num_spec_tokens = (speculative_config.num_speculative_tokens @@ -171,21 +176,36 @@ class SpecDecodingProm: documentation="Accepted tokens per draft position.", labelnames=pos_labelnames, ) - self.counter_spec_decode_num_accepted_tokens_per_pos: list[ - prometheus_client.Counter] = [] - for pos in range(num_spec_tokens): - pos_labelvalues = labelvalues + [str(pos)] - self.counter_spec_decode_num_accepted_tokens_per_pos.append( - base_counter.labels(*pos_labelvalues)) + self.counter_spec_decode_num_accepted_tokens_per_pos: dict[ + int, list[prometheus_client.Counter]] = { + idx: [ + base_counter.labels(*lv, str(pos)) + for pos in range(num_spec_tokens) + ] + for idx, lv in per_engine_labelvalues.items() + } - def observe(self, spec_decoding_stats: SpecDecodingStats): + def observe(self, + spec_decoding_stats: SpecDecodingStats, + engine_idx: int = 0): if not self.spec_decoding_enabled: return - self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts) - self.counter_spec_decode_num_draft_tokens.inc( + self.counter_spec_decode_num_drafts[engine_idx].inc( + spec_decoding_stats.num_drafts) + self.counter_spec_decode_num_draft_tokens[engine_idx].inc( spec_decoding_stats.num_draft_tokens) - self.counter_spec_decode_num_accepted_tokens.inc( + self.counter_spec_decode_num_accepted_tokens[engine_idx].inc( spec_decoding_stats.num_accepted_tokens) for pos, counter in enumerate( - self.counter_spec_decode_num_accepted_tokens_per_pos): + self. + counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) + + +def make_per_engine(counter: prometheus_client.Counter, + per_engine_labelvalues: dict[int, list[str]]): + """Create a counter for each label value.""" + return { + idx: counter.labels(*labelvalues) + for idx, labelvalues in per_engine_labelvalues.items() + }