diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 647f1c7b7f34..0c1f19371a16 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -46,23 +46,27 @@ MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @pytest.fixture(scope="function") def phi3v_model_config(): - return ModelConfig(PHI3V_MODEL_ID, - runner="generate", - trust_remote_code=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="function") def phi3v_model_config_mm_interleaved(): - return ModelConfig(PHI3V_MODEL_ID, - runner="generate", - trust_remote_code=True, - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -77,14 +81,16 @@ def phi3v_tokenizer(): @pytest.fixture(scope="function") def qwen25omni_model_config_mm_interleaved(): - return ModelConfig(QWEN25OMNI_MODEL_ID, - runner="generate", - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - "audio": 1, - "video": 1, - }) + return ModelConfig( + QWEN25OMNI_MODEL_ID, + runner="generate", + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + "audio": 1, + "video": 1, + }, + ) @pytest.fixture(scope="module") @@ -99,11 +105,13 @@ def qwen25omni_tokenizer(): @pytest.fixture(scope="module") def mllama_model_config(): - return ModelConfig(MLLAMA_MODEL_ID, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + MLLAMA_MODEL_ID, + runner="generate", + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -118,11 +126,13 @@ def mllama_tokenizer(): @pytest.fixture(scope="function") def mistral_model_config(): - return ModelConfig(MISTRAL_MODEL_ID, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + MISTRAL_MODEL_ID, + runner="generate", + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -137,21 +147,21 @@ def mistral_tokenizer(): @pytest.fixture(scope="module") def image_url(): - image = ImageAsset('cherry_blossom') + image = ImageAsset("cherry_blossom") base64 = encode_image_base64(image.pil_image) return f"data:image/jpeg;base64,{base64}" @pytest.fixture(scope="module") def video_url(): - video = VideoAsset('baby_reading', 1) + video = VideoAsset("baby_reading", 1) base64 = encode_video_base64(video.np_ndarrays) return f"data:video/jpeg;base64,{base64}" @pytest.fixture(scope="module") def audio_url(): - audio = AudioAsset('mary_had_lamb') + audio = AudioAsset("mary_had_lamb") base64 = encode_audio_base64(*audio.audio_and_sample_rate) return f"data:audio/ogg;base64,{base64}" @@ -195,15 +205,18 @@ def test_parse_chat_messages_single_image( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -223,58 +236,69 @@ def test_parse_chat_messages_empty_system( ): # Test string format conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], + [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "Who are you?" + }], + }, + ], mistral_model_config, mistral_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": "Who are you?" - }] + assert conversation == [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "Who are you?" + }, + ] # Test openai format conversation, _ = parse_chat_messages( - [{ + [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "Who are you?" + }], + }, + ], + mistral_model_config, + mistral_tokenizer, + content_format="openai", + ) + assert conversation == [ + { "role": "system", - "content": "" - }, { + "content": [{ + "type": "text", + "text": "" + }] + }, + { "role": "user", "content": [{ "type": "text", "text": "Who are you?" }] - }], - mistral_model_config, - mistral_tokenizer, - content_format="openai", - ) - assert conversation == [{ - "role": "system", - "content": [{ - "type": "text", - "text": "" - }] - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }] + }, + ] @pytest.mark.asyncio @@ -287,15 +311,18 @@ async def test_parse_chat_messages_single_image_async( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -318,18 +345,22 @@ def test_parse_chat_messages_multiple_images( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -340,7 +371,7 @@ def test_parse_chat_messages_multiple_images( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(mm_data, 2) @@ -355,18 +386,22 @@ async def test_parse_chat_messages_multiple_images_async( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -377,7 +412,7 @@ async def test_parse_chat_messages_multiple_images_async( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(await mm_future, 2) @@ -391,22 +426,26 @@ def test_parse_chat_messages_placeholder_already_in_prompt( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to <|image_2|>?", # noqa: E501 + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -416,7 +455,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt( "role": "user", "content": - "What's in <|image_1|> and how does it compare to <|image_2|>?" + "What's in <|image_1|> and how does it compare to <|image_2|>?", }] _assert_mm_data_is_image_input(mm_data, 2) @@ -447,9 +486,9 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( "type": "text", "text": - "What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 - } - ] + "What's in <|image_1|> and how does it compare to the other one?", # noqa: E501 + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -461,7 +500,7 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( "user", "content": "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " - "other one?" + "other one?", }] _assert_mm_data_is_image_input(mm_data, 2) @@ -472,34 +511,44 @@ def test_parse_chat_messages_multiple_images_across_messages( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about this one?" - }] - }], + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What about this one?" + }, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -527,19 +576,23 @@ def test_parse_chat_messages_context_text_format( phi3v_tokenizer, ): conversation, mm_data = parse_chat_messages( - [{ - "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What about this one?" - }], + [ + { + "role": "user", + "content": [{ + "type": "text", + "text": "What's in this text?" + }], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What about this one?" + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="openai", @@ -551,21 +604,21 @@ def test_parse_chat_messages_context_text_format( "content": [{ "type": "text", "text": "What's in this text?" - }] + }], }, { "role": "assistant", "content": [{ "type": "text", "text": "Some stuff." - }] + }], }, { "role": "user", "content": [{ "type": "text", "text": "What about this one?" - }] + }], }, ] @@ -578,31 +631,37 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + message="coroutine 'async_get_and_parse_image' was never awaited", + ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -618,42 +677,54 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + message="coroutine 'async_get_and_parse_image' was never awaited", + ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about these two?" - }] - }], + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What about these two?" + }, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -670,12 +741,14 @@ def test_parse_chat_messages_multiple_images_uncommon_input( "role": "user", "content": [ - "What's in these images?", { + "What's in these images?", + { "image_url": image_url - }, { + }, + { "image_url": image_url - } - ] + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -686,7 +759,7 @@ def test_parse_chat_messages_multiple_images_uncommon_input( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(mm_data, 2) @@ -700,26 +773,32 @@ def test_parse_chat_messages_multiple_images_interleave( [{ "role": "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "and this one" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -731,7 +810,7 @@ def test_parse_chat_messages_multiple_images_interleave( "user", "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }] _assert_mm_data_is_image_input(mm_data, 2) @@ -746,26 +825,32 @@ async def test_parse_chat_messages_multiple_images_interleave_async( [{ "role": "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "and this one" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -777,7 +862,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async( "user", "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }] _assert_mm_data_is_image_input(await mm_data, 2) @@ -788,135 +873,161 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Be accurate." - }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] - }], + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Be accurate." + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + ], + }, + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's on this image?\n<|image_1|>\nBe accurate." - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What's on this image?\n<|image_2|>" - }] + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|image_1|>\nBe accurate.", + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What's on this image?\n<|image_2|>" + }, + ] _assert_mm_data_is_image_input(mm_data, 2) def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( - qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, - image_url, video_url, audio_url): + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Now listen to this audio" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "And what's in the video?" - }, { - "type": "video_url", - "video_url": { - "url": video_url - } - }] - }], + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + } + }, + ], + }, + ], qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>" - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>" - }] + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) @@ -929,7 +1040,8 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( with pytest.raises( ValueError, match=r"Found more '<|image_1|>' placeholders in input prompt " - "than actual multimodal data items."): + "than actual multimodal data items.", + ): parse_chat_messages( [{ "role": @@ -952,9 +1064,9 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( "text", "text": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }, - ] + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -973,12 +1085,15 @@ def test_mllama_single_image( [{ "role": "user", - "content": [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - "image_url": image_url - }] + "content": [ + { + "type": "text", + "text": "The content of this image is:" + }, + { + "image_url": image_url + }, + ], }], mllama_model_config, mllama_tokenizer, @@ -986,14 +1101,17 @@ def test_mllama_single_image( ) _assert_mm_data_is_image_input(mm_data, 1) assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - 'type': 'image' - }] + "role": + "user", + "content": [ + { + "type": "text", + "text": "The content of this image is:" + }, + { + "type": "image" + }, + ], }] @@ -1009,20 +1127,20 @@ def test_mllama_interleaved_images( "user", "content": [ { - 'type': 'text', - 'text': 'The content of the first image is:' + "type": "text", + "text": "The content of the first image is:", }, { "image_url": image_url }, { - 'type': 'text', - 'text': 'The content of the second image is:' + "type": "text", + "text": "The content of the second image is:", }, { "image_url": image_url }, - ] + ], }], mllama_model_config, mllama_tokenizer, @@ -1030,19 +1148,24 @@ def test_mllama_interleaved_images( ) _assert_mm_data_is_image_input(mm_data, 2) assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of the first image is:' - }, { - 'type': 'image' - }, { - 'type': 'text', - 'text': 'The content of the second image is:' - }, { - 'type': 'image' - }] + "role": + "user", + "content": [ + { + "type": "text", + "text": "The content of the first image is:" + }, + { + "type": "image" + }, + { + "type": "text", + "text": "The content of the second image is:" + }, + { + "type": "image" + }, + ], }] @@ -1053,34 +1176,36 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): def get_conversation(is_hf: bool): img_part = {"type": "image_url", "image_url": {"url": image_url}} if is_hf: - img_part = {'type': 'image'} + img_part = {"type": "image"} return [{ - 'role': - 'user', - 'content': [ + "role": + "user", + "content": [ { - 'type': 'text', - 'text': 'The content of the first image is:' + "type": "text", + "text": "The content of the first image is:", }, img_part, { - 'type': 'text', - 'text': 'The content of the second image is:' + "type": "text", + "text": "The content of the second image is:", }, img_part, { - 'type': 'text', - 'text': 'What animal is in the first image?' + "type": "text", + "text": "What animal is in the first image?", }, - ] + ], }] # Build a config for the model - model_config = ModelConfig(model, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) + model_config = ModelConfig( + model, + runner="generate", + limit_mm_per_prompt={ + "image": 2, + }, + ) # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( @@ -1126,7 +1251,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): [ QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str HERMES_MODEL_ID, # tokenizer.chat_template is of type dict - ]) + ], +) @pytest.mark.parametrize("use_tools", [True, False]) def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): """checks that chat_template is a dict type for HF models.""" @@ -1152,14 +1278,14 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ) tokenizer = tokenizer_group.tokenizer - tools = [{ + tools = ([{ "type": "function", "function": { "name": "dummy_function_name", "description": "This is a dummy function", - "parameters": sample_json_schema - } - }] if use_tools else None + "parameters": sample_json_schema, + }, + }] if use_tools else None) # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 1954cbcbf1ed..80e2c44a0251 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -103,6 +103,7 @@ class PILImage(BaseModel): """ A PIL.Image.Image object. """ + image_pil: Image.Image model_config = ConfigDict(arbitrary_types_allowed=True) @@ -115,6 +116,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): "image_pil": ImageAsset('cherry_blossom').pil_image } """ + image_pil: Required[PILImage] @@ -127,6 +129,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): "image_url": "https://example.com/image.jpg" } """ + image_url: Required[str] @@ -138,6 +141,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): "audio_url": "https://example.com/audio.mp3" } """ + audio_url: Required[str] @@ -149,6 +153,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): "video_url": "https://example.com/video.mp4" } """ + video_url: Required[str] @@ -174,19 +179,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartAudioParam, ChatCompletionContentPartInputAudioParam, - ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartVideoParam, + ChatCompletionContentPartRefusalParam, CustomChatCompletionContentPILImageParam, CustomChatCompletionContentSimpleImageParam, ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, - CustomChatCompletionContentSimpleVideoParam, str, - CustomThinkCompletionContentParam] + CustomChatCompletionContentSimpleVideoParam, + str, + CustomThinkCompletionContentParam, +] class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" + role: Required[str] """The role of the message's author.""" @@ -207,9 +217,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): """The tool calls generated by the model, such as function calls.""" -ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, - CustomChatCompletionMessageParam, - OpenAIHarmonyMessage] +ChatCompletionMessageParam = Union[ + OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam, + OpenAIHarmonyMessage, +] # TODO: Make fields ReadOnly once mypy supports it @@ -262,13 +274,13 @@ def _is_var_or_elems_access( key: Optional[str] = None, ) -> bool: if isinstance(node, jinja2.nodes.Filter): - return (node.node is not None - and _is_var_or_elems_access(node.node, varname, key)) + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) - if (isinstance(node, jinja2.nodes.Getitem) - and isinstance(node.arg, jinja2.nodes.Slice)): + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice): return _is_var_or_elems_access(node.node, varname, key) # yapf: disable @@ -373,15 +385,18 @@ def resolve_mistral_chat_template( ) -> Optional[str]: if chat_template is not None: logger.warning_once( - "'chat_template' cannot be overridden for mistral tokenizer.") + "'chat_template' cannot be overridden for mistral tokenizer." + ) if "add_generation_prompt" in kwargs: logger.warning_once( "'add_generation_prompt' is not supported for mistral tokenizer, " - "so it will be ignored.") + "so it will be ignored." + ) if "continue_final_message" in kwargs: logger.warning_once( "'continue_final_message' is not supported for mistral tokenizer, " - "so it will be ignored.") + "so it will be ignored." + ) return None @@ -401,23 +416,35 @@ def resolve_hf_chat_template( try: processor = cached_get_processor( tokenizer.name_or_path, - processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, - ProcessorMixin), + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), trust_remote_code=model_config.trust_remote_code, ) - if isinstance(processor, ProcessorMixin) and \ - hasattr(processor, 'chat_template') and \ - processor.chat_template is not None: + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and processor.chat_template is not None + ): return processor.chat_template except Exception: - logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) # noqa: E501 # 3rd priority: AutoTokenizer chat template try: return tokenizer.get_chat_template(chat_template, tools=tools) except Exception: - logger.debug("Failed to load AutoTokenizer chat template for %s", - tokenizer.name_or_path, exc_info=True) + logger.debug( + "Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) # 4th priority: Predefined fallbacks path = get_chat_template_fallback_path( @@ -425,12 +452,16 @@ def resolve_hf_chat_template( tokenizer_name_or_path=model_config.tokenizer, ) if path is not None: - logger.info("Loading chat template fallback for %s as there isn't one " - "defined on HF Hub.", tokenizer.name_or_path) + logger.info( + "Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", + tokenizer.name_or_path, + ) chat_template = load_chat_template(path) else: - logger.debug("There is no chat template fallback for %s", - tokenizer.name_or_path) + logger.debug( + "There is no chat template fallback for %s", tokenizer.name_or_path + ) return chat_template @@ -452,11 +483,17 @@ def _resolve_chat_template_content_format( else: hf_chat_template = None - jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) - else load_chat_template(chat_template, is_literal=True)) + jinja_text = ( + hf_chat_template + if isinstance(hf_chat_template, str) + else load_chat_template(chat_template, is_literal=True) + ) - detected_format = ("string" if jinja_text is None else - _detect_content_format(jinja_text, default="string")) + detected_format = ( + "string" + if jinja_text is None + else _detect_content_format(jinja_text, default="string") + ) return detected_format @@ -512,7 +549,6 @@ def resolve_chat_template_content_format( return detected_format - ModalityStr = Literal["image", "audio", "video", "image_embeds"] _T = TypeVar("_T") @@ -539,6 +575,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): @cached_property def model_cls(self) -> type[SupportsMultiModal]: from vllm.model_executor.model_loader import get_model_cls + model_cls = get_model_cls(self.model_config) return cast(type[SupportsMultiModal], model_cls) @@ -574,28 +611,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): class MultiModalItemTracker(BaseMultiModalItemTracker[object]): - def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: - raise ValueError(\ - "Mixing raw image and embedding inputs is not allowed") + raise ValueError( + "Mixing raw image and embedding inputs is not allowed" + ) if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: - raise ValueError(\ - "Only one message can have {'type': 'image_embeds'}") + raise ValueError( + "Only one message can have {'type': 'image_embeds'}" + ) mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: - mm_inputs["image"] = items_by_modality["image"] # A list of images + mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: - mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: - mm_inputs["video"] = items_by_modality["video"] # A list of videos + mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -603,32 +641,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): - async def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = { - modality: await asyncio.gather(*items) - for modality, items in self._items_by_modality.items() - } + modality: await asyncio.gather(*items) + for modality, items in self._items_by_modality.items() + } if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError( - "Mixing raw image and embedding inputs is not allowed") + "Mixing raw image and embedding inputs is not allowed" + ) if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError( - "Only one message can have {'type': 'image_embeds'}") + "Only one message can have {'type': 'image_embeds'}" + ) mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: - mm_inputs["image"] = items_by_modality["image"] # A list of images + mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: - mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: - mm_inputs["video"] = items_by_modality["video"] # A list of videos + mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -636,7 +675,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): class BaseMultiModalContentParser(ABC): - def __init__(self) -> None: super().__init__() @@ -648,8 +686,9 @@ class BaseMultiModalContentParser(ABC): # } self._placeholder_storage: dict[str, list] = defaultdict(list) - def _add_placeholder(self, modality: ModalityStr, - placeholder: Optional[str]): + def _add_placeholder( + self, modality: ModalityStr, placeholder: Optional[str] + ): mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] if placeholder: self._placeholder_storage[mod_placeholder].append(placeholder) @@ -662,8 +701,9 @@ class BaseMultiModalContentParser(ABC): raise NotImplementedError @abstractmethod - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, image_embeds: Union[str, dict[str, str]] + ) -> None: raise NotImplementedError @abstractmethod @@ -684,7 +724,6 @@ class BaseMultiModalContentParser(ABC): class MultiModalContentParser(BaseMultiModalContentParser): - def __init__(self, tracker: MultiModalItemTracker) -> None: super().__init__() @@ -701,8 +740,9 @@ class MultiModalContentParser(BaseMultiModalContentParser): placeholder = self._tracker.add("image", image) self._add_placeholder("image", placeholder) - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, image_embeds: Union[str, dict[str, str]] + ) -> None: if isinstance(image_embeds, dict): embeds = { k: self._connector.fetch_image_embedding(v) @@ -741,14 +781,13 @@ class MultiModalContentParser(BaseMultiModalContentParser): class AsyncMultiModalContentParser(BaseMultiModalContentParser): - def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: super().__init__() self._tracker = tracker self._connector = MediaConnector( media_io_kwargs=self._tracker._model_config.media_io_kwargs, - allowed_local_media_path=tracker.allowed_local_media_path + allowed_local_media_path=tracker.allowed_local_media_path, ) def parse_image(self, image_url: str) -> None: @@ -757,8 +796,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): placeholder = self._tracker.add("image", image_coro) self._add_placeholder("image", placeholder) - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, image_embeds: Union[str, dict[str, str]] + ) -> None: future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() if isinstance(image_embeds, dict): @@ -769,8 +809,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): future.set_result(embeds) if isinstance(image_embeds, str): - embedding = self._connector.\ - fetch_image_embedding(image_embeds) + embedding = self._connector.fetch_image_embedding(image_embeds) future.set_result(embedding) placeholder = self._tracker.add("image_embeds", future) @@ -809,20 +848,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): return elif isinstance(chat_template, Path) and not chat_template.exists(): - raise FileNotFoundError( - "the supplied chat template path doesn't exist") + raise FileNotFoundError("the supplied chat template path doesn't exist") elif isinstance(chat_template, str): JINJA_CHARS = "{}\n" - if not any(c in chat_template - for c in JINJA_CHARS) and not Path(chat_template).exists(): + if ( + not any(c in chat_template for c in JINJA_CHARS) + and not Path(chat_template).exists() + ): raise ValueError( f"The supplied chat template string ({chat_template}) " - f"appears path-like, but doesn't exist!") + f"appears path-like, but doesn't exist!" + ) else: raise TypeError( - f"{type(chat_template)} is not a valid chat template type") + f"{type(chat_template)} is not a valid chat template type" + ) def _load_chat_template( @@ -835,8 +877,9 @@ def _load_chat_template( if is_literal: if isinstance(chat_template, Path): - raise TypeError("chat_template is expected to be read directly " - "from its value") + raise TypeError( + "chat_template is expected to be read directly from its value" + ) return chat_template @@ -849,9 +892,11 @@ def _load_chat_template( JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}" + ) raise ValueError(msg) from e # If opening a file fails, set chat template to be args to @@ -870,8 +915,9 @@ def load_chat_template( return _cached_load_chat_template(chat_template, is_literal=is_literal) -def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], - texts: list[str]) -> str: +def _get_interleaved_text_prompt( + placeholder_storage: dict[str, list], texts: list[str] +) -> str: for idx, elem in enumerate(texts): if elem in placeholder_storage: texts[idx] = placeholder_storage[elem].pop(0) @@ -881,10 +927,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], - texts: list[str], - interleave_strings: bool - ) -> str: +def _get_full_multimodal_text_prompt( + placeholder_storage: dict[str, list], + texts: list[str], + interleave_strings: bool, +) -> str: """Combine multimodal prompts for a multimodal language model.""" # flatten storage to make it looks like @@ -907,7 +954,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], # Look through the text prompt to check for missing placeholders missing_placeholders: list[str] = [] for placeholder in placeholder_counts: - # For any existing placeholder in the text prompt, we leave it as is placeholder_counts[placeholder] -= text_prompt.count(placeholder) @@ -916,15 +962,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], "Placeholder count is negative! " "Ensure that the 'interleave_strings' flag is disabled " "(current value: %s) " - "when manually placing image placeholders.", interleave_strings + "when manually placing image placeholders.", + interleave_strings, ) logger.debug("Input prompt: %s", text_prompt) raise ValueError( f"Found more '{placeholder}' placeholders in input prompt than " - "actual multimodal data items.") + "actual multimodal data items." + ) - missing_placeholders.extend([placeholder] * - placeholder_counts[placeholder]) + missing_placeholders.extend( + [placeholder] * placeholder_counts[placeholder] + ) # NOTE: Default behaviour: we always add missing placeholders # at the front of the prompt, if interleave_strings=False @@ -944,7 +993,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _ResponsesInputImageParser = TypeAdapter( - ResponseInputImageParam).validate_python + ResponseInputImageParam +).validate_python _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] # Define a mapping from part types to their corresponding parsing functions. @@ -952,32 +1002,35 @@ MM_PARSER_MAP: dict[ str, Callable[[ChatCompletionContentPartParam], _ContentPart], ] = { - "text": - lambda part: _TextParser(part).get("text", None), - "thinking": - lambda part: _ThinkParser(part).get("thinking", None), - "input_text": - lambda part: _TextParser(part).get("text", None), - "input_image": - lambda part: _ResponsesInputImageParser(part).get("image_url", None), - "image_url": - lambda part: _ImageParser(part).get("image_url", {}).get("url", None), - "image_embeds": - lambda part: _ImageEmbedsParser(part).get("image_embeds", None), + "text": lambda part: _TextParser(part).get("text", None), + "thinking": lambda part: _ThinkParser(part).get("thinking", None), + "input_text": lambda part: _TextParser(part).get("text", None), + "input_image": lambda part: _ResponsesInputImageParser(part).get( + "image_url", None + ), + "image_url": lambda part: _ImageParser(part) + .get("image_url", {}) + .get("url", None), + "image_embeds": lambda part: _ImageEmbedsParser(part).get( + "image_embeds", None + ), "image_pil": lambda part: _PILImageParser(part).get("image_pil", None), - "audio_url": - lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), - "input_audio": - lambda part: _InputAudioParser(part).get("input_audio", None), - "refusal": - lambda part: _RefusalParser(part).get("refusal", None), - "video_url": - lambda part: _VideoParser(part).get("video_url", {}).get("url", None), + "audio_url": lambda part: _AudioParser(part) + .get("audio_url", {}) + .get("url", None), + "input_audio": lambda part: _InputAudioParser(part).get( + "input_audio", None + ), + "refusal": lambda part: _RefusalParser(part).get("refusal", None), + "video_url": lambda part: _VideoParser(part) + .get("video_url", {}) + .get("url", None), } def _parse_chat_message_content_mm_part( - part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: + part: ChatCompletionContentPartParam, +) -> tuple[str, _ContentPart]: """ Parses a given multi-modal content part based on its type. @@ -993,7 +1046,8 @@ def _parse_chat_message_content_mm_part( ValueError: If the 'type' field is missing and no direct URL is found. """ assert isinstance( - part, dict) # This is needed to avoid mypy errors: part.get() from str + part, dict + ) # This is needed to avoid mypy errors: part.get() from str part_type = part.get("type", None) if isinstance(part_type, str) and part_type in MM_PARSER_MAP: @@ -1002,8 +1056,10 @@ def _parse_chat_message_content_mm_part( # Special case for 'image_url.detail' # We only support 'auto', which is the default if part_type == "image_url" and part.get("detail", "auto") != "auto": - logger.warning("'image_url.detail' is currently not supported " - "and will be ignored.") + logger.warning( + "'image_url.detail' is currently not supported " + "and will be ignored." + ) return part_type, content @@ -1011,19 +1067,22 @@ def _parse_chat_message_content_mm_part( # 'type' is required field by pydantic if part_type is None: if part.get("image_url") is not None: - image_params = cast(CustomChatCompletionContentSimpleImageParam, - part) + image_params = cast( + CustomChatCompletionContentSimpleImageParam, part + ) return "image_url", image_params.get("image_url", "") if part.get("audio_url") is not None: - audio_params = cast(CustomChatCompletionContentSimpleAudioParam, - part) + audio_params = cast( + CustomChatCompletionContentSimpleAudioParam, part + ) return "audio_url", audio_params.get("audio_url", "") if part.get("input_audio") is not None: input_audio_params = cast(dict[str, str], part) return "input_audio", input_audio_params if part.get("video_url") is not None: - video_params = cast(CustomChatCompletionContentSimpleVideoParam, - part) + video_params = cast( + CustomChatCompletionContentSimpleVideoParam, part + ) return "video_url", video_params.get("video_url", "") # Raise an error if no 'type' or direct URL is found. raise ValueError("Missing 'type' field in multimodal part.") @@ -1033,9 +1092,16 @@ def _parse_chat_message_content_mm_part( return part_type, "unknown part_type content" -VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", - "image_embeds", "image_pil", - "audio_url", "input_audio", "video_url") +VALID_MESSAGE_CONTENT_MM_PART_TYPES = ( + "text", + "refusal", + "image_url", + "image_embeds", + "image_pil", + "audio_url", + "input_audio", + "video_url", +) def _parse_chat_message_content_parts( @@ -1055,21 +1121,20 @@ def _parse_chat_message_content_parts( part, mm_parser, wrap_dicts=wrap_dicts, - interleave_strings=interleave_strings + interleave_strings=interleave_strings, ) if parse_res: content.append(parse_res) if wrap_dicts: # Parsing wraps images and texts as interleaved dictionaries - return [ConversationMessage(role=role, - content=content)] # type: ignore + return [ConversationMessage(role=role, content=content)] # type: ignore texts = cast(list[str], content) mm_placeholder_storage = mm_parser.mm_placeholder_storage() if mm_placeholder_storage: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, - texts, - interleave_strings) + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_storage, texts, interleave_strings + ) else: text_prompt = "\n".join(texts) @@ -1099,13 +1164,16 @@ def _parse_chat_message_content_part( if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: logger.warning( "Skipping multimodal part '%s' (type: '%s') " - "with empty / unparsable content.", part, part_type) + "with empty / unparsable content.", + part, + part_type, + ) return None if part_type in ("text", "input_text", "refusal", "thinking"): str_content = cast(str, content) if wrap_dicts: - return {'type': 'text', 'text': str_content} + return {"type": "text", "text": str_content} else: return str_content @@ -1137,8 +1205,12 @@ def _parse_chat_message_content_part( else: raise NotImplementedError(f"Unknown part type: {part_type}") - return {'type': modality} if wrap_dicts else ( - MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None + return ( + {"type": modality} + if wrap_dicts + else ( + MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None + ) ) @@ -1171,14 +1243,16 @@ def _parse_chat_message_content( ) for result_msg in result: - if role == 'assistant': + if role == "assistant": parsed_msg = _AssistantParser(message) # The 'tool_calls' is not None check ensures compatibility. # It's needed only if downstream code doesn't strictly # follow the OpenAI spec. - if ("tool_calls" in parsed_msg - and parsed_msg["tool_calls"] is not None): + if ( + "tool_calls" in parsed_msg + and parsed_msg["tool_calls"] is not None + ): result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": parsed_msg = _ToolParser(message) @@ -1198,12 +1272,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: # so, for messages that have tool_calls, parse the string (which we get # from openAI format) to dict for message in messages: - if (message["role"] == "assistant" and "tool_calls" in message - and isinstance(message["tool_calls"], list)): - + if ( + message["role"] == "assistant" + and "tool_calls" in message + and isinstance(message["tool_calls"], list) + ): for item in message["tool_calls"]: item["function"]["arguments"] = json.loads( - item["function"]["arguments"]) + item["function"]["arguments"] + ) def parse_chat_messages( @@ -1224,7 +1301,7 @@ def parse_chat_messages( content_format == "string" and model_config.multimodal_config is not None and model_config.multimodal_config.interleave_mm_strings - ) + ), ) conversation.extend(sub_messages) @@ -1252,7 +1329,7 @@ def parse_chat_messages_futures( content_format == "string" and model_config.multimodal_config is not None and model_config.multimodal_config.interleave_mm_strings - ) + ), ) conversation.extend(sub_messages) @@ -1283,10 +1360,10 @@ def apply_hf_chat_template( raise ValueError( "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one.") + "does not define one." + ) try: - return tokenizer.apply_chat_template( conversation=conversation, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type] @@ -1298,13 +1375,14 @@ def apply_hf_chat_template( # External library exceptions can sometimes occur despite the framework's # internal exception management capabilities. except Exception as e: - # Log and report any library-related exceptions for further # investigation. logger.exception( - "An error occurred in `transformers` while applying chat template") + "An error occurred in `transformers` while applying chat template" + ) raise ValueError(str(e)) from e + def apply_mistral_chat_template( tokenizer: MistralTokenizer, messages: list[ChatCompletionMessageParam], @@ -1337,26 +1415,26 @@ def apply_mistral_chat_template( # External library exceptions can sometimes occur despite the framework's # internal exception management capabilities. except Exception as e: - # Log and report any library-related exceptions for further # investigation. logger.exception( - "An error occurred in `mistral_common` while applying chat " - "template") + "An error occurred in `mistral_common` while applying chat template" + ) raise ValueError(str(e)) from e + def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): idx = 0 for msg in conversation: - if msg['role'] == 'assistant': - tool_calls = msg.get('tool_calls') - idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + if msg["role"] == "assistant": + tool_calls = msg.get("tool_calls") + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa return idx -def make_tool_call_id(id_type:str='random', func_name=None, idx=None): - if id_type=='kimi_k2': - return f'functions.{func_name}:{idx}' +def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): + if id_type == "kimi_k2": + return f"functions.{func_name}:{idx}" else: # by default return random return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 796b8ab5fc2c..f506f7de1682 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -82,16 +82,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, logger = init_logger(__name__) -CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, - EmbeddingCompletionRequest, RerankRequest, - ClassificationRequest, ScoreRequest, - TokenizeCompletionRequest] +CompletionLikeRequest = Union[ + CompletionRequest, + DetokenizeRequest, + EmbeddingCompletionRequest, + RerankRequest, + ClassificationRequest, + ScoreRequest, + TokenizeCompletionRequest, +] ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, - ResponsesRequest, IOProcessorRequest] +AnyRequest = Union[ + CompletionLikeRequest, + ChatLikeRequest, + SpeechToTextRequest, + ResponsesRequest, + IOProcessorRequest, +] AnyResponse = Union[ CompletionResponse, @@ -135,6 +145,7 @@ class RequestProcessingMixin(BaseModel): Mixin for request processing, handling prompt preparation and engine input. """ + request_prompts: Optional[Sequence[RequestPrompt]] = [] engine_prompts: Optional[Union[list[EngineTokensPrompt], list[EngineEmbedsPrompt]]] = [] @@ -147,6 +158,7 @@ class ResponseGenerationMixin(BaseModel): Mixin for response generation, managing result generators and final batch results. """ + result_generator: Optional[AsyncGenerator[tuple[int, Union[ RequestOutput, PoolingRequestOutput]], None]] = None final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( @@ -155,8 +167,12 @@ class ResponseGenerationMixin(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, - Generic[RequestT]): +class ServeContext( + RequestProcessingMixin, + ResponseGenerationMixin, + BaseModel, + Generic[RequestT], +): # Shared across all requests request: RequestT raw_request: Optional[Request] = None @@ -298,8 +314,8 @@ class OpenAIServing: truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - if truncate_prompt_tokens is not None and \ - truncate_prompt_tokens > self.max_model_len: + if (truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.max_model_len): return self.create_error_response( "truncate_prompt_tokens value is " "greater than max_model_len." @@ -344,10 +360,12 @@ class OpenAIServing: return self.create_error_response( "Request prompts not available") - self._log_inputs(request_id_item, - ctx.request_prompts[i], - params=pooling_params, - lora_request=ctx.lora_request) + self._log_inputs( + request_id_item, + ctx.request_prompts[i], + params=pooling_params, + lora_request=ctx.lora_request, + ) # Mypy has an existing bug related to inferring the variance of # TypedDicts with `builtins.enumerate`: @@ -410,10 +428,11 @@ class OpenAIServing: return self.create_error_response(str(e)) def create_error_response( - self, - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + ) -> ErrorResponse: if self.log_error_stack: exc_type, _, _ = sys.exc_info() if exc_type is not None: @@ -424,10 +443,11 @@ class OpenAIServing: message=message, type=err_type, code=status_code.value)) def create_streaming_error_response( - self, - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + ) -> str: json_str = json.dumps( self.create_error_response(message=message, err_type=err_type, @@ -438,25 +458,25 @@ class OpenAIServing: self, request: AnyRequest, ) -> Optional[ErrorResponse]: - error_response = None if self._is_model_supported(request.model): return None if request.model in self.models.lora_requests: return None - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( - load_result := await self.models.resolve_lora(request.model)): + if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and + (load_result := await self.models.resolve_lora(request.model))): if isinstance(load_result, LoRARequest): return None - if isinstance(load_result, ErrorResponse) and \ - load_result.error.code == HTTPStatus.BAD_REQUEST.value: + if (isinstance(load_result, ErrorResponse) and + load_result.error.code == HTTPStatus.BAD_REQUEST.value): error_response = load_result return error_response or self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) + status_code=HTTPStatus.NOT_FOUND, + ) def _get_active_default_mm_loras( self, request: AnyRequest) -> Optional[LoRARequest]: @@ -487,7 +507,6 @@ class OpenAIServing: request: AnyRequest, supports_default_mm_loras: bool = False, ) -> Optional[LoRARequest]: - if request.model in self.models.lora_requests: return self.models.lora_requests[request.model] @@ -548,13 +567,15 @@ class OpenAIServing: prompt, add_special_tokens=add_special_tokens, truncation=True, - max_length=self.max_model_len) + max_length=self.max_model_len, + ) else: encoded = await async_tokenizer( prompt, add_special_tokens=add_special_tokens, truncation=True, - max_length=truncate_prompt_tokens) + max_length=truncate_prompt_tokens, + ) input_ids = encoded.input_ids input_text = prompt @@ -595,16 +616,22 @@ class OpenAIServing: # Note: EmbeddingRequest, ClassificationRequest, # and ScoreRequest doesn't have max_tokens - if isinstance(request, - (EmbeddingChatRequest, EmbeddingCompletionRequest, - ScoreRequest, RerankRequest, ClassificationRequest)): - + if isinstance( + request, + ( + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ScoreRequest, + RerankRequest, + ClassificationRequest, + ), + ): # Note: input length can be up to the entire model context length # since these requests don't generate tokens. if token_num > self.max_model_len: operations: dict[type[AnyRequest], str] = { ScoreRequest: "score", - ClassificationRequest: "classification" + ClassificationRequest: "classification", } operation = operations.get(type(request), "embedding generation") @@ -618,8 +645,11 @@ class OpenAIServing: # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation - if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, - DetokenizeRequest)): + if isinstance( + request, + (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest), + ): return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) @@ -639,8 +669,8 @@ class OpenAIServing: f"{token_num} input tokens. Please reduce the length of " "the input messages.") - if max_tokens is not None and \ - token_num + max_tokens > self.max_model_len: + if (max_tokens is not None + and token_num + max_tokens > self.max_model_len): raise ValueError( "'max_tokens' or 'max_completion_tokens' is too large: " f"{max_tokens}. This model's maximum context length is " @@ -745,13 +775,14 @@ class OpenAIServing: tasks = [] for prompt_input in batch_inputs: if prompt_input["is_tokens"] is False: - assert tokenizer is not None, \ - "Tokenizer is required for text prompts" + assert tokenizer is not None, ( + "Tokenizer is required for text prompts") task = self._normalize_prompt_text_to_input( request, prompt_input["content"], tokenizer=tokenizer, - add_special_tokens=add_special_tokens) + add_special_tokens=add_special_tokens, + ) else: task = self._normalize_prompt_tokens_to_input( request, prompt_input["content"], tokenizer=tokenizer) @@ -766,9 +797,14 @@ class OpenAIServing: @overload async def _preprocess_completion( self, - request: Union[DetokenizeRequest, EmbeddingCompletionRequest, - RerankRequest, ClassificationRequest, ScoreRequest, - TokenizeCompletionRequest], + request: Union[ + DetokenizeRequest, + EmbeddingCompletionRequest, + RerankRequest, + ClassificationRequest, + ScoreRequest, + TokenizeCompletionRequest, + ], tokenizer: Optional[AnyTokenizer], input_or_inputs: Union[str, list[str], list[int], list[list[int]]], add_special_tokens: bool = ..., @@ -783,8 +819,10 @@ class OpenAIServing: input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], add_special_tokens: bool = ..., - ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ - EngineTokensPrompt, EngineEmbedsPrompt]]]: + ) -> tuple[ + list[Union[TextTokensPrompt, EmbedsPrompt]], + list[Union[EngineTokensPrompt, EngineEmbedsPrompt]], + ]: ... async def _preprocess_completion( @@ -794,32 +832,38 @@ class OpenAIServing: input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], add_special_tokens: bool = True, - ) -> tuple[Union[list[TextTokensPrompt], list[Union[ - TextTokensPrompt, EmbedsPrompt]]], Union[ - list[EngineTokensPrompt], list[Union[EngineTokensPrompt, - EngineEmbedsPrompt]]]]: - if not isinstance(request, - CompletionRequest) and input_or_inputs is None: + ) -> tuple[ + Union[list[TextTokensPrompt], list[Union[TextTokensPrompt, + EmbedsPrompt]]], + Union[ + list[EngineTokensPrompt], + list[Union[EngineTokensPrompt, EngineEmbedsPrompt]], + ], + ]: + if (not isinstance(request, CompletionRequest) + and input_or_inputs is None): raise ValueError( "Prompt embeds with non-completion requests is not" " currently supported.") - (request_prompts_text, request_prompts_embeds - ) = await self._tokenize_prompt_input_or_inputs_async( - request, - tokenizer, - input_or_inputs, - add_special_tokens=add_special_tokens, - ) + ( + request_prompts_text, + request_prompts_embeds, + ) = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + add_special_tokens=add_special_tokens, + ) engine_prompts_text = [ EngineTokensPrompt( prompt_token_ids=request_prompt_text["prompt_token_ids"]) for request_prompt_text in request_prompts_text ] - cache_salt = request.cache_salt if ( - hasattr(request, "cache_salt") - and request.cache_salt is not None) else None + cache_salt = (request.cache_salt if + (hasattr(request, "cache_salt") + and request.cache_salt is not None) else None) if cache_salt: for prompt_text in engine_prompts_text: prompt_text["cache_salt"] = cache_salt @@ -831,8 +875,8 @@ class OpenAIServing: # non-completion requests and if we don't add the overload here, # everywhere this function is used outside of serving_completion will # need logic asserting that only text prompts are in the request. - if not isinstance(request, - CompletionRequest) and input_or_inputs is not None: + if (not isinstance(request, CompletionRequest) + and input_or_inputs is not None): return request_prompts_text, engine_prompts_text engine_prompts_embeds = [ @@ -862,8 +906,11 @@ class OpenAIServing: chat_template_kwargs: Optional[dict[str, Any]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, add_special_tokens: bool = False, - ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], - list[EngineTokensPrompt]]: + ) -> tuple[ + list[ConversationMessage], + Sequence[RequestPrompt], + list[EngineTokensPrompt], + ]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -925,8 +972,8 @@ class OpenAIServing: if tokenizer is None: assert isinstance(request_prompt, str), ( - "Prompt has to be a string", \ - "when the tokenizer is not initialised" + "Prompt has to be a string", + "when the tokenizer is not initialised", ) prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) @@ -943,7 +990,8 @@ class OpenAIServing: "Prompt has to be either a string or a list of token ids") prompt_inputs = TextTokensPrompt( prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt) + prompt_token_ids=request_prompt, + ) engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) @@ -1007,22 +1055,23 @@ class OpenAIServing: prompt_token_ids=prompt_token_ids) request_prompt = prompt_token_ids # Update the sampling params. - sampling_params.max_tokens = (self.max_model_len - - len(prompt_token_ids)) + sampling_params.max_tokens = self.max_model_len - len( + prompt_token_ids) # OPTIMIZATION priority = orig_priority - 1 @staticmethod def _load_prompt_embeds( prompt_embeds: Optional[Union[bytes, list[bytes]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> list[EmbedsPrompt]: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: - tensor = torch.load(io.BytesIO( - pybase64.b64decode(embed, validate=True)), - weights_only=True, - map_location=torch.device("cpu")) + tensor = torch.load( + io.BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( torch.float32, torch.bfloat16, @@ -1061,7 +1110,7 @@ class OpenAIServing: prompt = inputs elif isinstance(inputs, list): prompt_token_ids = inputs - elif 'prompt_embeds' in inputs: + elif "prompt_embeds" in inputs: prompt_embeds = inputs.get("prompt_embeds") else: prompt = inputs["prompt"] @@ -1101,10 +1150,12 @@ class OpenAIServing: return raw_request.headers.get("X-Request-Id", default) @staticmethod - def _get_decoded_token(logprob: Logprob, - token_id: int, - tokenizer: AnyTokenizer, - return_as_token_id: bool = False) -> str: + def _get_decoded_token( + logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False, + ) -> str: if return_as_token_id: return f"token_id:{token_id}" @@ -1117,9 +1168,11 @@ class OpenAIServing: return True return self.models.is_base_model(model_name) - def _get_model_name(self, - model_name: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> str: + def _get_model_name( + self, + model_name: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + ) -> str: if lora_request: return lora_request.lora_name if not model_name: @@ -1129,7 +1182,7 @@ class OpenAIServing: def clamp_prompt_logprobs( prompt_logprobs: Union[PromptLogprobs, - None]) -> Union[PromptLogprobs, None]: + None], ) -> Union[PromptLogprobs, None]: if prompt_logprobs is None: return prompt_logprobs @@ -1137,6 +1190,6 @@ def clamp_prompt_logprobs( if logprob_dict is None: continue for logprob_values in logprob_dict.values(): - if logprob_values.logprob == float('-inf'): + if logprob_values.logprob == float("-inf"): logprob_values.logprob = -9999.0 return prompt_logprobs