[Frontend] Enable Online Multi-image Support for MLlama (#9393)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alex Brooks 2024-10-23 11:28:57 -06:00 committed by GitHub
parent 9013e24f7b
commit 150b779081
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 230 additions and 37 deletions

View File

@ -8,11 +8,13 @@ from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (parse_chat_messages,
parse_chat_messages_futures)
from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@pytest.fixture(scope="module")
@ -39,6 +41,30 @@ def phi3v_tokenizer():
)
@pytest.fixture(scope="module")
def mllama_model_config():
return ModelConfig(MLLAMA_MODEL_ID,
task="generate",
tokenizer=MLLAMA_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
limit_mm_per_prompt={
"image": 2,
})
@pytest.fixture(scope="module")
def mllama_tokenizer():
return TokenizerGroup(
MLLAMA_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
@pytest.fixture(scope="module")
def image_url():
image = ImageAsset('cherry_blossom')
@ -414,3 +440,153 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)
### Mllama currently wraps images / texts as interleaved dictionaries
def test_mllama_single_image(
mllama_model_config,
mllama_tokenizer,
image_url,
):
"""Ensures that a single image is parsed correctly mllama."""
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
'type': 'text',
'text': 'The content of this image is:'
}, {
"image_url": image_url
}]
}], mllama_model_config, mllama_tokenizer)
_assert_mm_data_is_image_input(mm_data, 1)
assert conversation == [{
'role':
'user',
'content': [{
'type': 'text',
'text': 'The content of this image is:'
}, {
'type': 'image'
}]
}]
def test_mllama_interleaved_images(
mllama_model_config,
mllama_tokenizer,
image_url,
):
"""Ensures that multiple image are parsed as interleaved dicts."""
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
{
'type': 'text',
'text': 'The content of the first image is:'
},
{
"image_url": image_url
},
{
'type': 'text',
'text': 'The content of the second image is:'
},
{
"image_url": image_url
},
]
}], mllama_model_config, mllama_tokenizer)
_assert_mm_data_is_image_input(mm_data, 2)
assert conversation == [{
'role':
'user',
'content': [{
'type': 'text',
'text': 'The content of the first image is:'
}, {
'type': 'image'
}, {
'type': 'text',
'text': 'The content of the second image is:'
}, {
'type': 'image'
}]
}]
@pytest.mark.parametrize("model", [MLLAMA_MODEL_ID])
def test_multimodal_image_parsing_matches_hf(model, image_url):
"""Checks end to end hf alignment for multimodal [image] parsing."""
def get_conversation(is_hf: bool):
img_part = {"type": "image_url", "image_url": {"url": image_url}}
if is_hf:
img_part = {'type': 'image'}
return [{
'role':
'user',
'content': [
{
'type': 'text',
'text': 'The content of the first image is:'
},
img_part,
{
'type': 'text',
'text': 'The content of the second image is:'
},
img_part,
{
'type': 'text',
'text': 'What animal is in the first image?'
},
]
}]
# Build a config for the model
model_config = ModelConfig(model,
task="generate",
tokenizer=MLLAMA_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
limit_mm_per_prompt={
"image": 2,
})
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
MLLAMA_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer
# Build and parse a conversation with {"type": "image"} using the tokenizer
hf_conversation = get_conversation(is_hf=True)
hf_result = tokenizer.apply_chat_template(
hf_conversation,
tokenize=False,
add_generation_prompt=True,
)
# Now parse with vLLMs chat utils & apply the template
vllm_conversation = get_conversation(is_hf=False)
conversation, _ = parse_chat_messages(
vllm_conversation,
model_config,
tokenizer_group,
)
vllm_result = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=None,
add_generation_prompt=True,
)
assert hf_result == vllm_result

View File

@ -483,53 +483,70 @@ def _parse_chat_message_content_parts(
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
content: List[Union[str, Dict[str, str]]] = []
mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT
has_image = False
for part in parts:
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
texts.append(text)
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
parse_res = _parse_chat_message_content_part(
part, mm_parser, wrap_dicts=keep_multimodal_content)
if parse_res:
content.append(parse_res)
# if part_type is text/refusal/image_url/audio_url but
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
continue
if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if keep_multimodal_content:
text_prompt = "\n".join(texts)
role_content = [{'type': 'text', 'text': text_prompt}]
if has_image:
role_content = [{'type': 'image'}] + role_content
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=role_content)] # type: ignore
else:
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts, text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
content=content)] # type: ignore
texts = cast(List[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
def _parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
{"type": "image"}, respectively. Otherwise multimodal data will be
handled by mm_parser, and texts will be returned as strings to be joined
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning(
"Skipping multimodal part (type: '%s')"
"with empty / unparsable content.", part_type)
return None
if part_type in ("text", "refusal"):
return {'type': 'text', 'text': content} if wrap_dicts else content
if part_type == "image_url":
mm_parser.parse_image(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url":
mm_parser.parse_audio(content)
return {'type': 'audio'} if wrap_dicts else None
raise NotImplementedError(f"Unknown part type: {part_type}")
# No need to validate using Pydantic again