diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 6a4862123b51..742a66683445 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -1,15 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 +import weakref import pytest from vllm import LLM +from vllm.distributed import cleanup_dist_env_and_memory from ..openai.test_vision import TEST_IMAGE_URLS -def test_chat(): - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") +@pytest.fixture(scope="function") +def text_llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + seed=0) + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +def test_chat(text_llm): prompt1 = "Explain the concept of entropy." messages = [ { @@ -21,13 +37,11 @@ def test_chat(): "content": prompt1 }, ] - outputs = llm.chat(messages) + outputs = text_llm.chat(messages) assert len(outputs) == 1 -def test_multi_chat(): - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") - +def test_multi_chat(text_llm): prompt1 = "Explain the concept of entropy." prompt2 = "Explain what among us is." @@ -55,13 +69,14 @@ def test_multi_chat(): messages = [conversation1, conversation2] - outputs = llm.chat(messages) + outputs = text_llm.chat(messages) assert len(outputs) == 2 -@pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) -def test_chat_multi_image(image_urls: list[str]): +@pytest.fixture(scope="function") +def vision_llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection llm = LLM( model="microsoft/Phi-3.5-vision-instruct", max_model_len=4096, @@ -69,8 +84,20 @@ def test_chat_multi_image(image_urls: list[str]): enforce_eager=True, trust_remote_code=True, limit_mm_per_prompt={"image": 2}, + seed=0, ) + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("image_urls", + [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +def test_chat_multi_image(vision_llm, image_urls: list[str]): messages = [{ "role": "user", @@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]): }, ], }] - outputs = llm.chat(messages) + outputs = vision_llm.chat(messages) assert len(outputs) >= 0 -def test_llm_chat_tokenization_no_double_bos(): +def test_llm_chat_tokenization_no_double_bos(text_llm): """ LLM.chat() should not add special tokens when using chat templates. Check we get a single BOS token for llama chat. """ - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True) messages = [ { "role": "system", @@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos(): "content": "Hello!" }, ] - outputs = llm.chat(messages) + outputs = text_llm.chat(messages) assert len(outputs) == 1 - prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None) + + prompt_token_ids = outputs[0].prompt_token_ids assert prompt_token_ids is not None - bos_token = llm.get_tokenizer().bos_token_id + bos_token = text_llm.get_tokenizer().bos_token_id # Ensure we have a single BOS assert prompt_token_ids[0] == bos_token assert prompt_token_ids[1] != bos_token, "Double BOS" + + +@pytest.fixture(scope="function") +def thinking_llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM( + model="Qwen/Qwen3-0.6B", + max_model_len=4096, + enforce_eager=True, + seed=0, + ) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("enable_thinking", [True, False]) +def test_chat_extra_kwargs(thinking_llm, enable_thinking): + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "What is 1+1?" + }, + ] + + outputs = thinking_llm.chat( + messages, + chat_template_kwargs={"enable_thinking": enable_thinking}, + ) + assert len(outputs) == 1 + + prompt_token_ids = outputs[0].prompt_token_ids + assert prompt_token_ids is not None + + think_id = thinking_llm.get_tokenizer().get_vocab()[""] + + if enable_thinking: + assert think_id not in prompt_token_ids + else: + # The chat template includes dummy thinking process + assert think_id in prompt_token_ids diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 653e61a11ebd..948e8f36e0e6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -656,6 +656,7 @@ class LLM: add_generation_prompt: bool = True, continue_final_message: bool = False, tools: Optional[list[dict[str, Any]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, ) -> list[RequestOutput]: """ @@ -696,6 +697,8 @@ class LLM: continue_final_message: If True, continues the final message in the conversation instead of starting a new one. Cannot be ``True`` if ``add_generation_prompt`` is also ``True``. + chat_template_kwargs: Additional kwargs to pass to the chat + template. mm_processor_kwargs: Multimodal processor kwarg overrides for this chat request. Only used for offline requests. @@ -726,6 +729,14 @@ class LLM: trust_remote_code=model_config.trust_remote_code, ) + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + prompts: list[Union[TokensPrompt, TextPrompt]] = [] for msgs in list_of_messages: @@ -743,20 +754,14 @@ class LLM: prompt_token_ids = apply_mistral_chat_template( tokenizer, messages=msgs, - chat_template=chat_template, - tools=tools, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, + **_chat_template_kwargs, ) else: prompt_str = apply_hf_chat_template( tokenizer, trust_remote_code=model_config.trust_remote_code, conversation=conversation, - chat_template=chat_template, - tools=tools, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, + **_chat_template_kwargs, ) # Special tokens are already included in chat templates so # should not be added by the tokenizer in this case.