mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:46:18 +08:00
[Misc] improve example mlpspeculator and llm_engine_example (#16175)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
parent
7699258ef0
commit
dc3529dbf6
@ -1,5 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
This file demonstrates using the `LLMEngine`
|
||||||
|
for processing prompts with various sampling parameters.
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||||
@ -26,6 +29,7 @@ def process_requests(engine: LLMEngine,
|
|||||||
"""Continuously process a list of prompts and handle the outputs."""
|
"""Continuously process a list of prompts and handle the outputs."""
|
||||||
request_id = 0
|
request_id = 0
|
||||||
|
|
||||||
|
print('-' * 50)
|
||||||
while test_prompts or engine.has_unfinished_requests():
|
while test_prompts or engine.has_unfinished_requests():
|
||||||
if test_prompts:
|
if test_prompts:
|
||||||
prompt, sampling_params = test_prompts.pop(0)
|
prompt, sampling_params = test_prompts.pop(0)
|
||||||
@ -37,6 +41,7 @@ def process_requests(engine: LLMEngine,
|
|||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
if request_output.finished:
|
if request_output.finished:
|
||||||
print(request_output)
|
print(request_output)
|
||||||
|
print('-' * 50)
|
||||||
|
|
||||||
|
|
||||||
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||||
|
|||||||
@ -1,4 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
This file demonstrates the usage of text generation with an LLM model,
|
||||||
|
comparing the performance with and without speculative decoding.
|
||||||
|
|
||||||
|
Note that still not support `v1`:
|
||||||
|
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
|
||||||
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
@ -7,7 +14,7 @@ from vllm import LLM, SamplingParams
|
|||||||
|
|
||||||
|
|
||||||
def time_generation(llm: LLM, prompts: list[str],
|
def time_generation(llm: LLM, prompts: list[str],
|
||||||
sampling_params: SamplingParams):
|
sampling_params: SamplingParams, title: str):
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||||
# objects that contain the prompt, generated text, and other information.
|
# objects that contain the prompt, generated text, and other information.
|
||||||
# Warmup first
|
# Warmup first
|
||||||
@ -16,11 +23,15 @@ def time_generation(llm: LLM, prompts: list[str],
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
|
print("-" * 50)
|
||||||
|
print(title)
|
||||||
|
print("time: ",
|
||||||
|
(end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"text: {generated_text!r}")
|
print(f"text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -41,8 +52,7 @@ if __name__ == "__main__":
|
|||||||
# Create an LLM without spec decoding
|
# Create an LLM without spec decoding
|
||||||
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
|
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
|
||||||
|
|
||||||
print("Without speculation")
|
time_generation(llm, prompts, sampling_params, "Without speculation")
|
||||||
time_generation(llm, prompts, sampling_params)
|
|
||||||
|
|
||||||
del llm
|
del llm
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -55,5 +65,4 @@ if __name__ == "__main__":
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
print("With speculation")
|
time_generation(llm, prompts, sampling_params, "With speculation")
|
||||||
time_generation(llm, prompts, sampling_params)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user