mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 01:15:47 +08:00
[BugFix][Frontend] Fix LLM.chat() tokenization (#16081)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
65e262b93b
commit
70116459c3
@ -89,3 +89,31 @@ def test_chat_multi_image(image_urls: list[str]):
|
|||||||
}]
|
}]
|
||||||
outputs = llm.chat(messages)
|
outputs = llm.chat(messages)
|
||||||
assert len(outputs) >= 0
|
assert len(outputs) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_chat_tokenization_no_double_bos():
|
||||||
|
"""
|
||||||
|
LLM.chat() should not add special tokens when using chat templates.
|
||||||
|
Check we get a single BOS token for llama chat.
|
||||||
|
"""
|
||||||
|
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello!"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
outputs = llm.chat(messages)
|
||||||
|
assert len(outputs) == 1
|
||||||
|
prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None)
|
||||||
|
assert prompt_token_ids is not None
|
||||||
|
|
||||||
|
bos_token = llm.get_tokenizer().bos_token_id
|
||||||
|
|
||||||
|
# Ensure we have a single BOS
|
||||||
|
assert prompt_token_ids[0] == bos_token
|
||||||
|
assert prompt_token_ids[1] != bos_token, "Double BOS"
|
||||||
|
|||||||
@ -251,8 +251,12 @@ class LLM:
|
|||||||
self.request_counter = Counter()
|
self.request_counter = Counter()
|
||||||
self.default_sampling_params: Union[dict[str, Any], None] = None
|
self.default_sampling_params: Union[dict[str, Any], None] = None
|
||||||
|
|
||||||
def get_tokenizer(self) -> AnyTokenizer:
|
def get_tokenizer(
|
||||||
return self.llm_engine.get_tokenizer_group().tokenizer
|
self,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
) -> AnyTokenizer:
|
||||||
|
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
|
||||||
|
lora_request)
|
||||||
|
|
||||||
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||||
tokenizer_group = self.llm_engine.get_tokenizer_group()
|
tokenizer_group = self.llm_engine.get_tokenizer_group()
|
||||||
@ -712,7 +716,7 @@ class LLM:
|
|||||||
cast(list[ChatCompletionMessageParam], messages)
|
cast(list[ChatCompletionMessageParam], messages)
|
||||||
]
|
]
|
||||||
|
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer(lora_request)
|
||||||
model_config = self.llm_engine.get_model_config()
|
model_config = self.llm_engine.get_model_config()
|
||||||
resolved_content_format = resolve_chat_template_content_format(
|
resolved_content_format = resolve_chat_template_content_format(
|
||||||
chat_template,
|
chat_template,
|
||||||
@ -735,9 +739,8 @@ class LLM:
|
|||||||
content_format=resolved_content_format,
|
content_format=resolved_content_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_data: Union[str, list[int]]
|
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
prompt_data = apply_mistral_chat_template(
|
prompt_token_ids = apply_mistral_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
messages=msgs,
|
messages=msgs,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
@ -746,7 +749,7 @@ class LLM:
|
|||||||
continue_final_message=continue_final_message,
|
continue_final_message=continue_final_message,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_data = apply_hf_chat_template(
|
prompt_str = apply_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
@ -755,12 +758,12 @@ class LLM:
|
|||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
continue_final_message=continue_final_message,
|
continue_final_message=continue_final_message,
|
||||||
)
|
)
|
||||||
|
# Special tokens are already included in chat templates so
|
||||||
|
# should not be added by the tokenizer in this case.
|
||||||
|
prompt_token_ids = tokenizer.encode(prompt_str,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
prompt: Union[TokensPrompt, TextPrompt]
|
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||||
if is_list_of(prompt_data, int):
|
|
||||||
prompt = TokensPrompt(prompt_token_ids=prompt_data)
|
|
||||||
else:
|
|
||||||
prompt = TextPrompt(prompt=prompt_data)
|
|
||||||
|
|
||||||
if mm_data is not None:
|
if mm_data is not None:
|
||||||
prompt["multi_modal_data"] = mm_data
|
prompt["multi_modal_data"] = mm_data
|
||||||
@ -1059,8 +1062,6 @@ class LLM:
|
|||||||
if len(encoded_output_1) == 1:
|
if len(encoded_output_1) == 1:
|
||||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||||
|
|
||||||
scores: list[PoolingRequestOutput] = []
|
|
||||||
|
|
||||||
scores = _cosine_similarity(tokenizer=tokenizer,
|
scores = _cosine_similarity(tokenizer=tokenizer,
|
||||||
embed_1=encoded_output_1,
|
embed_1=encoded_output_1,
|
||||||
embed_2=encoded_output_2)
|
embed_2=encoded_output_2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user