mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:25:01 +08:00
Chat method for offline llm (#5049)
Co-authored-by: nunjunj <ray@g-3ff9f30f2ed650001.c.vllm-405802.internal> Co-authored-by: nunjunj <ray@g-1df6075697c3f0001.c.vllm-405802.internal> Co-authored-by: nunjunj <ray@g-c5a2c23abc49e0001.c.vllm-405802.internal> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4cd7d47fed
commit
3b19e39dc5
@ -147,6 +147,7 @@ steps:
|
|||||||
- pip install awscli tensorizer # for llava example and tensorizer test
|
- pip install awscli tensorizer # for llava example and tensorizer test
|
||||||
- python3 offline_inference.py
|
- python3 offline_inference.py
|
||||||
- python3 cpu_offload.py
|
- python3 cpu_offload.py
|
||||||
|
- python3 offline_inference_chat.py
|
||||||
- python3 offline_inference_with_prefix.py
|
- python3 offline_inference_with_prefix.py
|
||||||
- python3 llm_engine_example.py
|
- python3 llm_engine_example.py
|
||||||
- python3 offline_inference_vision_language.py
|
- python3 offline_inference_vision_language.py
|
||||||
|
|||||||
53
examples/offline_inference_chat.py
Normal file
53
examples/offline_inference_chat.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||||
|
sampling_params = SamplingParams(temperature=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def print_outputs(outputs):
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# In this script, we demonstrate how to pass input to the chat method:
|
||||||
|
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I assist you today?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Write an essay about the importance of higher education.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
outputs = llm.chat(conversation,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=False)
|
||||||
|
print_outputs(outputs)
|
||||||
|
|
||||||
|
# A chat template can be optionally supplied.
|
||||||
|
# If not, the model will use its default chat template.
|
||||||
|
|
||||||
|
# with open('template_falcon_180b.jinja', "r") as f:
|
||||||
|
# chat_template = f.read()
|
||||||
|
|
||||||
|
# outputs = llm.chat(
|
||||||
|
# conversations,
|
||||||
|
# sampling_params=sampling_params,
|
||||||
|
# use_tqdm=False,
|
||||||
|
# chat_template=chat_template,
|
||||||
|
# )
|
||||||
@ -140,3 +140,22 @@ def test_multiple_sampling_params(llm: LLM):
|
|||||||
# sampling_params is None, default params should be applied
|
# sampling_params is None, default params should be applied
|
||||||
outputs = llm.generate(PROMPTS, sampling_params=None)
|
outputs = llm.generate(PROMPTS, sampling_params=None)
|
||||||
assert len(PROMPTS) == len(outputs)
|
assert len(PROMPTS) == len(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat():
|
||||||
|
|
||||||
|
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
|
prompt1 = "Explain the concept of entropy."
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt1
|
||||||
|
},
|
||||||
|
]
|
||||||
|
outputs = llm.chat(messages)
|
||||||
|
assert len(outputs) == 1
|
||||||
|
|||||||
@ -6,6 +6,9 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||||
|
apply_chat_template,
|
||||||
|
parse_chat_messages)
|
||||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -138,8 +141,12 @@ class LLM:
|
|||||||
|
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
removed_vision_keys = ("image_token_id", "image_feature_size",
|
removed_vision_keys = (
|
||||||
"image_input_shape", "image_input_type")
|
"image_token_id",
|
||||||
|
"image_feature_size",
|
||||||
|
"image_input_shape",
|
||||||
|
"image_input_type",
|
||||||
|
)
|
||||||
if any(k in kwargs for k in removed_vision_keys):
|
if any(k in kwargs for k in removed_vision_keys):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"There is no need to pass vision-related arguments anymore.")
|
"There is no need to pass vision-related arguments anymore.")
|
||||||
@ -259,11 +266,12 @@ class LLM:
|
|||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@deprecate_kwargs("prompts",
|
@deprecate_kwargs(
|
||||||
|
"prompts",
|
||||||
"prompt_token_ids",
|
"prompt_token_ids",
|
||||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||||
additional_message="Please use the 'inputs' parameter "
|
additional_message="Please use the 'inputs' parameter instead.",
|
||||||
"instead.")
|
)
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||||
@ -296,7 +304,7 @@ class LLM:
|
|||||||
generation, if any.
|
generation, if any.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `RequestOutput` objects containing the
|
A list of ``RequestOutput`` objects containing the
|
||||||
generated completions in the same order as the input prompts.
|
generated completions in the same order as the input prompts.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@ -339,6 +347,62 @@ class LLM:
|
|||||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
sampling_params: Optional[Union[SamplingParams,
|
||||||
|
List[SamplingParams]]] = None,
|
||||||
|
use_tqdm: bool = True,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
add_generation_template: bool = True,
|
||||||
|
) -> List[RequestOutput]:
|
||||||
|
"""
|
||||||
|
Generates responses for chat messages.
|
||||||
|
|
||||||
|
Converts the messages to prompts using the tokenizer and calls
|
||||||
|
the :meth:`generate` method to generate the responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: A list of messages to generate responses for. Each
|
||||||
|
message is a list of dictionaries with 'role' and 'content'
|
||||||
|
keys.
|
||||||
|
sampling_params: The sampling parameters for text generation.
|
||||||
|
If None, we use the default sampling parameters. When it
|
||||||
|
is a single value, it is applied to every prompt. When it
|
||||||
|
is a list, the list must have the same length as the
|
||||||
|
prompts and it is paired one by one with the prompt.
|
||||||
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
chat_template: The template to use for structuring the chat.
|
||||||
|
If not provided, the model's default chat template will be used.
|
||||||
|
add_generation_template: If True, adds a generation template
|
||||||
|
to each message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of ``RequestOutput`` objects containing the generated
|
||||||
|
responses in the same order as the input messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
model_config = self.llm_engine.get_model_config()
|
||||||
|
|
||||||
|
conversations, _ = parse_chat_messages(messages, model_config,
|
||||||
|
tokenizer)
|
||||||
|
|
||||||
|
prompts = apply_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
conversations,
|
||||||
|
chat_template=chat_template,
|
||||||
|
add_generation_template=add_generation_template)
|
||||||
|
|
||||||
|
return self.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
use_tqdm=use_tqdm,
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
@overload # LEGACY: single (prompt + optional token ids)
|
@overload # LEGACY: single (prompt + optional token ids)
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
@ -413,11 +477,12 @@ class LLM:
|
|||||||
) -> List[EmbeddingRequestOutput]:
|
) -> List[EmbeddingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@deprecate_kwargs("prompts",
|
@deprecate_kwargs(
|
||||||
|
"prompts",
|
||||||
"prompt_token_ids",
|
"prompt_token_ids",
|
||||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||||
additional_message="Please use the 'inputs' parameter "
|
additional_message="Please use the 'inputs' parameter instead.",
|
||||||
"instead.")
|
)
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||||
@ -563,15 +628,15 @@ class LLM:
|
|||||||
params[i] if isinstance(params, Sequence) else params,
|
params[i] if isinstance(params, Sequence) else params,
|
||||||
lora_request=lora_request[i] if isinstance(
|
lora_request=lora_request[i] if isinstance(
|
||||||
lora_request, Sequence) else lora_request,
|
lora_request, Sequence) else lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
)
|
||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
inputs: PromptInputs,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
lora_request: Optional[Union[List[LoRARequest],
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
LoRARequest]] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
|
||||||
) -> None:
|
) -> None:
|
||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_engine.add_request(
|
self.llm_engine.add_request(
|
||||||
@ -579,7 +644,8 @@ class LLM:
|
|||||||
inputs,
|
inputs,
|
||||||
params,
|
params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
)
|
||||||
|
|
||||||
def _add_guided_processor(
|
def _add_guided_processor(
|
||||||
self,
|
self,
|
||||||
@ -628,8 +694,8 @@ class LLM:
|
|||||||
in_spd = total_in_toks / pbar.format_dict["elapsed"]
|
in_spd = total_in_toks / pbar.format_dict["elapsed"]
|
||||||
total_out_toks += sum(
|
total_out_toks += sum(
|
||||||
len(stp.token_ids) for stp in output.outputs)
|
len(stp.token_ids) for stp in output.outputs)
|
||||||
out_spd = total_out_toks / pbar.format_dict[
|
out_spd = (total_out_toks /
|
||||||
"elapsed"]
|
pbar.format_dict["elapsed"])
|
||||||
pbar.postfix = (
|
pbar.postfix = (
|
||||||
f"est. speed input: {in_spd:.2f} toks/s, "
|
f"est. speed input: {in_spd:.2f} toks/s, "
|
||||||
f"output: {out_spd:.2f} toks/s")
|
f"output: {out_spd:.2f} toks/s")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user