From 4caf7044e052399f07089aa8f586d5bd641f7d53 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Fri, 23 Feb 2024 00:00:12 +0200 Subject: [PATCH] Include tokens from prompt phase in `counter_generation_tokens` (#2802) --- .buildkite/test-pipeline.yaml | 3 +++ tests/metrics/test_metrics.py | 34 +++++++++++++++++++++++++++++++++- vllm/engine/llm_engine.py | 3 +++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a91dcdfaf2ea..efcc4d2d07a1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -52,6 +52,9 @@ steps: - label: LoRA Test command: pytest -v -s lora +- label: Metrics Test + command: pytest -v -s metrics + - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" commands: diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index da608a6a18f9..fe09aa8237f2 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -9,13 +9,16 @@ MODELS = [ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) -def test_metrics( +def test_metric_counter_prompt_tokens( vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, ) -> None: + # Reset metric + vllm.engine.metrics.counter_prompt_tokens.set_value({}, 0) + vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False) tokenizer = vllm_model.model.get_tokenizer() prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts] @@ -31,3 +34,32 @@ def test_metrics( assert vllm_prompt_token_count == metric_count, ( f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}" ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_metric_counter_generation_tokens( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # Reset metric + vllm.engine.metrics.counter_generation_tokens.set_value({}, 0) + + vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + tokenizer = vllm_model.model.get_tokenizer() + metric_count = vllm.engine.metrics.counter_generation_tokens.get_value({}) + vllm_generation_count = 0 + for i in range(len(example_prompts)): + vllm_output_ids, vllm_output_str = vllm_outputs[i] + prompt_ids = tokenizer.encode(example_prompts[i]) + # vllm_output_ids contains both prompt tokens and generation tokens. We're interested only in the count of the generation tokens. + vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) + + assert vllm_generation_count == metric_count, ( + f"generation token count: {vllm_generation_count!r}\nmetric: {metric_count!r}" + ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f0de40f54db6..81c9281c5541 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -872,6 +872,9 @@ class LLMEngine: num_prompt_tokens = sum( len(seq_group.prompt_token_ids) for seq_group in scheduler_outputs.scheduled_seq_groups) + num_generation_tokens = sum( + seq_group.num_seqs() + for seq_group in scheduler_outputs.scheduled_seq_groups) else: num_generation_tokens = scheduler_outputs.num_batched_tokens