mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:05:53 +08:00
[Misc] Log spec decode metrics (#6454)
This commit is contained in:
parent
94162beb9f
commit
160e1d8c99
@ -168,6 +168,55 @@ def test_engine_log_metrics_regression(
|
||||
assert_metrics(engine, disable_log_stats, len(example_prompts))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_metric_spec_decode(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
k = 5
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.4,
|
||||
speculative_model=model,
|
||||
num_speculative_tokens=k,
|
||||
use_v2_block_manager=True) as vllm_model:
|
||||
|
||||
# Force log interval to be 0 to catch all metrics.
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
|
||||
stat_logger.local_interval = 0
|
||||
|
||||
# Note that the purpose of this test is to verify spec decode
|
||||
# metrics instead of functional correctness, so the expected values
|
||||
# are intended to be loose.
|
||||
metric_name_to_expected_fn = {
|
||||
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
|
||||
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
|
||||
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
|
||||
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
|
||||
"counter_spec_decode_num_emitted_tokens":
|
||||
lambda v: 0 <= v <= k + 1,
|
||||
}
|
||||
|
||||
# Use one request to better inspect the metrics.
|
||||
prompts = example_prompts[:1]
|
||||
|
||||
_ = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
for metric_name, is_expected in metric_name_to_expected_fn.items():
|
||||
metric_val = getattr(
|
||||
stat_logger.metrics,
|
||||
metric_name).labels(**stat_logger.labels)._value.get()
|
||||
assert is_expected(metric_val), (
|
||||
f"the value of metric {metric_name} ({metric_val}) "
|
||||
"does not meet expectation")
|
||||
|
||||
|
||||
def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
|
||||
num_requests: int) -> None:
|
||||
if disable_log_stats:
|
||||
|
||||
@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
}
|
||||
test_name = request.node.name
|
||||
|
||||
model = kwargs["model"]
|
||||
draft_model = kwargs.get("speculative_model", None)
|
||||
same_draft_target_model = (draft_model is not None
|
||||
and draft_model == model)
|
||||
|
||||
def generator_inner():
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
@ -177,6 +182,13 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
||||
|
||||
# Override logging interval to 0 for spec decode test run to
|
||||
# log all metrics in time.
|
||||
if (baseline_or_test == "test" and not use_async
|
||||
and llm.llm_engine.log_stats):
|
||||
for sate_logger in llm.llm_engine.stat_loggers.values():
|
||||
sate_logger.local_interval = 0
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
@ -188,6 +200,9 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
# Set an attribute to the generator_outer function to allow us to
|
||||
# determine whether to further check the acceptance rate in tests.
|
||||
generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
|
||||
return generator_outer
|
||||
|
||||
|
||||
@ -204,18 +219,26 @@ def maybe_assert_ngram_worker(llm):
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
sampling_params) -> Tuple[List[str], List[List[int]], float]:
|
||||
tokens: List[str] = []
|
||||
token_ids: List[List[int]] = []
|
||||
acceptance_rate: float = -1.0
|
||||
for llm in llm_generator():
|
||||
maybe_assert_ngram_worker(llm)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
|
||||
# Fetch acceptance rate if logging is enabled.
|
||||
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
|
||||
stat_logger = stat_loggers["prometheus"]
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||
**stat_logger.labels)._value.get())
|
||||
del llm
|
||||
|
||||
return tokens, token_ids
|
||||
return tokens, token_ids, acceptance_rate
|
||||
|
||||
|
||||
def get_logprobs_from_llm_generator(
|
||||
@ -237,7 +260,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
print_tokens: bool = False,
|
||||
ensure_all_accepted: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
@ -267,12 +291,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
(spec_batch_tokens, spec_batch_token_ids,
|
||||
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
(baseline_batch_tokens, baseline_batch_token_ids,
|
||||
_) = get_output_from_llm_generator(baseline_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
@ -287,3 +312,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
if ensure_all_accepted:
|
||||
assert acceptance_rate == 1.0
|
||||
|
||||
@ -97,7 +97,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
batch_tokens, batch_token_ids = get_output_from_llm_generator(
|
||||
batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
# Expect a generation for each prompt in the batch.
|
||||
@ -200,12 +200,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
|
||||
Since this test is cheaper than other e2e correctness tests, we generate
|
||||
with a higher output_len.
|
||||
|
||||
When the draft model is the same as the target model, we further check
|
||||
whether all speculative tokens are accepted.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
ensure_all_accepted = test_llm_generator.same_draft_target_model
|
||||
run_greedy_equality_correctness_test(
|
||||
baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
force_output_len=True,
|
||||
ensure_all_accepted=ensure_all_accepted)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -133,6 +133,30 @@ class Metrics:
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + [Metrics.labelname_finish_reason])
|
||||
|
||||
# Speculatie decoding stats
|
||||
self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge(
|
||||
name="vllm:spec_decode_draft_acceptance_rate",
|
||||
documentation="Speulative token acceptance rate.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_spec_decode_efficiency = self._base_library.Gauge(
|
||||
name="vllm:spec_decode_efficiency",
|
||||
documentation="Speculative decoding system efficiency.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_accepted_tokens = (
|
||||
self._base_library.Counter(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames))
|
||||
self.counter_spec_decode_num_draft_tokens = self._base_library.Counter(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_emitted_tokens = (
|
||||
self._base_library.Counter(
|
||||
name="vllm:spec_decode_num_emitted_tokens_total",
|
||||
documentation="Number of emitted tokens.",
|
||||
labelnames=labelnames))
|
||||
|
||||
# Deprecated in favor of vllm:prompt_tokens_total
|
||||
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
|
||||
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||
@ -454,6 +478,22 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
|
||||
if stats.spec_decode_metrics is not None:
|
||||
self._log_gauge(
|
||||
self.metrics.gauge_spec_decode_draft_acceptance_rate,
|
||||
stats.spec_decode_metrics.draft_acceptance_rate)
|
||||
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
|
||||
stats.spec_decode_metrics.system_efficiency)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_accepted_tokens,
|
||||
stats.spec_decode_metrics.accepted_tokens)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_draft_tokens,
|
||||
stats.spec_decode_metrics.draft_tokens)
|
||||
self._log_counter(
|
||||
self.metrics.counter_spec_decode_num_emitted_tokens,
|
||||
stats.spec_decode_metrics.emitted_tokens)
|
||||
|
||||
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user