[Frontend] Support image object in llm.chat (#19635)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: Flora Feng <4florafeng@gmail.com>
This commit is contained in:
Flora Feng 2025-07-05 23:47:13 -07:00 committed by GitHub
parent 4548c03c50
commit fe1e924811
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 97 additions and 13 deletions

View File

@ -101,6 +101,49 @@ To substitute multiple images inside the same text prompt, you can pass in a lis
Full example: <gh-file:examples/offline_inference/vision_language_multi_image.py> Full example: <gh-file:examples/offline_inference/vision_language_multi_image.py>
If using the [LLM.chat](https://docs.vllm.ai/en/stable/models/generative_models.html#llmchat) method, you can pass images directly in the message content using various formats: image URLs, PIL Image objects, or pre-computed embeddings:
```python
from vllm import LLM
from vllm.assets.image import ImageAsset
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
image_url = "https://picsum.photos/id/32/512/512"
image_pil = ImageAsset('cherry_blossom').pil_image
image_embeds = torch.load(...)
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{
"role": "user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
},{
"type": "image_pil",
"image_pil": image_pil
}, {
"type": "image_embeds",
"image_embeds": image_embeds
}, {
"type": "text",
"text": "What's in these images?"
}],
},
]
# Perform inference and log output.
outputs = llm.chat(conversation)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
```
Multi-image input can be extended to perform video captioning. We show this with [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) as it supports videos: Multi-image input can be extended to perform video captioning. We show this with [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) as it supports videos:
??? Code ??? Code

View File

@ -6,6 +6,7 @@ import argparse
from vllm import LLM from vllm import LLM
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.assets.image import ImageAsset
# This script is an offline demo for running Mistral-Small-3.1 # This script is an offline demo for running Mistral-Small-3.1
# #
@ -71,14 +72,16 @@ def run_simple_demo(args: argparse.Namespace):
) )
prompt = "Describe this image in one sentence." prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300"
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": prompt}, {"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_url}}, {
"type": "image_pil",
"image_pil": ImageAsset("cherry_blossom").pil_image,
},
], ],
}, },
] ]

View File

@ -264,10 +264,8 @@ def test_parse_chat_messages_multiple_images(
"url": image_url "url": image_url
} }
}, { }, {
"type": "image_url", "type": "image_pil",
"image_url": { "image_pil": ImageAsset('cherry_blossom').pil_image
"url": image_url
}
}, { }, {
"type": "text", "type": "text",
"text": "What's in these images?" "text": "What's in these images?"
@ -303,10 +301,8 @@ async def test_parse_chat_messages_multiple_images_async(
"url": image_url "url": image_url
} }
}, { }, {
"type": "image_url", "type": "image_pil",
"image_url": { "image_pil": ImageAsset('cherry_blossom').pil_image
"url": image_url
}
}, { }, {
"type": "text", "type": "text",
"text": "What's in these images?" "text": "What's in these images?"

View File

@ -28,7 +28,8 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam) ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import ( from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio) InputAudio)
from pydantic import TypeAdapter from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
# yapf: enable # yapf: enable
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin) ProcessorMixin)
@ -91,6 +92,25 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
"""The type of the content part.""" """The type of the content part."""
class PILImage(BaseModel):
"""
A PIL.Image.Image object.
"""
image_pil: Image.Image
model_config = ConfigDict(arbitrary_types_allowed=True)
class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a PIL image.
Example:
{
"image_pil": ImageAsset('cherry_blossom').pil_image
}
"""
image_pil: Required[PILImage]
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url. """A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented. This is supported by OpenAI API, although it is not documented.
@ -129,6 +149,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPILImageParam,
CustomChatCompletionContentSimpleImageParam, CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam, ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam, CustomChatCompletionContentSimpleAudioParam,
@ -631,6 +652,10 @@ class BaseMultiModalContentParser(ABC):
image_embeds: Union[str, dict[str, str]]) -> None: image_embeds: Union[str, dict[str, str]]) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def parse_image_pil(self, image_pil: Image.Image) -> None:
raise NotImplementedError
@abstractmethod @abstractmethod
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError raise NotImplementedError
@ -677,6 +702,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None:
placeholder = self._tracker.add("image", image_pil)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url) audio = self._connector.fetch_audio(audio_url)
@ -733,6 +762,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image_embeds", future) placeholder = self._tracker.add("image_embeds", future)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None:
future: asyncio.Future[Image.Image] = asyncio.Future()
future.set_result(image_pil)
placeholder = self._tracker.add("image", future)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url) audio_coro = self._connector.fetch_audio_async(audio_url)
@ -851,12 +887,13 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
# Need to validate url objects # Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python _ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio] _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
# Define a mapping from part types to their corresponding parsing functions. # Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: dict[ MM_PARSER_MAP: dict[
@ -869,6 +906,7 @@ MM_PARSER_MAP: dict[
lambda part: _ImageParser(part).get("image_url", {}).get("url", None), lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds": "image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", None), lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url": "audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio": "input_audio":
@ -938,7 +976,7 @@ def _parse_chat_message_content_mm_part(
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds", "image_embeds", "image_pil",
"audio_url", "input_audio", "video_url") "audio_url", "input_audio", "video_url")
@ -1009,6 +1047,10 @@ def _parse_chat_message_content_part(
else: else:
return str_content return str_content
if part_type == "image_pil":
image_content = cast(Image.Image, content)
mm_parser.parse_image_pil(image_content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_url": if part_type == "image_url":
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_image(str_content) mm_parser.parse_image(str_content)