mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:45:26 +08:00
Chat method for offline llm (#5049)
Co-authored-by: nunjunj <ray@g-3ff9f30f2ed650001.c.vllm-405802.internal> Co-authored-by: nunjunj <ray@g-1df6075697c3f0001.c.vllm-405802.internal> Co-authored-by: nunjunj <ray@g-c5a2c23abc49e0001.c.vllm-405802.internal> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4cd7d47fed
commit
3b19e39dc5
@ -147,6 +147,7 @@ steps:
|
||||
- pip install awscli tensorizer # for llava example and tensorizer test
|
||||
- python3 offline_inference.py
|
||||
- python3 cpu_offload.py
|
||||
- python3 offline_inference_chat.py
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 offline_inference_vision_language.py
|
||||
|
||||
53
examples/offline_inference_chat.py
Normal file
53
examples/offline_inference_chat.py
Normal file
@ -0,0 +1,53 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
sampling_params = SamplingParams(temperature=0.5)
|
||||
|
||||
|
||||
def print_outputs(outputs):
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
# In this script, we demonstrate how to pass input to the chat method:
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Write an essay about the importance of higher education.",
|
||||
},
|
||||
]
|
||||
outputs = llm.chat(conversation,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print_outputs(outputs)
|
||||
|
||||
# A chat template can be optionally supplied.
|
||||
# If not, the model will use its default chat template.
|
||||
|
||||
# with open('template_falcon_180b.jinja', "r") as f:
|
||||
# chat_template = f.read()
|
||||
|
||||
# outputs = llm.chat(
|
||||
# conversations,
|
||||
# sampling_params=sampling_params,
|
||||
# use_tqdm=False,
|
||||
# chat_template=chat_template,
|
||||
# )
|
||||
@ -140,3 +140,22 @@ def test_multiple_sampling_params(llm: LLM):
|
||||
# sampling_params is None, default params should be applied
|
||||
outputs = llm.generate(PROMPTS, sampling_params=None)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
|
||||
def test_chat():
|
||||
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
|
||||
prompt1 = "Explain the concept of entropy."
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt1
|
||||
},
|
||||
]
|
||||
outputs = llm.chat(messages)
|
||||
assert len(outputs) == 1
|
||||
|
||||
@ -6,6 +6,9 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_chat_template,
|
||||
parse_chat_messages)
|
||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
@ -87,7 +90,7 @@ class LLM:
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
||||
:ref:`engine_args`)
|
||||
|
||||
|
||||
Note:
|
||||
This class is intended to be used for offline inference. For online
|
||||
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
||||
@ -138,8 +141,12 @@ class LLM:
|
||||
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
removed_vision_keys = ("image_token_id", "image_feature_size",
|
||||
"image_input_shape", "image_input_type")
|
||||
removed_vision_keys = (
|
||||
"image_token_id",
|
||||
"image_feature_size",
|
||||
"image_input_shape",
|
||||
"image_input_type",
|
||||
)
|
||||
if any(k in kwargs for k in removed_vision_keys):
|
||||
raise TypeError(
|
||||
"There is no need to pass vision-related arguments anymore.")
|
||||
@ -259,11 +266,12 @@ class LLM:
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs("prompts",
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter "
|
||||
"instead.")
|
||||
@deprecate_kwargs(
|
||||
"prompts",
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter instead.",
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||
@ -286,17 +294,17 @@ class LLM:
|
||||
Args:
|
||||
inputs: A list of inputs to generate completions for.
|
||||
sampling_params: The sampling parameters for text generation. If
|
||||
None, we use the default sampling parameters.
|
||||
When it is a single value, it is applied to every prompt.
|
||||
When it is a list, the list must have the same length as the
|
||||
None, we use the default sampling parameters.
|
||||
When it is a single value, it is applied to every prompt.
|
||||
When it is a list, the list must have the same length as the
|
||||
prompts and it is paired one by one with the prompt.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of `RequestOutput` objects containing the
|
||||
A list of ``RequestOutput`` objects containing the
|
||||
generated completions in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
@ -339,6 +347,62 @@ class LLM:
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
add_generation_template: bool = True,
|
||||
) -> List[RequestOutput]:
|
||||
"""
|
||||
Generates responses for chat messages.
|
||||
|
||||
Converts the messages to prompts using the tokenizer and calls
|
||||
the :meth:`generate` method to generate the responses.
|
||||
|
||||
Args:
|
||||
messages: A list of messages to generate responses for. Each
|
||||
message is a list of dictionaries with 'role' and 'content'
|
||||
keys.
|
||||
sampling_params: The sampling parameters for text generation.
|
||||
If None, we use the default sampling parameters. When it
|
||||
is a single value, it is applied to every prompt. When it
|
||||
is a list, the list must have the same length as the
|
||||
prompts and it is paired one by one with the prompt.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
chat_template: The template to use for structuring the chat.
|
||||
If not provided, the model's default chat template will be used.
|
||||
add_generation_template: If True, adds a generation template
|
||||
to each message.
|
||||
|
||||
Returns:
|
||||
A list of ``RequestOutput`` objects containing the generated
|
||||
responses in the same order as the input messages.
|
||||
"""
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
model_config = self.llm_engine.get_model_config()
|
||||
|
||||
conversations, _ = parse_chat_messages(messages, model_config,
|
||||
tokenizer)
|
||||
|
||||
prompts = apply_chat_template(
|
||||
tokenizer,
|
||||
conversations,
|
||||
chat_template=chat_template,
|
||||
add_generation_template=add_generation_template)
|
||||
|
||||
return self.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@overload # LEGACY: single (prompt + optional token ids)
|
||||
def encode(
|
||||
self,
|
||||
@ -413,11 +477,12 @@ class LLM:
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs("prompts",
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter "
|
||||
"instead.")
|
||||
@deprecate_kwargs(
|
||||
"prompts",
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter instead.",
|
||||
)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||
@ -443,7 +508,7 @@ class LLM:
|
||||
use the default pooling parameters.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
@ -563,15 +628,15 @@ class LLM:
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[Union[List[LoRARequest],
|
||||
LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(
|
||||
@ -579,7 +644,8 @@ class LLM:
|
||||
inputs,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
def _add_guided_processor(
|
||||
self,
|
||||
@ -628,8 +694,8 @@ class LLM:
|
||||
in_spd = total_in_toks / pbar.format_dict["elapsed"]
|
||||
total_out_toks += sum(
|
||||
len(stp.token_ids) for stp in output.outputs)
|
||||
out_spd = total_out_toks / pbar.format_dict[
|
||||
"elapsed"]
|
||||
out_spd = (total_out_toks /
|
||||
pbar.format_dict["elapsed"])
|
||||
pbar.postfix = (
|
||||
f"est. speed input: {in_spd:.2f} toks/s, "
|
||||
f"output: {out_spd:.2f} toks/s")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user