[Misc] Log spec decode metrics (#6454)

This commit is contained in:
Cody Yu 2024-07-16 13:37:10 -07:00 committed by GitHub
parent 94162beb9f
commit 160e1d8c99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 137 additions and 14 deletions

View File

@ -168,6 +168,55 @@ def test_engine_log_metrics_regression(
assert_metrics(engine, disable_log_stats, len(example_prompts))
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [10])
def test_metric_spec_decode(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
k = 5
with vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
use_v2_block_manager=True) as vllm_model:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
stat_logger.local_interval = 0
# Note that the purpose of this test is to verify spec decode
# metrics instead of functional correctness, so the expected values
# are intended to be loose.
metric_name_to_expected_fn = {
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
"counter_spec_decode_num_emitted_tokens":
lambda v: 0 <= v <= k + 1,
}
# Use one request to better inspect the metrics.
prompts = example_prompts[:1]
_ = vllm_model.generate_greedy(prompts, max_tokens)
for metric_name, is_expected in metric_name_to_expected_fn.items():
metric_val = getattr(
stat_logger.metrics,
metric_name).labels(**stat_logger.labels)._value.get()
assert is_expected(metric_val), (
f"the value of metric {metric_name} ({metric_val}) "
"does not meet expectation")
def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests: int) -> None:
if disable_log_stats:

View File

@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
}
test_name = request.node.name
model = kwargs["model"]
draft_model = kwargs.get("speculative_model", None)
same_draft_target_model = (draft_model is not None
and draft_model == model)
def generator_inner():
wait_for_gpu_memory_to_clear(
@ -177,6 +182,13 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
# Override logging interval to 0 for spec decode test run to
# log all metrics in time.
if (baseline_or_test == "test" and not use_async
and llm.llm_engine.log_stats):
for sate_logger in llm.llm_engine.stat_loggers.values():
sate_logger.local_interval = 0
set_random_seed(seed)
yield llm
@ -188,6 +200,9 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
yield llm
del llm
# Set an attribute to the generator_outer function to allow us to
# determine whether to further check the acceptance rate in tests.
generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
return generator_outer
@ -204,18 +219,26 @@ def maybe_assert_ngram_worker(llm):
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
sampling_params) -> Tuple[List[str], List[List[int]], float]:
tokens: List[str] = []
token_ids: List[List[int]] = []
acceptance_rate: float = -1.0
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
# Fetch acceptance rate if logging is enabled.
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
stat_logger = stat_loggers["prometheus"]
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
del llm
return tokens, token_ids
return tokens, token_ids, acceptance_rate
def get_logprobs_from_llm_generator(
@ -237,7 +260,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
print_tokens: bool = False,
ensure_all_accepted: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
@ -267,12 +291,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
temperature=temperature,
)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(spec_batch_tokens, spec_batch_token_ids,
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
(baseline_batch_tokens, baseline_batch_token_ids,
_) = get_output_from_llm_generator(baseline_llm_generator, prompts,
sampling_params)
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
@ -287,3 +312,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
if ensure_all_accepted:
assert acceptance_rate == 1.0

View File

@ -97,7 +97,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
temperature=temperature,
)
batch_tokens, batch_token_ids = get_output_from_llm_generator(
batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
# Expect a generation for each prompt in the batch.
@ -200,12 +200,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
ensure_all_accepted = test_llm_generator.same_draft_target_model
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
force_output_len=True,
ensure_all_accepted=ensure_all_accepted)
@pytest.mark.parametrize(

View File

@ -133,6 +133,30 @@ class Metrics:
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Speculatie decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames)
self.gauge_spec_decode_efficiency = self._base_library.Gauge(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames)
self.counter_spec_decode_num_accepted_tokens = (
self._base_library.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._base_library.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = (
self._base_library.Counter(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))
# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
@ -454,6 +478,22 @@ class PrometheusStatLogger(StatLoggerBase):
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
stats.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
stats.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
stats.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
stats.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
stats.spec_decode_metrics.emitted_tokens)
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""