# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import json from abc import ABC, abstractmethod from collections import Counter, defaultdict, deque from collections.abc import Awaitable, Callable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast import jinja2 import jinja2.ext import jinja2.meta import jinja2.nodes import jinja2.parser import jinja2.sandbox import transformers.utils.chat_template_utils as hf_chat_utils from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionContentPartImageParam, ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartRefusalParam, ChatCompletionContentPartTextParam, ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam, ) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam, ) from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio from openai.types.responses import ResponseInputImageParam from openai_harmony import Message as OpenAIHarmonyMessage from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin # pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import MediaConnector from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import random_uuid from vllm.utils.func_utils import supports_kw logger = init_logger(__name__) MODALITY_PLACEHOLDERS_MAP = { "image": "<##IMAGE##>", "audio": "<##AUDIO##>", "video": "<##VIDEO##>", } class AudioURL(TypedDict, total=False): url: Required[str] """ Either a URL of the audio or a data URL with base64 encoded audio data. """ class ChatCompletionContentPartAudioParam(TypedDict, total=False): audio_url: Required[AudioURL] type: Required[Literal["audio_url"]] """The type of the content part.""" class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): image_embeds: str | dict[str, str] | None """ The image embeddings. It can be either: - A single base64 string. - A dictionary where each value is a base64 string. """ type: Required[Literal["image_embeds"]] """The type of the content part.""" uuid: str | None """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. """ class VideoURL(TypedDict, total=False): url: Required[str] """ Either a URL of the video or a data URL with base64 encoded video data. """ class ChatCompletionContentPartVideoParam(TypedDict, total=False): video_url: Required[VideoURL] type: Required[Literal["video_url"]] """The type of the content part.""" class PILImage(BaseModel): """ A PIL.Image.Image object. """ image_pil: Image.Image model_config = ConfigDict(arbitrary_types_allowed=True) class CustomChatCompletionContentPILImageParam(TypedDict, total=False): """A simpler version of the param that only accepts a PIL image. Example: { "image_pil": ImageAsset('cherry_blossom').pil_image } """ image_pil: PILImage | None uuid: str | None """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. """ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain image_url. This is supported by OpenAI API, although it is not documented. Example: { "image_url": "https://example.com/image.jpg" } """ image_url: str | None uuid: str | None """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. """ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain audio_url. Example: { "audio_url": "https://example.com/audio.mp3" } """ audio_url: str | None class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain audio_url. Example: { "video_url": "https://example.com/video.mp4" } """ video_url: str | None uuid: str | None """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. """ class CustomThinkCompletionContentParam(TypedDict, total=False): """A Think Completion Content Param that accepts a plain text and a boolean. Example: { "thinking": "I am thinking about the answer", "closed": True, "type": "thinking" } """ thinking: Required[str] """The thinking content.""" closed: bool """Whether the thinking is closed.""" type: Required[Literal["thinking"]] """The thinking type.""" ChatCompletionContentPartParam: TypeAlias = ( OpenAIChatCompletionContentPartParam | ChatCompletionContentPartAudioParam | ChatCompletionContentPartInputAudioParam | ChatCompletionContentPartVideoParam | ChatCompletionContentPartRefusalParam | CustomChatCompletionContentPILImageParam | CustomChatCompletionContentSimpleImageParam | ChatCompletionContentPartImageEmbedsParam | CustomChatCompletionContentSimpleAudioParam | CustomChatCompletionContentSimpleVideoParam | str | CustomThinkCompletionContentParam ) class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" role: Required[str] """The role of the message's author.""" content: str | list[ChatCompletionContentPartParam] """The contents of the message.""" name: str """An optional name for the participant. Provides the model information to differentiate between participants of the same role. """ tool_call_id: str | None """Tool call that this message is responding to.""" tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None """The tool calls generated by the model, such as function calls.""" ChatCompletionMessageParam: TypeAlias = ( OpenAIChatCompletionMessageParam | CustomChatCompletionMessageParam | OpenAIHarmonyMessage ) # TODO: Make fields ReadOnly once mypy supports it class ConversationMessage(TypedDict, total=False): role: Required[str] """The role of the message's author.""" content: str | None | list[dict[str, str]] """The contents of the message""" tool_call_id: str | None """Tool call that this message is responding to.""" name: str | None """The name of the function to call""" tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None """The tool calls generated by the model, such as function calls.""" # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] # Used internally _ChatTemplateContentFormat = Literal["string", "openai"] def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: if isinstance(node, jinja2.nodes.Name): return node.ctx == "load" and node.name == varname return False def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: if isinstance(node, jinja2.nodes.Getitem): return ( _is_var_access(node.node, varname) and isinstance(node.arg, jinja2.nodes.Const) and node.arg.value == key ) if isinstance(node, jinja2.nodes.Getattr): return _is_var_access(node.node, varname) and node.attr == key return False def _is_var_or_elems_access( node: jinja2.nodes.Node, varname: str, key: str | None = None, ) -> bool: if isinstance(node, jinja2.nodes.Filter): return node.node is not None and _is_var_or_elems_access( node.node, varname, key ) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) if isinstance(node, jinja2.nodes.Getitem) and isinstance( node.arg, jinja2.nodes.Slice ): return _is_var_or_elems_access(node.node, varname, key) return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): # Global variable that is implicitly defined at the root yield root, varname # Iterative BFS related_varnames = deque([varname]) while related_varnames: related_varname = related_varnames.popleft() for assign_ast in root.find_all(jinja2.nodes.Assign): lhs = assign_ast.target rhs = assign_ast.node if _is_var_or_elems_access(rhs, related_varname): assert isinstance(lhs, jinja2.nodes.Name) yield assign_ast, lhs.name # Avoid infinite looping for self-assignment if lhs.name != related_varname: related_varnames.append(lhs.name) # NOTE: The proper way to handle this is to build a CFG so that we can handle # the scope in which each variable is defined, but that is too complicated def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): messages_varnames = [ varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") ] # Search for {%- for message in messages -%} loops for loop_ast in root.find_all(jinja2.nodes.For): loop_iter = loop_ast.iter loop_target = loop_ast.target for varname in messages_varnames: if _is_var_or_elems_access(loop_iter, varname): assert isinstance(loop_target, jinja2.nodes.Name) yield loop_ast, loop_target.name break def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): message_varnames = [ varname for _, varname in _iter_nodes_assign_messages_item(root) ] # Search for {%- for content in message['content'] -%} loops for loop_ast in root.find_all(jinja2.nodes.For): loop_iter = loop_ast.iter loop_target = loop_ast.target for varname in message_varnames: if _is_var_or_elems_access(loop_iter, varname, "content"): assert isinstance(loop_target, jinja2.nodes.Name) yield loop_ast, loop_target.name break def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None: try: jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) return jinja_compiled.environment.parse(chat_template) except Exception: logger.exception("Error when compiling Jinja template") return None @lru_cache(maxsize=32) def _detect_content_format( chat_template: str, *, default: _ChatTemplateContentFormat, ) -> _ChatTemplateContentFormat: jinja_ast = _try_extract_ast(chat_template) if jinja_ast is None: return default try: next(_iter_nodes_assign_content_item(jinja_ast)) except StopIteration: return "string" except Exception: logger.exception("Error when parsing AST of Jinja template") return default else: return "openai" def resolve_mistral_chat_template( chat_template: str | None, **kwargs: Any, ) -> str | None: if chat_template is not None or kwargs.get("chat_template_kwargs") is not None: raise ValueError( "'chat_template' or 'chat_template_kwargs' cannot be overridden " "for mistral tokenizer." ) return None _PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]() """ Used in `_try_get_processor_chat_template` to avoid calling `cached_get_processor` again if the processor fails to be loaded. This is needed because `lru_cache` does not cache when an exception happens. """ def _try_get_processor_chat_template( tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, model_config: ModelConfig, ) -> str | None: cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) if cache_key in _PROCESSOR_CHAT_TEMPLATES: return _PROCESSOR_CHAT_TEMPLATES[cache_key] try: processor = cached_get_processor( tokenizer.name_or_path, processor_cls=( PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin, ), trust_remote_code=model_config.trust_remote_code, ) if ( isinstance(processor, ProcessorMixin) and hasattr(processor, "chat_template") and (chat_template := processor.chat_template) is not None ): _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template return chat_template except Exception: logger.debug( "Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True, ) _PROCESSOR_CHAT_TEMPLATES[cache_key] = None return None def resolve_hf_chat_template( tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, chat_template: str | None, tools: list[dict[str, Any]] | None, *, model_config: ModelConfig, ) -> str | None: # 1st priority: The given chat template if chat_template is not None: return chat_template # 2nd priority: AutoProcessor chat template, unless tool calling is enabled if tools is None: chat_template = _try_get_processor_chat_template(tokenizer, model_config) if chat_template is not None: return chat_template # 3rd priority: AutoTokenizer chat template try: return tokenizer.get_chat_template(chat_template, tools=tools) except Exception: logger.debug( "Failed to load AutoTokenizer chat template for %s", tokenizer.name_or_path, exc_info=True, ) # 4th priority: Predefined fallbacks path = get_chat_template_fallback_path( model_type=model_config.hf_config.model_type, tokenizer_name_or_path=model_config.tokenizer, ) if path is not None: logger.info_once( "Loading chat template fallback for %s as there isn't one " "defined on HF Hub.", tokenizer.name_or_path, ) chat_template = load_chat_template(path) else: logger.debug_once( "There is no chat template fallback for %s", tokenizer.name_or_path ) return chat_template def _resolve_chat_template_content_format( chat_template: str | None, tools: list[dict[str, Any]] | None, tokenizer: AnyTokenizer, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): hf_chat_template = resolve_hf_chat_template( tokenizer, chat_template=chat_template, tools=tools, model_config=model_config, ) else: hf_chat_template = None jinja_text = ( hf_chat_template if isinstance(hf_chat_template, str) else load_chat_template(chat_template, is_literal=True) ) detected_format = ( "string" if jinja_text is None else _detect_content_format(jinja_text, default="string") ) return detected_format @lru_cache def _log_chat_template_content_format( chat_template: str | None, given_format: ChatTemplateContentFormatOption, detected_format: ChatTemplateContentFormatOption, ): logger.info( "Detected the chat template content format to be '%s'. " "You can set `--chat-template-content-format` to override this.", detected_format, ) if given_format != "auto" and given_format != detected_format: logger.warning( "You specified `--chat-template-content-format %s` " "which is different from the detected format '%s'. " "If our automatic detection is incorrect, please consider " "opening a GitHub issue so that we can improve it: " "https://github.com/vllm-project/vllm/issues/new/choose", given_format, detected_format, ) def resolve_chat_template_content_format( chat_template: str | None, tools: list[dict[str, Any]] | None, given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: if given_format != "auto": return given_format detected_format = _resolve_chat_template_content_format( chat_template, tools, tokenizer, model_config=model_config, ) _log_chat_template_content_format( chat_template, given_format=given_format, detected_format=detected_format, ) return detected_format ModalityStr = Literal["image", "audio", "video", "image_embeds"] _T = TypeVar("_T") class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number of multi-modal items in a given request does not exceed the configured maximum per prompt. """ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): super().__init__() self._model_config = model_config self._tokenizer = tokenizer self._items_by_modality = defaultdict[str, list[_T | None]](list) self._uuids_by_modality = defaultdict[str, list[str | None]](list) @property def model_config(self) -> ModelConfig: return self._model_config @cached_property def model_cls(self) -> type[SupportsMultiModal]: from vllm.model_executor.model_loader import get_model_cls model_cls = get_model_cls(self.model_config) return cast(type[SupportsMultiModal], model_cls) @property def allowed_local_media_path(self): return self._model_config.allowed_local_media_path @property def allowed_media_domains(self): return self._model_config.allowed_media_domains @property def mm_registry(self): return MULTIMODAL_REGISTRY @cached_property def mm_processor(self): return self.mm_registry.create_processor(self.model_config) def add( self, modality: ModalityStr, item: _T | None, uuid: str | None = None, ) -> str | None: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. An optional uuid can be added which serves as a unique identifier of the media. """ input_modality = modality.replace("_embeds", "") num_items = len(self._items_by_modality[modality]) + 1 self.mm_processor.validate_num_items(input_modality, num_items) self._items_by_modality[modality].append(item) self._uuids_by_modality[modality].append(uuid) return self.model_cls.get_placeholder_str(modality, num_items) def all_mm_uuids(self) -> MultiModalUUIDDict | None: if not self._items_by_modality: return None mm_uuids = {} uuids_by_modality = dict(self._uuids_by_modality) if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in uuids_by_modality: image_embeds_uuids = uuids_by_modality["image_embeds"] if len(image_embeds_uuids) > 1: raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_uuids["image"] = uuids_by_modality["image_embeds"] if "image" in uuids_by_modality: mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images if "audio" in uuids_by_modality: mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios if "video" in uuids_by_modality: mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos return mm_uuids @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": return MultiModalContentParser(self) class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = {} for modality, items in self._items_by_modality.items(): coros = [] for item in items: if item is not None: coros.append(item) else: coros.append(asyncio.sleep(0)) items_by_modality[modality] = await asyncio.gather(*coros) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": return AsyncMultiModalContentParser(self) class BaseMultiModalContentParser(ABC): def __init__(self) -> None: super().__init__() # stores model placeholders list with corresponding # general MM placeholder: # { # "<##IMAGE##>": ["", "", ""], # "<##AUDIO##>": ["