From dc3529dbf65786fe25cce8144c76260266061c9d Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Mon, 7 Apr 2025 19:53:52 +0800 Subject: [PATCH] [Misc] improve example mlpspeculator and llm_engine_example (#16175) Signed-off-by: reidliu41 Co-authored-by: reidliu41 --- .../offline_inference/llm_engine_example.py | 7 ++++++- examples/offline_inference/mlpspeculator.py | 21 +++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference/llm_engine_example.py b/examples/offline_inference/llm_engine_example.py index e94f47b72b2e9..abff90d1c0cb6 100644 --- a/examples/offline_inference/llm_engine_example.py +++ b/examples/offline_inference/llm_engine_example.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 - +""" +This file demonstrates using the `LLMEngine` +for processing prompts with various sampling parameters. +""" import argparse from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams @@ -26,6 +29,7 @@ def process_requests(engine: LLMEngine, """Continuously process a list of prompts and handle the outputs.""" request_id = 0 + print('-' * 50) while test_prompts or engine.has_unfinished_requests(): if test_prompts: prompt, sampling_params = test_prompts.pop(0) @@ -37,6 +41,7 @@ def process_requests(engine: LLMEngine, for request_output in request_outputs: if request_output.finished: print(request_output) + print('-' * 50) def initialize_engine(args: argparse.Namespace) -> LLMEngine: diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index 380c53fab2201..a2a984b04e005 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -1,4 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the usage of text generation with an LLM model, +comparing the performance with and without speculative decoding. + +Note that still not support `v1`: +VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py +""" import gc import time @@ -7,7 +14,7 @@ from vllm import LLM, SamplingParams def time_generation(llm: LLM, prompts: list[str], - sampling_params: SamplingParams): + sampling_params: SamplingParams, title: str): # Generate texts from the prompts. The output is a list of RequestOutput # objects that contain the prompt, generated text, and other information. # Warmup first @@ -16,11 +23,15 @@ def time_generation(llm: LLM, prompts: list[str], start = time.time() outputs = llm.generate(prompts, sampling_params) end = time.time() - print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs])) + print("-" * 50) + print(title) + print("time: ", + (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs)) # Print the outputs. for output in outputs: generated_text = output.outputs[0].text print(f"text: {generated_text!r}") + print("-" * 50) if __name__ == "__main__": @@ -41,8 +52,7 @@ if __name__ == "__main__": # Create an LLM without spec decoding llm = LLM(model="meta-llama/Llama-2-13b-chat-hf") - print("Without speculation") - time_generation(llm, prompts, sampling_params) + time_generation(llm, prompts, sampling_params, "Without speculation") del llm gc.collect() @@ -55,5 +65,4 @@ if __name__ == "__main__": }, ) - print("With speculation") - time_generation(llm, prompts, sampling_params) + time_generation(llm, prompts, sampling_params, "With speculation")