[Misc] improve example mlpspeculator and llm_engine_example (#16175)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid 2025-04-07 19:53:52 +08:00 committed by GitHub
parent 7699258ef0
commit dc3529dbf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 7 deletions

View File

@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates using the `LLMEngine`
for processing prompts with various sampling parameters.
"""
import argparse
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
@ -26,6 +29,7 @@ def process_requests(engine: LLMEngine,
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
print('-' * 50)
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
@ -37,6 +41,7 @@ def process_requests(engine: LLMEngine,
for request_output in request_outputs:
if request_output.finished:
print(request_output)
print('-' * 50)
def initialize_engine(args: argparse.Namespace) -> LLMEngine:

View File

@ -1,4 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the usage of text generation with an LLM model,
comparing the performance with and without speculative decoding.
Note that still not support `v1`:
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
"""
import gc
import time
@ -7,7 +14,7 @@ from vllm import LLM, SamplingParams
def time_generation(llm: LLM, prompts: list[str],
sampling_params: SamplingParams):
sampling_params: SamplingParams, title: str):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
@ -16,11 +23,15 @@ def time_generation(llm: LLM, prompts: list[str],
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
print("-" * 50)
print(title)
print("time: ",
(end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
# Print the outputs.
for output in outputs:
generated_text = output.outputs[0].text
print(f"text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
@ -41,8 +52,7 @@ if __name__ == "__main__":
# Create an LLM without spec decoding
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
print("Without speculation")
time_generation(llm, prompts, sampling_params)
time_generation(llm, prompts, sampling_params, "Without speculation")
del llm
gc.collect()
@ -55,5 +65,4 @@ if __name__ == "__main__":
},
)
print("With speculation")
time_generation(llm, prompts, sampling_params)
time_generation(llm, prompts, sampling_params, "With speculation")