mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Misc] improve example mlpspeculator and llm_engine_example (#16175)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
parent
7699258ef0
commit
dc3529dbf6
@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
This file demonstrates using the `LLMEngine`
|
||||
for processing prompts with various sampling parameters.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
@ -26,6 +29,7 @@ def process_requests(engine: LLMEngine,
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
print('-' * 50)
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
@ -37,6 +41,7 @@ def process_requests(engine: LLMEngine,
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print(request_output)
|
||||
print('-' * 50)
|
||||
|
||||
|
||||
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
|
||||
@ -1,4 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This file demonstrates the usage of text generation with an LLM model,
|
||||
comparing the performance with and without speculative decoding.
|
||||
|
||||
Note that still not support `v1`:
|
||||
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import time
|
||||
@ -7,7 +14,7 @@ from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def time_generation(llm: LLM, prompts: list[str],
|
||||
sampling_params: SamplingParams):
|
||||
sampling_params: SamplingParams, title: str):
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||
# objects that contain the prompt, generated text, and other information.
|
||||
# Warmup first
|
||||
@ -16,11 +23,15 @@ def time_generation(llm: LLM, prompts: list[str],
|
||||
start = time.time()
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
end = time.time()
|
||||
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
|
||||
print("-" * 50)
|
||||
print(title)
|
||||
print("time: ",
|
||||
(end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -41,8 +52,7 @@ if __name__ == "__main__":
|
||||
# Create an LLM without spec decoding
|
||||
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
|
||||
|
||||
print("Without speculation")
|
||||
time_generation(llm, prompts, sampling_params)
|
||||
time_generation(llm, prompts, sampling_params, "Without speculation")
|
||||
|
||||
del llm
|
||||
gc.collect()
|
||||
@ -55,5 +65,4 @@ if __name__ == "__main__":
|
||||
},
|
||||
)
|
||||
|
||||
print("With speculation")
|
||||
time_generation(llm, prompts, sampling_params)
|
||||
time_generation(llm, prompts, sampling_params, "With speculation")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user