mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
[Bugfix]: Make chat content text allow type content (#9358)
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
This commit is contained in:
parent
b7df53cd42
commit
33bab41060
@ -103,6 +103,23 @@ vllm serve <model> --chat-template ./path-to-chat-template.jinja
|
|||||||
vLLM community provides a set of chat templates for popular models. You can find them in the examples
|
vLLM community provides a set of chat templates for popular models. You can find them in the examples
|
||||||
directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
||||||
|
|
||||||
|
With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies
|
||||||
|
both a `type` and a `text` field. An example is provided below:
|
||||||
|
```python
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like
|
||||||
|
`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which
|
||||||
|
format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify
|
||||||
|
between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match
|
||||||
|
this, unless explicitly specified.
|
||||||
|
|
||||||
|
|
||||||
## Command line arguments for the server
|
## Command line arguments for the server
|
||||||
|
|
||||||
```{argparse}
|
```{argparse}
|
||||||
|
|||||||
@ -26,6 +26,7 @@ class MockModelConfig:
|
|||||||
tokenizer = MODEL_NAME
|
tokenizer = MODEL_NAME
|
||||||
trust_remote_code = False
|
trust_remote_code = False
|
||||||
tokenizer_mode = "auto"
|
tokenizer_mode = "auto"
|
||||||
|
chat_template_text_format = "string"
|
||||||
max_model_len = 100
|
max_model_len = 100
|
||||||
tokenizer_revision = None
|
tokenizer_revision = None
|
||||||
multimodal_config = MultiModalConfig()
|
multimodal_config = MultiModalConfig()
|
||||||
|
|||||||
@ -17,7 +17,7 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
|||||||
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="function")
|
||||||
def phi3v_model_config():
|
def phi3v_model_config():
|
||||||
return ModelConfig(PHI3V_MODEL_ID,
|
return ModelConfig(PHI3V_MODEL_ID,
|
||||||
task="generate",
|
task="generate",
|
||||||
@ -26,6 +26,7 @@ def phi3v_model_config():
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
seed=0,
|
seed=0,
|
||||||
|
chat_template_text_format="string",
|
||||||
limit_mm_per_prompt={
|
limit_mm_per_prompt={
|
||||||
"image": 2,
|
"image": 2,
|
||||||
})
|
})
|
||||||
@ -330,6 +331,51 @@ def test_parse_chat_messages_multiple_images_across_messages(
|
|||||||
_assert_mm_data_is_image_input(mm_data, 2)
|
_assert_mm_data_is_image_input(mm_data, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_context_text_format(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
):
|
||||||
|
phi3v_model_config.chat_template_text_format = "openai"
|
||||||
|
conversation, mm_data = parse_chat_messages(
|
||||||
|
[{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this text?"
|
||||||
|
}]
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Some stuff."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "What about this one?"
|
||||||
|
}], phi3v_model_config, phi3v_tokenizer)
|
||||||
|
|
||||||
|
assert conversation == [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this text?"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Some stuff."
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What about this one?"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
||||||
phi3v_model_config,
|
phi3v_model_config,
|
||||||
phi3v_tokenizer,
|
phi3v_tokenizer,
|
||||||
|
|||||||
@ -142,6 +142,7 @@ class ModelConfig:
|
|||||||
use_async_output_proc: bool = True,
|
use_async_output_proc: bool = True,
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
|
chat_template_text_format: str = "string",
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -176,6 +177,7 @@ class ModelConfig:
|
|||||||
self.model, revision)
|
self.model, revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
self.use_async_output_proc = use_async_output_proc
|
self.use_async_output_proc = use_async_output_proc
|
||||||
|
self.chat_template_text_format = chat_template_text_format
|
||||||
self.mm_processor_kwargs = mm_processor_kwargs
|
self.mm_processor_kwargs = mm_processor_kwargs
|
||||||
|
|
||||||
# Set enforce_eager to False if the value is unset.
|
# Set enforce_eager to False if the value is unset.
|
||||||
|
|||||||
@ -89,6 +89,7 @@ class EngineArgs:
|
|||||||
task: TaskOption = "auto"
|
task: TaskOption = "auto"
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
tokenizer_mode: str = 'auto'
|
tokenizer_mode: str = 'auto'
|
||||||
|
chat_template_text_format: str = 'string'
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
load_format: str = 'auto'
|
load_format: str = 'auto'
|
||||||
@ -250,6 +251,14 @@ class EngineArgs:
|
|||||||
'fast tokenizer if available.\n* "slow" will '
|
'fast tokenizer if available.\n* "slow" will '
|
||||||
'always use the slow tokenizer. \n* '
|
'always use the slow tokenizer. \n* '
|
||||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--chat-template-text-format',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.chat_template_text_format,
|
||||||
|
choices=['string', 'openai'],
|
||||||
|
help='The format to render text content within a chat template. '
|
||||||
|
'"string" will keep the content field as a string whereas '
|
||||||
|
'"openai" will parse content in the current OpenAI format.')
|
||||||
parser.add_argument('--trust-remote-code',
|
parser.add_argument('--trust-remote-code',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Trust remote code from huggingface.')
|
help='Trust remote code from huggingface.')
|
||||||
@ -858,6 +867,7 @@ class EngineArgs:
|
|||||||
# We know this is not None because we set it in __post_init__
|
# We know this is not None because we set it in __post_init__
|
||||||
tokenizer=cast(str, self.tokenizer),
|
tokenizer=cast(str, self.tokenizer),
|
||||||
tokenizer_mode=self.tokenizer_mode,
|
tokenizer_mode=self.tokenizer_mode,
|
||||||
|
chat_template_text_format=self.chat_template_text_format,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
|
|||||||
@ -254,7 +254,7 @@ class LLMEngine:
|
|||||||
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
|
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
|
||||||
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
|
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
|
||||||
"use_async_output_proc=%s, use_cached_outputs=%s, "
|
"use_async_output_proc=%s, use_cached_outputs=%s, "
|
||||||
"mm_processor_kwargs=%s)",
|
"chat_template_text_format=%s, mm_processor_kwargs=%s)",
|
||||||
VLLM_VERSION,
|
VLLM_VERSION,
|
||||||
model_config.model,
|
model_config.model,
|
||||||
speculative_config,
|
speculative_config,
|
||||||
@ -289,6 +289,7 @@ class LLMEngine:
|
|||||||
cache_config.enable_prefix_caching,
|
cache_config.enable_prefix_caching,
|
||||||
model_config.use_async_output_proc,
|
model_config.use_async_output_proc,
|
||||||
use_cached_outputs,
|
use_cached_outputs,
|
||||||
|
model_config.chat_template_text_format,
|
||||||
model_config.mm_processor_kwargs,
|
model_config.mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|||||||
@ -121,7 +121,7 @@ class ConversationMessage(TypedDict, total=False):
|
|||||||
role: Required[str]
|
role: Required[str]
|
||||||
"""The role of the message's author."""
|
"""The role of the message's author."""
|
||||||
|
|
||||||
content: Optional[str]
|
content: Union[Optional[str], List[Dict[str, str]]]
|
||||||
"""The contents of the message"""
|
"""The contents of the message"""
|
||||||
|
|
||||||
tool_call_id: Optional[str]
|
tool_call_id: Optional[str]
|
||||||
@ -431,7 +431,7 @@ MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
|
|||||||
def _parse_chat_message_content_mm_part(
|
def _parse_chat_message_content_mm_part(
|
||||||
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
|
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Parses a given multi modal content part based on its type.
|
Parses a given multi-modal content part based on its type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
part: A dict containing the content part, with a potential 'type' field.
|
part: A dict containing the content part, with a potential 'type' field.
|
||||||
@ -485,21 +485,26 @@ def _parse_chat_message_content_parts(
|
|||||||
role: str,
|
role: str,
|
||||||
parts: Iterable[ChatCompletionContentPartParam],
|
parts: Iterable[ChatCompletionContentPartParam],
|
||||||
mm_tracker: BaseMultiModalItemTracker,
|
mm_tracker: BaseMultiModalItemTracker,
|
||||||
|
chat_template_text_format: str,
|
||||||
) -> List[ConversationMessage]:
|
) -> List[ConversationMessage]:
|
||||||
content: List[Union[str, Dict[str, str]]] = []
|
content: List[Union[str, Dict[str, str]]] = []
|
||||||
|
|
||||||
mm_parser = mm_tracker.create_parser()
|
mm_parser = mm_tracker.create_parser()
|
||||||
keep_multimodal_content = \
|
wrap_dicts = \
|
||||||
mm_tracker._model_config.hf_config.model_type in \
|
mm_tracker._model_config.hf_config.model_type in \
|
||||||
MODEL_KEEP_MULTI_MODAL_CONTENT
|
MODEL_KEEP_MULTI_MODAL_CONTENT or \
|
||||||
|
(chat_template_text_format == "openai")
|
||||||
|
|
||||||
for part in parts:
|
for part in parts:
|
||||||
parse_res = _parse_chat_message_content_part(
|
parse_res = _parse_chat_message_content_part(
|
||||||
part, mm_parser, wrap_dicts=keep_multimodal_content)
|
part,
|
||||||
|
mm_parser,
|
||||||
|
wrap_dicts=wrap_dicts,
|
||||||
|
)
|
||||||
if parse_res:
|
if parse_res:
|
||||||
content.append(parse_res)
|
content.append(parse_res)
|
||||||
|
|
||||||
if keep_multimodal_content:
|
if wrap_dicts:
|
||||||
# Parsing wraps images and texts as interleaved dictionaries
|
# Parsing wraps images and texts as interleaved dictionaries
|
||||||
return [ConversationMessage(role=role,
|
return [ConversationMessage(role=role,
|
||||||
content=content)] # type: ignore
|
content=content)] # type: ignore
|
||||||
@ -560,6 +565,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
|||||||
def _parse_chat_message_content(
|
def _parse_chat_message_content(
|
||||||
message: ChatCompletionMessageParam,
|
message: ChatCompletionMessageParam,
|
||||||
mm_tracker: BaseMultiModalItemTracker,
|
mm_tracker: BaseMultiModalItemTracker,
|
||||||
|
chat_template_text_format: str,
|
||||||
) -> List[ConversationMessage]:
|
) -> List[ConversationMessage]:
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
@ -575,6 +581,7 @@ def _parse_chat_message_content(
|
|||||||
role,
|
role,
|
||||||
content, # type: ignore
|
content, # type: ignore
|
||||||
mm_tracker,
|
mm_tracker,
|
||||||
|
chat_template_text_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
for result_msg in result:
|
for result_msg in result:
|
||||||
@ -618,7 +625,11 @@ def parse_chat_messages(
|
|||||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
sub_messages = _parse_chat_message_content(
|
||||||
|
msg,
|
||||||
|
mm_tracker,
|
||||||
|
model_config.chat_template_text_format,
|
||||||
|
)
|
||||||
|
|
||||||
conversation.extend(sub_messages)
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
@ -636,7 +647,11 @@ def parse_chat_messages_futures(
|
|||||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
sub_messages = _parse_chat_message_content(
|
||||||
|
msg,
|
||||||
|
mm_tracker,
|
||||||
|
model_config.chat_template_text_format,
|
||||||
|
)
|
||||||
|
|
||||||
conversation.extend(sub_messages)
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
|
|||||||
@ -384,7 +384,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# Send response to echo the input portion of the
|
# Send response to echo the input portion of the
|
||||||
# last message
|
# last message
|
||||||
if request.echo or request.continue_final_message:
|
if request.echo or request.continue_final_message:
|
||||||
last_msg_content: str = ""
|
last_msg_content: Union[str, List[Dict[str, str]]] = ""
|
||||||
if conversation and "content" in conversation[
|
if conversation and "content" in conversation[
|
||||||
-1] and conversation[-1].get("role") == role:
|
-1] and conversation[-1].get("role") == role:
|
||||||
last_msg_content = conversation[-1]["content"] or ""
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
@ -724,10 +724,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
if request.echo or request.continue_final_message:
|
if request.echo or request.continue_final_message:
|
||||||
last_msg_content = ""
|
last_msg_content: Union[str, List[Dict[str, str]]] = ""
|
||||||
if conversation and "content" in conversation[-1] and conversation[
|
if conversation and "content" in conversation[-1] and conversation[
|
||||||
-1].get("role") == role:
|
-1].get("role") == role:
|
||||||
last_msg_content = conversation[-1]["content"] or ""
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
|
if isinstance(last_msg_content, list):
|
||||||
|
last_msg_content = "\n".join(msg['text']
|
||||||
|
for msg in last_msg_content)
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
full_message = last_msg_content + (choice.message.content
|
full_message = last_msg_content + (choice.message.content
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user