diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index e96081c167ed9..6a4862123b517 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -89,3 +89,31 @@ def test_chat_multi_image(image_urls: list[str]): }] outputs = llm.chat(messages) assert len(outputs) >= 0 + + +def test_llm_chat_tokenization_no_double_bos(): + """ + LLM.chat() should not add special tokens when using chat templates. + Check we get a single BOS token for llama chat. + """ + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True) + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello!" + }, + ] + outputs = llm.chat(messages) + assert len(outputs) == 1 + prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None) + assert prompt_token_ids is not None + + bos_token = llm.get_tokenizer().bos_token_id + + # Ensure we have a single BOS + assert prompt_token_ids[0] == bos_token + assert prompt_token_ids[1] != bos_token, "Double BOS" diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5fcf88f57a133..90bd5494c183d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -117,7 +117,7 @@ class LLM: disable_async_output_proc: Disable async output processing. This may result in lower performance. hf_token: The token to use as HTTP bearer authorization for remote files - . If `True`, will use the token generated when running + . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the @@ -251,8 +251,12 @@ class LLM: self.request_counter = Counter() self.default_sampling_params: Union[dict[str, Any], None] = None - def get_tokenizer(self) -> AnyTokenizer: - return self.llm_engine.get_tokenizer_group().tokenizer + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group().get_lora_tokenizer( + lora_request) def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: tokenizer_group = self.llm_engine.get_tokenizer_group() @@ -712,7 +716,7 @@ class LLM: cast(list[ChatCompletionMessageParam], messages) ] - tokenizer = self.get_tokenizer() + tokenizer = self.get_tokenizer(lora_request) model_config = self.llm_engine.get_model_config() resolved_content_format = resolve_chat_template_content_format( chat_template, @@ -735,9 +739,8 @@ class LLM: content_format=resolved_content_format, ) - prompt_data: Union[str, list[int]] if isinstance(tokenizer, MistralTokenizer): - prompt_data = apply_mistral_chat_template( + prompt_token_ids = apply_mistral_chat_template( tokenizer, messages=msgs, chat_template=chat_template, @@ -746,7 +749,7 @@ class LLM: continue_final_message=continue_final_message, ) else: - prompt_data = apply_hf_chat_template( + prompt_str = apply_hf_chat_template( tokenizer, trust_remote_code=model_config.trust_remote_code, conversation=conversation, @@ -755,12 +758,12 @@ class LLM: add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, ) + # Special tokens are already included in chat templates so + # should not be added by the tokenizer in this case. + prompt_token_ids = tokenizer.encode(prompt_str, + add_special_tokens=False) - prompt: Union[TokensPrompt, TextPrompt] - if is_list_of(prompt_data, int): - prompt = TokensPrompt(prompt_token_ids=prompt_data) - else: - prompt = TextPrompt(prompt=prompt_data) + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) if mm_data is not None: prompt["multi_modal_data"] = mm_data @@ -1059,8 +1062,6 @@ class LLM: if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) - scores: list[PoolingRequestOutput] = [] - scores = _cosine_similarity(tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2)