From 2965c99c86b460ee819e4805764d769c7b7d3d8e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 30 Jun 2025 08:28:13 -0700 Subject: [PATCH] [Spec Decode] Clean up spec decode example (#20240) Signed-off-by: Woosuk Kwon --- examples/offline_inference/eagle.py | 144 ---------------------- examples/offline_inference/spec_decode.py | 40 +++--- 2 files changed, 21 insertions(+), 163 deletions(-) delete mode 100644 examples/offline_inference/eagle.py diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py deleted file mode 100644 index f4193fdb8bd38..0000000000000 --- a/examples/offline_inference/eagle.py +++ /dev/null @@ -1,144 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import json -import os - -from transformers import AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.v1.metrics.reader import Counter, Vector - - -def load_prompts(dataset_path, num_prompts): - if os.path.exists(dataset_path): - prompts = [] - try: - with open(dataset_path) as f: - for line in f: - data = json.loads(line) - prompts.append(data["turns"][0]) - except Exception as e: - print(f"Error reading dataset: {e}") - return [] - else: - prompts = ["The future of AI is", "The president of the United States is"] - - return prompts[:num_prompts] - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--dataset", - type=str, - default="./examples/data/gsm8k.jsonl", - help="downloaded from the eagle repo " - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/", - ) - parser.add_argument( - "--method", type=str, default="eagle", choices=["eagle", "eagle3"] - ) - parser.add_argument("--max_num_seqs", type=int, default=8) - parser.add_argument("--num_prompts", type=int, default=80) - parser.add_argument("--num_spec_tokens", type=int, default=2) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--draft_tp", type=int, default=1) - parser.add_argument("--enforce_eager", action="store_true") - parser.add_argument("--enable_chunked_prefill", action="store_true") - parser.add_argument("--max_num_batched_tokens", type=int, default=2048) - parser.add_argument("--temp", type=float, default=0) - return parser.parse_args() - - -def main(): - args = parse_args() - - model_dir = "meta-llama/Llama-3.1-8B-Instruct" - - if args.method == "eagle": - eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - elif args.method == "eagle3": - eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - else: - raise ValueError(f"unknown method: {args.method}") - - max_model_len = 2048 - - tokenizer = AutoTokenizer.from_pretrained(model_dir) - - prompts = load_prompts(args.dataset, args.num_prompts) - - prompt_ids = [ - tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], add_generation_prompt=True - ) - for prompt in prompts - ] - - llm = LLM( - model=model_dir, - trust_remote_code=True, - tensor_parallel_size=args.tp, - enable_chunked_prefill=args.enable_chunked_prefill, - max_num_batched_tokens=args.max_num_batched_tokens, - enforce_eager=args.enforce_eager, - max_model_len=max_model_len, - max_num_seqs=args.max_num_seqs, - gpu_memory_utilization=0.8, - speculative_config={ - "method": args.method, - "model": eagle_dir, - "num_speculative_tokens": args.num_spec_tokens, - "draft_tensor_parallel_size": args.draft_tp, - "max_model_len": max_model_len, - }, - disable_log_stats=False, - ) - - sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) - - outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) - - # print the generated text - for output in outputs: - print("-" * 50) - print(f"prompt: {output.prompt}") - print(f"generated text: {output.outputs[0].text}") - print("-" * 50) - - try: - metrics = llm.get_metrics() - except AssertionError: - print("Metrics are not supported in the V0 engine.") - return - - num_drafts = num_accepted = 0 - acceptance_counts = [0] * args.num_spec_tokens - for metric in metrics: - if metric.name == "vllm:spec_decode_num_drafts": - assert isinstance(metric, Counter) - num_drafts += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens": - assert isinstance(metric, Counter) - num_accepted += metric.value - elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": - assert isinstance(metric, Vector) - for pos in range(len(metric.values)): - acceptance_counts[pos] += metric.values[pos] - - print("-" * 50) - print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") - print("-" * 50) - - # print acceptance at each token position - for i in range(len(acceptance_counts)): - print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") - - -if __name__ == "__main__": - print( - "[WARNING] Use examples/offline_inference/spec_decode.py" - " instead of this script." - ) - main() diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 6fa68d2ecee1d..90d103e5cb05d 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -16,24 +16,17 @@ def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) parser.add_argument( - "--dataset", + "--method", type=str, - default="./examples/data/gsm8k.jsonl", - help="downloaded from the eagle repo " - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/", + default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], ) - parser.add_argument( - "--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3"] - ) - parser.add_argument("--max-num-seqs", type=int, default=8) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-min", type=int, default=2) parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--draft-tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") - parser.add_argument("--max-num-batched-tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) @@ -41,7 +34,6 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) - parser.add_argument("--max-model-len", type=int, default=2048) return parser.parse_args() @@ -71,8 +63,6 @@ def main(): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, - "draft_tensor_parallel_size": args.draft_tp, - "max_model_len": args.max_model_len, } elif args.method == "ngram": speculative_config = { @@ -80,7 +70,6 @@ def main(): "num_speculative_tokens": args.num_spec_tokens, "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, - "max_model_len": args.max_model_len, } else: raise ValueError(f"unknown method: {args.method}") @@ -92,7 +81,6 @@ def main(): enable_chunked_prefill=args.enable_chunked_prefill, max_num_batched_tokens=args.max_num_batched_tokens, enforce_eager=args.enforce_eager, - max_model_len=args.max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config=speculative_config, @@ -116,27 +104,41 @@ def main(): print("Metrics are not supported in the V0 engine.") return - num_drafts = num_accepted = 0 + total_num_output_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs + ) + num_drafts = 0 + num_draft_tokens = 0 + num_accepted_tokens = 0 acceptance_counts = [0] * args.num_spec_tokens for metric in metrics: if metric.name == "vllm:spec_decode_num_drafts": assert isinstance(metric, Counter) num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_draft_tokens": + assert isinstance(metric, Counter) + num_draft_tokens += metric.value elif metric.name == "vllm:spec_decode_num_accepted_tokens": assert isinstance(metric, Counter) - num_accepted += metric.value + num_accepted_tokens += metric.value elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": assert isinstance(metric, Vector) for pos in range(len(metric.values)): acceptance_counts[pos] += metric.values[pos] print("-" * 50) - print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") + print(f"total_num_output_tokens: {total_num_output_tokens}") + print(f"num_drafts: {num_drafts}") + print(f"num_draft_tokens: {num_draft_tokens}") + print(f"num_accepted_tokens: {num_accepted_tokens}") + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + print(f"mean acceptance length: {acceptance_length:.2f}") print("-" * 50) # print acceptance at each token position for i in range(len(acceptance_counts)): - print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") + acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 + print(f"acceptance at token {i}: {acceptance_rate:.2f}") if __name__ == "__main__":