Update Spec Decode metrics to include drafted and accepted token throughput (#24127)

Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
Andrew Xia 2025-09-11 12:58:43 -07:00 committed by GitHub
parent b971f91504
commit 79ac59f32e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@ -58,6 +59,7 @@ class SpecDecodingLogging:
self.num_draft_tokens: list[int] = [] self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = [] self.num_accepted_tokens: list[int] = []
self.accepted_tokens_per_pos_lists: list[list[int]] = [] self.accepted_tokens_per_pos_lists: list[list[int]] = []
self.last_log_time = time.monotonic()
def observe(self, spec_decoding_stats: SpecDecodingStats): def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_drafts.append(spec_decoding_stats.num_drafts) self.num_drafts.append(spec_decoding_stats.num_drafts)
@ -73,6 +75,13 @@ class SpecDecodingLogging:
num_drafts = np.sum(self.num_drafts) num_drafts = np.sum(self.num_drafts)
num_draft_tokens = np.sum(self.num_draft_tokens) num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_throughput = 0
accepted_throughput = 0
elapsed_time = time.monotonic() - self.last_log_time
if elapsed_time > 0:
draft_throughput = num_draft_tokens / elapsed_time
accepted_throughput = num_accepted_tokens / elapsed_time
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
100 if num_draft_tokens > 0 else float("nan")) 100 if num_draft_tokens > 0 else float("nan"))
@ -86,16 +95,20 @@ class SpecDecodingLogging:
log_fn( log_fn(
"SpecDecoding metrics: " "SpecDecoding metrics: "
"Draft acceptance rate: %.1f%%, "
"Mean acceptance length: %.2f, " "Mean acceptance length: %.2f, "
"Accepted throughput: %.2f tokens/s, "
"Drafted throughput: %.2f tokens/s, "
"Accepted: %d tokens, " "Accepted: %d tokens, "
"Drafted: %d tokens, " "Drafted: %d tokens, "
"Per-position acceptance rate: %s", "Per-position acceptance rate: %s, "
draft_acceptance_rate, "Avg Draft acceptance rate: %.1f%%",
mean_acceptance_length, mean_acceptance_length,
accepted_throughput,
draft_throughput,
num_accepted_tokens, num_accepted_tokens,
num_draft_tokens, num_draft_tokens,
rates_str, rates_str,
draft_acceptance_rate,
) )
self.reset() self.reset()