mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 19:25:31 +08:00
[Bugfix] StatLoggers: cache spec decode metrics when they get collected. (#6645)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
01c16ede6b
commit
2f808e69ab
@ -1,3 +1,4 @@
|
|||||||
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -10,6 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||||||
from vllm.engine.metrics import RayPrometheusStatLogger
|
from vllm.engine.metrics import RayPrometheusStatLogger
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
from ..conftest import cleanup
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
]
|
]
|
||||||
@ -219,6 +222,94 @@ def test_metric_spec_decode(
|
|||||||
"does not meet expectation")
|
"does not meet expectation")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [10])
|
||||||
|
@pytest.mark.parametrize("log_interval", [1, 3, 5, 7])
|
||||||
|
def test_metric_spec_decode_interval(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
log_interval: int,
|
||||||
|
) -> None:
|
||||||
|
k = 5
|
||||||
|
|
||||||
|
engine_args = EngineArgs(model=model,
|
||||||
|
dtype=dtype,
|
||||||
|
disable_log_stats=False,
|
||||||
|
gpu_memory_utilization=0.4,
|
||||||
|
speculative_model=model,
|
||||||
|
num_speculative_tokens=k,
|
||||||
|
use_v2_block_manager=True,
|
||||||
|
enforce_eager=True)
|
||||||
|
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
engine.add_request(
|
||||||
|
"request-id-0",
|
||||||
|
example_prompts[0],
|
||||||
|
SamplingParams(max_tokens=max_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
|
# set log internal
|
||||||
|
stat_logger = engine.stat_loggers['prometheus']
|
||||||
|
stat_logger.local_interval = log_interval
|
||||||
|
|
||||||
|
# prefill
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
# wait for 5 seconds to ensure that spec decode metrics
|
||||||
|
# get triggered in first decode step
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
# first decode step should trigger async collection of metrics
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
# wait one second to allow H2D transfer to finish
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# second decode step should now be able to collect the spec
|
||||||
|
# decode stats and the request should also be finished
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
# must have finisehd now
|
||||||
|
assert not engine.has_unfinished_requests()
|
||||||
|
|
||||||
|
# wait to ensure logging occurs
|
||||||
|
time.sleep(log_interval)
|
||||||
|
|
||||||
|
# force logging
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
# Note that the purpose of this test is to verify spec decode
|
||||||
|
# metrics instead of functional correctness, so the expected values
|
||||||
|
# are intended to be loose.
|
||||||
|
metric_name_to_expected_fn = {
|
||||||
|
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
|
||||||
|
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
|
||||||
|
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
|
||||||
|
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
|
||||||
|
"counter_spec_decode_num_emitted_tokens":
|
||||||
|
lambda v: 0 <= v <= k + 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
for metric_name, is_expected in metric_name_to_expected_fn.items():
|
||||||
|
metric_val = getattr(
|
||||||
|
stat_logger.metrics,
|
||||||
|
metric_name).labels(**stat_logger.labels)._value.get()
|
||||||
|
assert is_expected(metric_val), (
|
||||||
|
f"the value of metric {metric_name} ({metric_val}) "
|
||||||
|
"does not meet expectation")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
del engine
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
|
def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
|
||||||
num_requests: int) -> None:
|
num_requests: int) -> None:
|
||||||
if disable_log_stats:
|
if disable_log_stats:
|
||||||
|
|||||||
@ -355,6 +355,7 @@ class StatLoggerBase(ABC):
|
|||||||
self.num_generation_tokens: List[int] = []
|
self.num_generation_tokens: List[int] = []
|
||||||
self.last_local_log = time.time()
|
self.last_local_log = time.time()
|
||||||
self.local_interval = local_interval
|
self.local_interval = local_interval
|
||||||
|
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||||
@ -364,6 +365,12 @@ class StatLoggerBase(ABC):
|
|||||||
def log(self, stats: Stats) -> None:
|
def log(self, stats: Stats) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def maybe_update_spec_decode_metrics(self, stats: Stats):
|
||||||
|
"""Save spec decode metrics (since they are unlikely
|
||||||
|
to be emitted at same time as log interval)."""
|
||||||
|
if stats.spec_decode_metrics is not None:
|
||||||
|
self.spec_decode_metrics = stats.spec_decode_metrics
|
||||||
|
|
||||||
|
|
||||||
class LoggingStatLogger(StatLoggerBase):
|
class LoggingStatLogger(StatLoggerBase):
|
||||||
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
||||||
@ -379,6 +386,9 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||||
|
|
||||||
|
# Update spec decode metrics
|
||||||
|
self.maybe_update_spec_decode_metrics(stats)
|
||||||
|
|
||||||
# Log locally every local_interval seconds.
|
# Log locally every local_interval seconds.
|
||||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||||
self.local_interval):
|
self.local_interval):
|
||||||
@ -408,15 +418,16 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
stats.cpu_cache_usage_sys * 100,
|
stats.cpu_cache_usage_sys * 100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.spec_decode_metrics is not None:
|
||||||
|
logger.info(
|
||||||
|
self._format_spec_decode_metrics_str(
|
||||||
|
self.spec_decode_metrics))
|
||||||
|
|
||||||
# Reset tracked stats for next interval.
|
# Reset tracked stats for next interval.
|
||||||
self.num_prompt_tokens = []
|
self.num_prompt_tokens = []
|
||||||
self.num_generation_tokens = []
|
self.num_generation_tokens = []
|
||||||
self.last_local_log = stats.now
|
self.last_local_log = stats.now
|
||||||
|
self.spec_decode_metrics = None
|
||||||
if stats.spec_decode_metrics is not None:
|
|
||||||
logger.info(
|
|
||||||
self._format_spec_decode_metrics_str(
|
|
||||||
stats.spec_decode_metrics))
|
|
||||||
|
|
||||||
def _format_spec_decode_metrics_str(
|
def _format_spec_decode_metrics_str(
|
||||||
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
||||||
@ -533,6 +544,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||||
|
|
||||||
|
# Update spec decode metrics
|
||||||
|
self.maybe_update_spec_decode_metrics(stats)
|
||||||
|
|
||||||
# Log locally every local_interval seconds.
|
# Log locally every local_interval seconds.
|
||||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||||
self.local_interval):
|
self.local_interval):
|
||||||
@ -550,26 +564,27 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
prompt_throughput=prompt_throughput,
|
prompt_throughput=prompt_throughput,
|
||||||
generation_throughput=generation_throughput)
|
generation_throughput=generation_throughput)
|
||||||
|
|
||||||
|
if self.spec_decode_metrics is not None:
|
||||||
|
self._log_gauge(
|
||||||
|
self.metrics.gauge_spec_decode_draft_acceptance_rate,
|
||||||
|
self.spec_decode_metrics.draft_acceptance_rate)
|
||||||
|
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
|
||||||
|
self.spec_decode_metrics.system_efficiency)
|
||||||
|
self._log_counter(
|
||||||
|
self.metrics.counter_spec_decode_num_accepted_tokens,
|
||||||
|
self.spec_decode_metrics.accepted_tokens)
|
||||||
|
self._log_counter(
|
||||||
|
self.metrics.counter_spec_decode_num_draft_tokens,
|
||||||
|
self.spec_decode_metrics.draft_tokens)
|
||||||
|
self._log_counter(
|
||||||
|
self.metrics.counter_spec_decode_num_emitted_tokens,
|
||||||
|
self.spec_decode_metrics.emitted_tokens)
|
||||||
|
|
||||||
# Reset tracked stats for next interval.
|
# Reset tracked stats for next interval.
|
||||||
self.num_prompt_tokens = []
|
self.num_prompt_tokens = []
|
||||||
self.num_generation_tokens = []
|
self.num_generation_tokens = []
|
||||||
self.last_local_log = stats.now
|
self.last_local_log = stats.now
|
||||||
|
self.spec_decode_metrics = None
|
||||||
if stats.spec_decode_metrics is not None:
|
|
||||||
self._log_gauge(
|
|
||||||
self.metrics.gauge_spec_decode_draft_acceptance_rate,
|
|
||||||
stats.spec_decode_metrics.draft_acceptance_rate)
|
|
||||||
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
|
|
||||||
stats.spec_decode_metrics.system_efficiency)
|
|
||||||
self._log_counter(
|
|
||||||
self.metrics.counter_spec_decode_num_accepted_tokens,
|
|
||||||
stats.spec_decode_metrics.accepted_tokens)
|
|
||||||
self._log_counter(
|
|
||||||
self.metrics.counter_spec_decode_num_draft_tokens,
|
|
||||||
stats.spec_decode_metrics.draft_tokens)
|
|
||||||
self._log_counter(
|
|
||||||
self.metrics.counter_spec_decode_num_emitted_tokens,
|
|
||||||
stats.spec_decode_metrics.emitted_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user