mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 05:45:00 +08:00
[Frontend] Added support for HF's new continue_final_message parameter (#8942)
This commit is contained in:
parent
1fb9c1b0bf
commit
6c9ba48fde
@ -12,7 +12,7 @@ assert chatml_jinja_path.exists()
|
|||||||
|
|
||||||
# Define models, templates, and their corresponding expected outputs
|
# Define models, templates, and their corresponding expected outputs
|
||||||
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
||||||
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
|
("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
|
||||||
Hello<|im_end|>
|
Hello<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Hi there!<|im_end|>
|
Hi there!<|im_end|>
|
||||||
@ -20,12 +20,20 @@ Hi there!<|im_end|>
|
|||||||
What is the capital of<|im_end|>
|
What is the capital of<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
"""),
|
"""),
|
||||||
("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
|
("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
|
||||||
Hello<|im_end|>
|
Hello<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Hi there!<|im_end|>
|
Hi there!<|im_end|>
|
||||||
<|im_start|>user
|
<|im_start|>user
|
||||||
What is the capital of""")
|
What is the capital of"""),
|
||||||
|
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
|
||||||
|
Hello<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
Hi there!<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What is the capital of<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
The capital of"""),
|
||||||
]
|
]
|
||||||
|
|
||||||
TEST_MESSAGES = [
|
TEST_MESSAGES = [
|
||||||
@ -42,6 +50,10 @@ TEST_MESSAGES = [
|
|||||||
'content': 'What is the capital of'
|
'content': 'What is the capital of'
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
ASSISTANT_MESSAGE_TO_CONTINUE = {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'The capital of'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_load_chat_template():
|
def test_load_chat_template():
|
||||||
@ -73,10 +85,10 @@ def test_no_load_chat_template_literallike():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model,template,add_generation_prompt,expected_output",
|
"model,template,add_generation_prompt,continue_final_message,expected_output",
|
||||||
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
||||||
def test_get_gen_prompt(model, template, add_generation_prompt,
|
def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||||
expected_output):
|
continue_final_message, expected_output):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||||
template_content = load_chat_template(chat_template=template)
|
template_content = load_chat_template(chat_template=template)
|
||||||
@ -84,8 +96,11 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
# Create a mock request object using keyword arguments
|
# Create a mock request object using keyword arguments
|
||||||
mock_request = ChatCompletionRequest(
|
mock_request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=TEST_MESSAGES,
|
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
|
||||||
add_generation_prompt=add_generation_prompt)
|
if continue_final_message else TEST_MESSAGES,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
continue_final_message=continue_final_message,
|
||||||
|
)
|
||||||
|
|
||||||
# Call the function and get the result
|
# Call the function and get the result
|
||||||
result = apply_hf_chat_template(
|
result = apply_hf_chat_template(
|
||||||
@ -93,6 +108,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
conversation=mock_request.messages,
|
conversation=mock_request.messages,
|
||||||
chat_template=mock_request.chat_template or template_content,
|
chat_template=mock_request.chat_template or template_content,
|
||||||
add_generation_prompt=mock_request.add_generation_prompt,
|
add_generation_prompt=mock_request.add_generation_prompt,
|
||||||
|
continue_final_message=mock_request.continue_final_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test assertion
|
# Test assertion
|
||||||
|
|||||||
@ -104,17 +104,29 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Can I ask a question? vllm1"
|
"content": "Can I ask a question? vllm1"
|
||||||
}]
|
}]
|
||||||
|
for continue_final in [False, True]:
|
||||||
|
if add_generation and continue_final:
|
||||||
|
continue
|
||||||
|
if continue_final:
|
||||||
|
conversation.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Sure,"
|
||||||
|
})
|
||||||
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
add_generation_prompt=add_generation,
|
add_generation_prompt=add_generation,
|
||||||
|
continue_final_message=continue_final,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
tokenize=False)
|
tokenize=False)
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
tokens = tokenizer.encode(prompt,
|
||||||
|
add_special_tokens=add_special)
|
||||||
|
|
||||||
response = requests.post(base_url + "/tokenize",
|
response = requests.post(base_url + "/tokenize",
|
||||||
json={
|
json={
|
||||||
"add_generation_prompt":
|
"add_generation_prompt":
|
||||||
add_generation,
|
add_generation,
|
||||||
|
"continue_final_message":
|
||||||
|
continue_final,
|
||||||
"add_special_tokens": add_special,
|
"add_special_tokens": add_special,
|
||||||
"messages": conversation,
|
"messages": conversation,
|
||||||
"model": model_name
|
"model": model_name
|
||||||
|
|||||||
@ -542,6 +542,14 @@ def apply_mistral_chat_template(
|
|||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"'chat_template' cannot be overridden for mistral tokenizer.")
|
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||||
|
if "add_generation_prompt" in kwargs:
|
||||||
|
logger.warning(
|
||||||
|
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||||
|
"so it will be ignored.")
|
||||||
|
if "continue_final_message" in kwargs:
|
||||||
|
logger.warning(
|
||||||
|
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||||
|
"so it will be ignored.")
|
||||||
|
|
||||||
return tokenizer.apply_chat_template(
|
return tokenizer.apply_chat_template(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|||||||
@ -501,6 +501,7 @@ class LLM:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
add_generation_prompt: bool = True,
|
add_generation_prompt: bool = True,
|
||||||
|
continue_final_message: bool = False,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
"""
|
"""
|
||||||
@ -528,6 +529,9 @@ class LLM:
|
|||||||
If not provided, the model's default chat template will be used.
|
If not provided, the model's default chat template will be used.
|
||||||
add_generation_prompt: If True, adds a generation template
|
add_generation_prompt: If True, adds a generation template
|
||||||
to each message.
|
to each message.
|
||||||
|
continue_final_message: If True, continues the final message in
|
||||||
|
the conversation instead of starting a new one. Cannot be `True`
|
||||||
|
if `add_generation_prompt` is also `True`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of ``RequestOutput`` objects containing the generated
|
A list of ``RequestOutput`` objects containing the generated
|
||||||
@ -559,6 +563,7 @@ class LLM:
|
|||||||
messages=msgs,
|
messages=msgs,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
continue_final_message=continue_final_message,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -567,6 +572,7 @@ class LLM:
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
continue_final_message=continue_final_message,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -211,6 +211,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
"This is a parameter used by chat template in tokenizer config of the "
|
"This is a parameter used by chat template in tokenizer config of the "
|
||||||
"model."),
|
"model."),
|
||||||
)
|
)
|
||||||
|
continue_final_message: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=
|
||||||
|
("If this is set, the chat will be formatted so that the final "
|
||||||
|
"message in the chat is open-ended, without any EOS tokens. The "
|
||||||
|
"model will continue this message rather than starting a new one. "
|
||||||
|
"This allows you to \"prefill\" part of the model's response for it. "
|
||||||
|
"Cannot be used at the same time as `add_generation_prompt`."),
|
||||||
|
)
|
||||||
add_special_tokens: bool = Field(
|
add_special_tokens: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description=(
|
description=(
|
||||||
@ -431,6 +440,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
" of the specified `tools`")
|
" of the specified `tools`")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_generation_prompt(cls, data):
|
||||||
|
if data.get("continue_final_message") and data.get(
|
||||||
|
"add_generation_prompt"):
|
||||||
|
raise ValueError("Cannot set both `continue_final_message` and "
|
||||||
|
"`add_generation_prompt` to True.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(OpenAIBaseModel):
|
class CompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
@ -862,8 +880,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
|
|||||||
messages: List[ChatCompletionMessageParam]
|
messages: List[ChatCompletionMessageParam]
|
||||||
|
|
||||||
add_generation_prompt: bool = Field(default=True)
|
add_generation_prompt: bool = Field(default=True)
|
||||||
|
continue_final_message: bool = Field(default=False)
|
||||||
add_special_tokens: bool = Field(default=False)
|
add_special_tokens: bool = Field(default=False)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_generation_prompt(cls, data):
|
||||||
|
if data.get("continue_final_message") and data.get(
|
||||||
|
"add_generation_prompt"):
|
||||||
|
raise ValueError("Cannot set both `continue_final_message` and "
|
||||||
|
"`add_generation_prompt` to True.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
|
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
|
||||||
|
|
||||||
|
|||||||
@ -140,6 +140,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
messages=request.messages,
|
messages=request.messages,
|
||||||
chat_template=request.chat_template or self.chat_template,
|
chat_template=request.chat_template or self.chat_template,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
tools=tool_dicts,
|
tools=tool_dicts,
|
||||||
documents=request.documents,
|
documents=request.documents,
|
||||||
**(request.chat_template_kwargs or {}),
|
**(request.chat_template_kwargs or {}),
|
||||||
@ -150,6 +151,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
chat_template=request.chat_template or self.chat_template,
|
chat_template=request.chat_template or self.chat_template,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
tools=tool_dicts,
|
tools=tool_dicts,
|
||||||
documents=request.documents,
|
documents=request.documents,
|
||||||
**(request.chat_template_kwargs or {}),
|
**(request.chat_template_kwargs or {}),
|
||||||
@ -361,7 +363,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
# Send response to echo the input portion of the
|
# Send response to echo the input portion of the
|
||||||
# last message
|
# last message
|
||||||
if request.echo:
|
if request.echo or request.continue_final_message:
|
||||||
last_msg_content: str = ""
|
last_msg_content: str = ""
|
||||||
if conversation and "content" in conversation[
|
if conversation and "content" in conversation[
|
||||||
-1] and conversation[-1].get("role") == role:
|
-1] and conversation[-1].get("role") == role:
|
||||||
@ -716,7 +718,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
stop_reason=output.stop_reason)
|
stop_reason=output.stop_reason)
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
if request.echo:
|
if request.echo or request.continue_final_message:
|
||||||
last_msg_content = ""
|
last_msg_content = ""
|
||||||
if conversation and "content" in conversation[-1] and conversation[
|
if conversation and "content" in conversation[-1] and conversation[
|
||||||
-1].get("role") == role:
|
-1].get("role") == role:
|
||||||
|
|||||||
@ -87,6 +87,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
messages=request.messages,
|
messages=request.messages,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = apply_hf_chat_template(
|
prompt = apply_hf_chat_template(
|
||||||
@ -94,6 +95,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user