[Frontend] Added support for HF's new continue_final_message parameter (#8942)

This commit is contained in:
danieljannai21 2024-09-29 20:59:47 +03:00 committed by GitHub
parent 1fb9c1b0bf
commit 6c9ba48fde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 102 additions and 28 deletions

View File

@ -12,7 +12,7 @@ assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs # Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [ MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user ("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
Hello<|im_end|> Hello<|im_end|>
<|im_start|>assistant <|im_start|>assistant
Hi there!<|im_end|> Hi there!<|im_end|>
@ -20,12 +20,20 @@ Hi there!<|im_end|>
What is the capital of<|im_end|> What is the capital of<|im_end|>
<|im_start|>assistant <|im_start|>assistant
"""), """),
("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user ("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
Hello<|im_end|> Hello<|im_end|>
<|im_start|>assistant <|im_start|>assistant
Hi there!<|im_end|> Hi there!<|im_end|>
<|im_start|>user <|im_start|>user
What is the capital of""") What is the capital of"""),
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of"""),
] ]
TEST_MESSAGES = [ TEST_MESSAGES = [
@ -42,6 +50,10 @@ TEST_MESSAGES = [
'content': 'What is the capital of' 'content': 'What is the capital of'
}, },
] ]
ASSISTANT_MESSAGE_TO_CONTINUE = {
'role': 'assistant',
'content': 'The capital of'
}
def test_load_chat_template(): def test_load_chat_template():
@ -73,10 +85,10 @@ def test_no_load_chat_template_literallike():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model,template,add_generation_prompt,expected_output", "model,template,add_generation_prompt,continue_final_message,expected_output",
MODEL_TEMPLATE_GENERATON_OUTPUT) MODEL_TEMPLATE_GENERATON_OUTPUT)
def test_get_gen_prompt(model, template, add_generation_prompt, def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output): continue_final_message, expected_output):
# Initialize the tokenizer # Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model) tokenizer = get_tokenizer(tokenizer_name=model)
template_content = load_chat_template(chat_template=template) template_content = load_chat_template(chat_template=template)
@ -84,8 +96,11 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Create a mock request object using keyword arguments # Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest( mock_request = ChatCompletionRequest(
model=model, model=model,
messages=TEST_MESSAGES, messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
add_generation_prompt=add_generation_prompt) if continue_final_message else TEST_MESSAGES,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
# Call the function and get the result # Call the function and get the result
result = apply_hf_chat_template( result = apply_hf_chat_template(
@ -93,6 +108,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
conversation=mock_request.messages, conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content, chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt, add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
) )
# Test assertion # Test assertion

View File

@ -104,17 +104,29 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
"role": "user", "role": "user",
"content": "Can I ask a question? vllm1" "content": "Can I ask a question? vllm1"
}] }]
for continue_final in [False, True]:
if add_generation and continue_final:
continue
if continue_final:
conversation.append({
"role": "assistant",
"content": "Sure,"
})
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation, add_generation_prompt=add_generation,
continue_final_message=continue_final,
conversation=conversation, conversation=conversation,
tokenize=False) tokenize=False)
tokens = tokenizer.encode(prompt, add_special_tokens=add_special) tokens = tokenizer.encode(prompt,
add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize", response = requests.post(base_url + "/tokenize",
json={ json={
"add_generation_prompt": "add_generation_prompt":
add_generation, add_generation,
"continue_final_message":
continue_final,
"add_special_tokens": add_special, "add_special_tokens": add_special,
"messages": conversation, "messages": conversation,
"model": model_name "model": model_name

View File

@ -542,6 +542,14 @@ def apply_mistral_chat_template(
if chat_template is not None: if chat_template is not None:
logger.warning( logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.") "'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
messages=messages, messages=messages,

View File

@ -501,6 +501,7 @@ class LLM:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None, tools: Optional[List[Dict[str, Any]]] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
""" """
@ -528,6 +529,9 @@ class LLM:
If not provided, the model's default chat template will be used. If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template add_generation_prompt: If True, adds a generation template
to each message. to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
Returns: Returns:
A list of ``RequestOutput`` objects containing the generated A list of ``RequestOutput`` objects containing the generated
@ -559,6 +563,7 @@ class LLM:
messages=msgs, messages=msgs,
chat_template=chat_template, chat_template=chat_template,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools, tools=tools,
) )
else: else:
@ -567,6 +572,7 @@ class LLM:
conversation=conversation, conversation=conversation,
chat_template=chat_template, chat_template=chat_template,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools, tools=tools,
) )

View File

@ -211,6 +211,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"This is a parameter used by chat template in tokenizer config of the " "This is a parameter used by chat template in tokenizer config of the "
"model."), "model."),
) )
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=False, default=False,
description=( description=(
@ -431,6 +440,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
" of the specified `tools`") " of the specified `tools`")
return data return data
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
class CompletionRequest(OpenAIBaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
@ -862,8 +880,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
messages: List[ChatCompletionMessageParam] messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True) add_generation_prompt: bool = Field(default=True)
continue_final_message: bool = Field(default=False)
add_special_tokens: bool = Field(default=False) add_special_tokens: bool = Field(default=False)
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]

View File

@ -140,6 +140,7 @@ class OpenAIServingChat(OpenAIServing):
messages=request.messages, messages=request.messages,
chat_template=request.chat_template or self.chat_template, chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts, tools=tool_dicts,
documents=request.documents, documents=request.documents,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
@ -150,6 +151,7 @@ class OpenAIServingChat(OpenAIServing):
conversation=conversation, conversation=conversation,
chat_template=request.chat_template or self.chat_template, chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts, tools=tool_dicts,
documents=request.documents, documents=request.documents,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
@ -361,7 +363,7 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the # Send response to echo the input portion of the
# last message # last message
if request.echo: if request.echo or request.continue_final_message:
last_msg_content: str = "" last_msg_content: str = ""
if conversation and "content" in conversation[ if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role: -1] and conversation[-1].get("role") == role:
@ -716,7 +718,7 @@ class OpenAIServingChat(OpenAIServing):
stop_reason=output.stop_reason) stop_reason=output.stop_reason)
choices.append(choice_data) choices.append(choice_data)
if request.echo: if request.echo or request.continue_final_message:
last_msg_content = "" last_msg_content = ""
if conversation and "content" in conversation[-1] and conversation[ if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role: -1].get("role") == role:

View File

@ -87,6 +87,7 @@ class OpenAIServingTokenization(OpenAIServing):
messages=request.messages, messages=request.messages,
chat_template=self.chat_template, chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
) )
else: else:
prompt = apply_hf_chat_template( prompt = apply_hf_chat_template(
@ -94,6 +95,7 @@ class OpenAIServingTokenization(OpenAIServing):
conversation=conversation, conversation=conversation,
chat_template=self.chat_template, chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
) )
else: else:
prompt = request.prompt prompt = request.prompt