mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:04:57 +08:00
Run ruff format on a few files. (#24075)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
parent
1c41310584
commit
f399182e8c
File diff suppressed because it is too large
Load Diff
@ -103,6 +103,7 @@ class PILImage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
A PIL.Image.Image object.
|
A PIL.Image.Image object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_pil: Image.Image
|
image_pil: Image.Image
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@ -115,6 +116,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
|
|||||||
"image_pil": ImageAsset('cherry_blossom').pil_image
|
"image_pil": ImageAsset('cherry_blossom').pil_image
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_pil: Required[PILImage]
|
image_pil: Required[PILImage]
|
||||||
|
|
||||||
|
|
||||||
@ -127,6 +129,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
|
|||||||
"image_url": "https://example.com/image.jpg"
|
"image_url": "https://example.com/image.jpg"
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_url: Required[str]
|
image_url: Required[str]
|
||||||
|
|
||||||
|
|
||||||
@ -138,6 +141,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
|
|||||||
"audio_url": "https://example.com/audio.mp3"
|
"audio_url": "https://example.com/audio.mp3"
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio_url: Required[str]
|
audio_url: Required[str]
|
||||||
|
|
||||||
|
|
||||||
@ -149,6 +153,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
|||||||
"video_url": "https://example.com/video.mp4"
|
"video_url": "https://example.com/video.mp4"
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
video_url: Required[str]
|
video_url: Required[str]
|
||||||
|
|
||||||
|
|
||||||
@ -174,19 +179,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
OpenAIChatCompletionContentPartParam,
|
||||||
|
ChatCompletionContentPartAudioParam,
|
||||||
ChatCompletionContentPartInputAudioParam,
|
ChatCompletionContentPartInputAudioParam,
|
||||||
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
|
ChatCompletionContentPartVideoParam,
|
||||||
|
ChatCompletionContentPartRefusalParam,
|
||||||
CustomChatCompletionContentPILImageParam,
|
CustomChatCompletionContentPILImageParam,
|
||||||
CustomChatCompletionContentSimpleImageParam,
|
CustomChatCompletionContentSimpleImageParam,
|
||||||
ChatCompletionContentPartImageEmbedsParam,
|
ChatCompletionContentPartImageEmbedsParam,
|
||||||
CustomChatCompletionContentSimpleAudioParam,
|
CustomChatCompletionContentSimpleAudioParam,
|
||||||
CustomChatCompletionContentSimpleVideoParam, str,
|
CustomChatCompletionContentSimpleVideoParam,
|
||||||
CustomThinkCompletionContentParam]
|
str,
|
||||||
|
CustomThinkCompletionContentParam,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||||
"""Enables custom roles in the Chat Completion API."""
|
"""Enables custom roles in the Chat Completion API."""
|
||||||
|
|
||||||
role: Required[str]
|
role: Required[str]
|
||||||
"""The role of the message's author."""
|
"""The role of the message's author."""
|
||||||
|
|
||||||
@ -207,9 +217,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
|
|||||||
"""The tool calls generated by the model, such as function calls."""
|
"""The tool calls generated by the model, such as function calls."""
|
||||||
|
|
||||||
|
|
||||||
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
ChatCompletionMessageParam = Union[
|
||||||
CustomChatCompletionMessageParam,
|
OpenAIChatCompletionMessageParam,
|
||||||
OpenAIHarmonyMessage]
|
CustomChatCompletionMessageParam,
|
||||||
|
OpenAIHarmonyMessage,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# TODO: Make fields ReadOnly once mypy supports it
|
# TODO: Make fields ReadOnly once mypy supports it
|
||||||
@ -262,13 +274,13 @@ def _is_var_or_elems_access(
|
|||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if isinstance(node, jinja2.nodes.Filter):
|
if isinstance(node, jinja2.nodes.Filter):
|
||||||
return (node.node is not None
|
return node.node is not None and _is_var_or_elems_access(
|
||||||
and _is_var_or_elems_access(node.node, varname, key))
|
node.node, varname, key)
|
||||||
if isinstance(node, jinja2.nodes.Test):
|
if isinstance(node, jinja2.nodes.Test):
|
||||||
return _is_var_or_elems_access(node.node, varname, key)
|
return _is_var_or_elems_access(node.node, varname, key)
|
||||||
|
|
||||||
if (isinstance(node, jinja2.nodes.Getitem)
|
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||||
and isinstance(node.arg, jinja2.nodes.Slice)):
|
node.arg, jinja2.nodes.Slice):
|
||||||
return _is_var_or_elems_access(node.node, varname, key)
|
return _is_var_or_elems_access(node.node, varname, key)
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -373,15 +385,18 @@ def resolve_mistral_chat_template(
|
|||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"'chat_template' cannot be overridden for mistral tokenizer.")
|
"'chat_template' cannot be overridden for mistral tokenizer."
|
||||||
|
)
|
||||||
if "add_generation_prompt" in kwargs:
|
if "add_generation_prompt" in kwargs:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||||
"so it will be ignored.")
|
"so it will be ignored."
|
||||||
|
)
|
||||||
if "continue_final_message" in kwargs:
|
if "continue_final_message" in kwargs:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"'continue_final_message' is not supported for mistral tokenizer, "
|
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||||
"so it will be ignored.")
|
"so it will be ignored."
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -401,23 +416,35 @@ def resolve_hf_chat_template(
|
|||||||
try:
|
try:
|
||||||
processor = cached_get_processor(
|
processor = cached_get_processor(
|
||||||
tokenizer.name_or_path,
|
tokenizer.name_or_path,
|
||||||
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
|
processor_cls=(
|
||||||
ProcessorMixin),
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
ProcessorMixin,
|
||||||
|
),
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
if isinstance(processor, ProcessorMixin) and \
|
if (
|
||||||
hasattr(processor, 'chat_template') and \
|
isinstance(processor, ProcessorMixin)
|
||||||
processor.chat_template is not None:
|
and hasattr(processor, "chat_template")
|
||||||
|
and processor.chat_template is not None
|
||||||
|
):
|
||||||
return processor.chat_template
|
return processor.chat_template
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501
|
logger.debug(
|
||||||
|
"Failed to load AutoProcessor chat template for %s",
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
exc_info=True,
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
# 3rd priority: AutoTokenizer chat template
|
# 3rd priority: AutoTokenizer chat template
|
||||||
try:
|
try:
|
||||||
return tokenizer.get_chat_template(chat_template, tools=tools)
|
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to load AutoTokenizer chat template for %s",
|
logger.debug(
|
||||||
tokenizer.name_or_path, exc_info=True)
|
"Failed to load AutoTokenizer chat template for %s",
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
# 4th priority: Predefined fallbacks
|
# 4th priority: Predefined fallbacks
|
||||||
path = get_chat_template_fallback_path(
|
path = get_chat_template_fallback_path(
|
||||||
@ -425,12 +452,16 @@ def resolve_hf_chat_template(
|
|||||||
tokenizer_name_or_path=model_config.tokenizer,
|
tokenizer_name_or_path=model_config.tokenizer,
|
||||||
)
|
)
|
||||||
if path is not None:
|
if path is not None:
|
||||||
logger.info("Loading chat template fallback for %s as there isn't one "
|
logger.info(
|
||||||
"defined on HF Hub.", tokenizer.name_or_path)
|
"Loading chat template fallback for %s as there isn't one "
|
||||||
|
"defined on HF Hub.",
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
)
|
||||||
chat_template = load_chat_template(path)
|
chat_template = load_chat_template(path)
|
||||||
else:
|
else:
|
||||||
logger.debug("There is no chat template fallback for %s",
|
logger.debug(
|
||||||
tokenizer.name_or_path)
|
"There is no chat template fallback for %s", tokenizer.name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
return chat_template
|
return chat_template
|
||||||
|
|
||||||
@ -452,11 +483,17 @@ def _resolve_chat_template_content_format(
|
|||||||
else:
|
else:
|
||||||
hf_chat_template = None
|
hf_chat_template = None
|
||||||
|
|
||||||
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
|
jinja_text = (
|
||||||
else load_chat_template(chat_template, is_literal=True))
|
hf_chat_template
|
||||||
|
if isinstance(hf_chat_template, str)
|
||||||
|
else load_chat_template(chat_template, is_literal=True)
|
||||||
|
)
|
||||||
|
|
||||||
detected_format = ("string" if jinja_text is None else
|
detected_format = (
|
||||||
_detect_content_format(jinja_text, default="string"))
|
"string"
|
||||||
|
if jinja_text is None
|
||||||
|
else _detect_content_format(jinja_text, default="string")
|
||||||
|
)
|
||||||
|
|
||||||
return detected_format
|
return detected_format
|
||||||
|
|
||||||
@ -512,7 +549,6 @@ def resolve_chat_template_content_format(
|
|||||||
return detected_format
|
return detected_format
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
|
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
@ -539,6 +575,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
@cached_property
|
@cached_property
|
||||||
def model_cls(self) -> type[SupportsMultiModal]:
|
def model_cls(self) -> type[SupportsMultiModal]:
|
||||||
from vllm.model_executor.model_loader import get_model_cls
|
from vllm.model_executor.model_loader import get_model_cls
|
||||||
|
|
||||||
model_cls = get_model_cls(self.model_config)
|
model_cls = get_model_cls(self.model_config)
|
||||||
return cast(type[SupportsMultiModal], model_cls)
|
return cast(type[SupportsMultiModal], model_cls)
|
||||||
|
|
||||||
@ -574,28 +611,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
|
|
||||||
|
|
||||||
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||||
|
|
||||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
if not self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
return None
|
return None
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
items_by_modality = dict(self._items_by_modality)
|
items_by_modality = dict(self._items_by_modality)
|
||||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
raise ValueError(\
|
raise ValueError(
|
||||||
"Mixing raw image and embedding inputs is not allowed")
|
"Mixing raw image and embedding inputs is not allowed"
|
||||||
|
)
|
||||||
|
|
||||||
if "image_embeds" in items_by_modality:
|
if "image_embeds" in items_by_modality:
|
||||||
image_embeds_lst = items_by_modality["image_embeds"]
|
image_embeds_lst = items_by_modality["image_embeds"]
|
||||||
if len(image_embeds_lst) > 1:
|
if len(image_embeds_lst) > 1:
|
||||||
raise ValueError(\
|
raise ValueError(
|
||||||
"Only one message can have {'type': 'image_embeds'}")
|
"Only one message can have {'type': 'image_embeds'}"
|
||||||
|
)
|
||||||
mm_inputs["image"] = image_embeds_lst[0]
|
mm_inputs["image"] = image_embeds_lst[0]
|
||||||
if "image" in items_by_modality:
|
if "image" in items_by_modality:
|
||||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||||
if "audio" in items_by_modality:
|
if "audio" in items_by_modality:
|
||||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||||
if "video" in items_by_modality:
|
if "video" in items_by_modality:
|
||||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
@ -603,32 +641,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
|||||||
|
|
||||||
|
|
||||||
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||||
|
|
||||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
if not self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
return None
|
return None
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
items_by_modality = {
|
items_by_modality = {
|
||||||
modality: await asyncio.gather(*items)
|
modality: await asyncio.gather(*items)
|
||||||
for modality, items in self._items_by_modality.items()
|
for modality, items in self._items_by_modality.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Mixing raw image and embedding inputs is not allowed")
|
"Mixing raw image and embedding inputs is not allowed"
|
||||||
|
)
|
||||||
|
|
||||||
if "image_embeds" in items_by_modality:
|
if "image_embeds" in items_by_modality:
|
||||||
image_embeds_lst = items_by_modality["image_embeds"]
|
image_embeds_lst = items_by_modality["image_embeds"]
|
||||||
if len(image_embeds_lst) > 1:
|
if len(image_embeds_lst) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only one message can have {'type': 'image_embeds'}")
|
"Only one message can have {'type': 'image_embeds'}"
|
||||||
|
)
|
||||||
mm_inputs["image"] = image_embeds_lst[0]
|
mm_inputs["image"] = image_embeds_lst[0]
|
||||||
if "image" in items_by_modality:
|
if "image" in items_by_modality:
|
||||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||||
if "audio" in items_by_modality:
|
if "audio" in items_by_modality:
|
||||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||||
if "video" in items_by_modality:
|
if "video" in items_by_modality:
|
||||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
@ -636,7 +675,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
|||||||
|
|
||||||
|
|
||||||
class BaseMultiModalContentParser(ABC):
|
class BaseMultiModalContentParser(ABC):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -648,8 +686,9 @@ class BaseMultiModalContentParser(ABC):
|
|||||||
# }
|
# }
|
||||||
self._placeholder_storage: dict[str, list] = defaultdict(list)
|
self._placeholder_storage: dict[str, list] = defaultdict(list)
|
||||||
|
|
||||||
def _add_placeholder(self, modality: ModalityStr,
|
def _add_placeholder(
|
||||||
placeholder: Optional[str]):
|
self, modality: ModalityStr, placeholder: Optional[str]
|
||||||
|
):
|
||||||
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
|
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
|
||||||
if placeholder:
|
if placeholder:
|
||||||
self._placeholder_storage[mod_placeholder].append(placeholder)
|
self._placeholder_storage[mod_placeholder].append(placeholder)
|
||||||
@ -662,8 +701,9 @@ class BaseMultiModalContentParser(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_image_embeds(self,
|
def parse_image_embeds(
|
||||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
self, image_embeds: Union[str, dict[str, str]]
|
||||||
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -684,7 +724,6 @@ class BaseMultiModalContentParser(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class MultiModalContentParser(BaseMultiModalContentParser):
|
class MultiModalContentParser(BaseMultiModalContentParser):
|
||||||
|
|
||||||
def __init__(self, tracker: MultiModalItemTracker) -> None:
|
def __init__(self, tracker: MultiModalItemTracker) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -701,8 +740,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
placeholder = self._tracker.add("image", image)
|
placeholder = self._tracker.add("image", image)
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
def parse_image_embeds(self,
|
def parse_image_embeds(
|
||||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
self, image_embeds: Union[str, dict[str, str]]
|
||||||
|
) -> None:
|
||||||
if isinstance(image_embeds, dict):
|
if isinstance(image_embeds, dict):
|
||||||
embeds = {
|
embeds = {
|
||||||
k: self._connector.fetch_image_embedding(v)
|
k: self._connector.fetch_image_embedding(v)
|
||||||
@ -741,14 +781,13 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
|
|
||||||
|
|
||||||
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||||
|
|
||||||
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
|
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._tracker = tracker
|
self._tracker = tracker
|
||||||
self._connector = MediaConnector(
|
self._connector = MediaConnector(
|
||||||
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
|
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
|
||||||
allowed_local_media_path=tracker.allowed_local_media_path
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_image(self, image_url: str) -> None:
|
def parse_image(self, image_url: str) -> None:
|
||||||
@ -757,8 +796,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
placeholder = self._tracker.add("image", image_coro)
|
placeholder = self._tracker.add("image", image_coro)
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
def parse_image_embeds(self,
|
def parse_image_embeds(
|
||||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
self, image_embeds: Union[str, dict[str, str]]
|
||||||
|
) -> None:
|
||||||
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
|
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
|
||||||
|
|
||||||
if isinstance(image_embeds, dict):
|
if isinstance(image_embeds, dict):
|
||||||
@ -769,8 +809,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
future.set_result(embeds)
|
future.set_result(embeds)
|
||||||
|
|
||||||
if isinstance(image_embeds, str):
|
if isinstance(image_embeds, str):
|
||||||
embedding = self._connector.\
|
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||||
fetch_image_embedding(image_embeds)
|
|
||||||
future.set_result(embedding)
|
future.set_result(embedding)
|
||||||
|
|
||||||
placeholder = self._tracker.add("image_embeds", future)
|
placeholder = self._tracker.add("image_embeds", future)
|
||||||
@ -809,20 +848,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
|||||||
return
|
return
|
||||||
|
|
||||||
elif isinstance(chat_template, Path) and not chat_template.exists():
|
elif isinstance(chat_template, Path) and not chat_template.exists():
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError("the supplied chat template path doesn't exist")
|
||||||
"the supplied chat template path doesn't exist")
|
|
||||||
|
|
||||||
elif isinstance(chat_template, str):
|
elif isinstance(chat_template, str):
|
||||||
JINJA_CHARS = "{}\n"
|
JINJA_CHARS = "{}\n"
|
||||||
if not any(c in chat_template
|
if (
|
||||||
for c in JINJA_CHARS) and not Path(chat_template).exists():
|
not any(c in chat_template for c in JINJA_CHARS)
|
||||||
|
and not Path(chat_template).exists()
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The supplied chat template string ({chat_template}) "
|
f"The supplied chat template string ({chat_template}) "
|
||||||
f"appears path-like, but doesn't exist!")
|
f"appears path-like, but doesn't exist!"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"{type(chat_template)} is not a valid chat template type")
|
f"{type(chat_template)} is not a valid chat template type"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_chat_template(
|
def _load_chat_template(
|
||||||
@ -835,8 +877,9 @@ def _load_chat_template(
|
|||||||
|
|
||||||
if is_literal:
|
if is_literal:
|
||||||
if isinstance(chat_template, Path):
|
if isinstance(chat_template, Path):
|
||||||
raise TypeError("chat_template is expected to be read directly "
|
raise TypeError(
|
||||||
"from its value")
|
"chat_template is expected to be read directly from its value"
|
||||||
|
)
|
||||||
|
|
||||||
return chat_template
|
return chat_template
|
||||||
|
|
||||||
@ -849,9 +892,11 @@ def _load_chat_template(
|
|||||||
|
|
||||||
JINJA_CHARS = "{}\n"
|
JINJA_CHARS = "{}\n"
|
||||||
if not any(c in chat_template for c in JINJA_CHARS):
|
if not any(c in chat_template for c in JINJA_CHARS):
|
||||||
msg = (f"The supplied chat template ({chat_template}) "
|
msg = (
|
||||||
f"looks like a file path, but it failed to be "
|
f"The supplied chat template ({chat_template}) "
|
||||||
f"opened. Reason: {e}")
|
f"looks like a file path, but it failed to be "
|
||||||
|
f"opened. Reason: {e}"
|
||||||
|
)
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
# If opening a file fails, set chat template to be args to
|
# If opening a file fails, set chat template to be args to
|
||||||
@ -870,8 +915,9 @@ def load_chat_template(
|
|||||||
return _cached_load_chat_template(chat_template, is_literal=is_literal)
|
return _cached_load_chat_template(chat_template, is_literal=is_literal)
|
||||||
|
|
||||||
|
|
||||||
def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
|
def _get_interleaved_text_prompt(
|
||||||
texts: list[str]) -> str:
|
placeholder_storage: dict[str, list], texts: list[str]
|
||||||
|
) -> str:
|
||||||
for idx, elem in enumerate(texts):
|
for idx, elem in enumerate(texts):
|
||||||
if elem in placeholder_storage:
|
if elem in placeholder_storage:
|
||||||
texts[idx] = placeholder_storage[elem].pop(0)
|
texts[idx] = placeholder_storage[elem].pop(0)
|
||||||
@ -881,10 +927,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
|
|||||||
|
|
||||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||||
# (similar to chat template)
|
# (similar to chat template)
|
||||||
def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
|
def _get_full_multimodal_text_prompt(
|
||||||
texts: list[str],
|
placeholder_storage: dict[str, list],
|
||||||
interleave_strings: bool
|
texts: list[str],
|
||||||
) -> str:
|
interleave_strings: bool,
|
||||||
|
) -> str:
|
||||||
"""Combine multimodal prompts for a multimodal language model."""
|
"""Combine multimodal prompts for a multimodal language model."""
|
||||||
|
|
||||||
# flatten storage to make it looks like
|
# flatten storage to make it looks like
|
||||||
@ -907,7 +954,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
|
|||||||
# Look through the text prompt to check for missing placeholders
|
# Look through the text prompt to check for missing placeholders
|
||||||
missing_placeholders: list[str] = []
|
missing_placeholders: list[str] = []
|
||||||
for placeholder in placeholder_counts:
|
for placeholder in placeholder_counts:
|
||||||
|
|
||||||
# For any existing placeholder in the text prompt, we leave it as is
|
# For any existing placeholder in the text prompt, we leave it as is
|
||||||
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
||||||
|
|
||||||
@ -916,15 +962,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
|
|||||||
"Placeholder count is negative! "
|
"Placeholder count is negative! "
|
||||||
"Ensure that the 'interleave_strings' flag is disabled "
|
"Ensure that the 'interleave_strings' flag is disabled "
|
||||||
"(current value: %s) "
|
"(current value: %s) "
|
||||||
"when manually placing image placeholders.", interleave_strings
|
"when manually placing image placeholders.",
|
||||||
|
interleave_strings,
|
||||||
)
|
)
|
||||||
logger.debug("Input prompt: %s", text_prompt)
|
logger.debug("Input prompt: %s", text_prompt)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Found more '{placeholder}' placeholders in input prompt than "
|
f"Found more '{placeholder}' placeholders in input prompt than "
|
||||||
"actual multimodal data items.")
|
"actual multimodal data items."
|
||||||
|
)
|
||||||
|
|
||||||
missing_placeholders.extend([placeholder] *
|
missing_placeholders.extend(
|
||||||
placeholder_counts[placeholder])
|
[placeholder] * placeholder_counts[placeholder]
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE: Default behaviour: we always add missing placeholders
|
# NOTE: Default behaviour: we always add missing placeholders
|
||||||
# at the front of the prompt, if interleave_strings=False
|
# at the front of the prompt, if interleave_strings=False
|
||||||
@ -944,7 +993,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
|
|||||||
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
|
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
|
||||||
|
|
||||||
_ResponsesInputImageParser = TypeAdapter(
|
_ResponsesInputImageParser = TypeAdapter(
|
||||||
ResponseInputImageParam).validate_python
|
ResponseInputImageParam
|
||||||
|
).validate_python
|
||||||
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
|
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
|
||||||
|
|
||||||
# Define a mapping from part types to their corresponding parsing functions.
|
# Define a mapping from part types to their corresponding parsing functions.
|
||||||
@ -952,32 +1002,35 @@ MM_PARSER_MAP: dict[
|
|||||||
str,
|
str,
|
||||||
Callable[[ChatCompletionContentPartParam], _ContentPart],
|
Callable[[ChatCompletionContentPartParam], _ContentPart],
|
||||||
] = {
|
] = {
|
||||||
"text":
|
"text": lambda part: _TextParser(part).get("text", None),
|
||||||
lambda part: _TextParser(part).get("text", None),
|
"thinking": lambda part: _ThinkParser(part).get("thinking", None),
|
||||||
"thinking":
|
"input_text": lambda part: _TextParser(part).get("text", None),
|
||||||
lambda part: _ThinkParser(part).get("thinking", None),
|
"input_image": lambda part: _ResponsesInputImageParser(part).get(
|
||||||
"input_text":
|
"image_url", None
|
||||||
lambda part: _TextParser(part).get("text", None),
|
),
|
||||||
"input_image":
|
"image_url": lambda part: _ImageParser(part)
|
||||||
lambda part: _ResponsesInputImageParser(part).get("image_url", None),
|
.get("image_url", {})
|
||||||
"image_url":
|
.get("url", None),
|
||||||
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
|
"image_embeds": lambda part: _ImageEmbedsParser(part).get(
|
||||||
"image_embeds":
|
"image_embeds", None
|
||||||
lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
|
),
|
||||||
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
|
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
|
||||||
"audio_url":
|
"audio_url": lambda part: _AudioParser(part)
|
||||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
|
.get("audio_url", {})
|
||||||
"input_audio":
|
.get("url", None),
|
||||||
lambda part: _InputAudioParser(part).get("input_audio", None),
|
"input_audio": lambda part: _InputAudioParser(part).get(
|
||||||
"refusal":
|
"input_audio", None
|
||||||
lambda part: _RefusalParser(part).get("refusal", None),
|
),
|
||||||
"video_url":
|
"refusal": lambda part: _RefusalParser(part).get("refusal", None),
|
||||||
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
|
"video_url": lambda part: _VideoParser(part)
|
||||||
|
.get("video_url", {})
|
||||||
|
.get("url", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _parse_chat_message_content_mm_part(
|
def _parse_chat_message_content_mm_part(
|
||||||
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
|
part: ChatCompletionContentPartParam,
|
||||||
|
) -> tuple[str, _ContentPart]:
|
||||||
"""
|
"""
|
||||||
Parses a given multi-modal content part based on its type.
|
Parses a given multi-modal content part based on its type.
|
||||||
|
|
||||||
@ -993,7 +1046,8 @@ def _parse_chat_message_content_mm_part(
|
|||||||
ValueError: If the 'type' field is missing and no direct URL is found.
|
ValueError: If the 'type' field is missing and no direct URL is found.
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
part, dict) # This is needed to avoid mypy errors: part.get() from str
|
part, dict
|
||||||
|
) # This is needed to avoid mypy errors: part.get() from str
|
||||||
part_type = part.get("type", None)
|
part_type = part.get("type", None)
|
||||||
|
|
||||||
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
|
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
|
||||||
@ -1002,8 +1056,10 @@ def _parse_chat_message_content_mm_part(
|
|||||||
# Special case for 'image_url.detail'
|
# Special case for 'image_url.detail'
|
||||||
# We only support 'auto', which is the default
|
# We only support 'auto', which is the default
|
||||||
if part_type == "image_url" and part.get("detail", "auto") != "auto":
|
if part_type == "image_url" and part.get("detail", "auto") != "auto":
|
||||||
logger.warning("'image_url.detail' is currently not supported "
|
logger.warning(
|
||||||
"and will be ignored.")
|
"'image_url.detail' is currently not supported "
|
||||||
|
"and will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
return part_type, content
|
return part_type, content
|
||||||
|
|
||||||
@ -1011,19 +1067,22 @@ def _parse_chat_message_content_mm_part(
|
|||||||
# 'type' is required field by pydantic
|
# 'type' is required field by pydantic
|
||||||
if part_type is None:
|
if part_type is None:
|
||||||
if part.get("image_url") is not None:
|
if part.get("image_url") is not None:
|
||||||
image_params = cast(CustomChatCompletionContentSimpleImageParam,
|
image_params = cast(
|
||||||
part)
|
CustomChatCompletionContentSimpleImageParam, part
|
||||||
|
)
|
||||||
return "image_url", image_params.get("image_url", "")
|
return "image_url", image_params.get("image_url", "")
|
||||||
if part.get("audio_url") is not None:
|
if part.get("audio_url") is not None:
|
||||||
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
|
audio_params = cast(
|
||||||
part)
|
CustomChatCompletionContentSimpleAudioParam, part
|
||||||
|
)
|
||||||
return "audio_url", audio_params.get("audio_url", "")
|
return "audio_url", audio_params.get("audio_url", "")
|
||||||
if part.get("input_audio") is not None:
|
if part.get("input_audio") is not None:
|
||||||
input_audio_params = cast(dict[str, str], part)
|
input_audio_params = cast(dict[str, str], part)
|
||||||
return "input_audio", input_audio_params
|
return "input_audio", input_audio_params
|
||||||
if part.get("video_url") is not None:
|
if part.get("video_url") is not None:
|
||||||
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
|
video_params = cast(
|
||||||
part)
|
CustomChatCompletionContentSimpleVideoParam, part
|
||||||
|
)
|
||||||
return "video_url", video_params.get("video_url", "")
|
return "video_url", video_params.get("video_url", "")
|
||||||
# Raise an error if no 'type' or direct URL is found.
|
# Raise an error if no 'type' or direct URL is found.
|
||||||
raise ValueError("Missing 'type' field in multimodal part.")
|
raise ValueError("Missing 'type' field in multimodal part.")
|
||||||
@ -1033,9 +1092,16 @@ def _parse_chat_message_content_mm_part(
|
|||||||
return part_type, "unknown part_type content"
|
return part_type, "unknown part_type content"
|
||||||
|
|
||||||
|
|
||||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
|
VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
|
||||||
"image_embeds", "image_pil",
|
"text",
|
||||||
"audio_url", "input_audio", "video_url")
|
"refusal",
|
||||||
|
"image_url",
|
||||||
|
"image_embeds",
|
||||||
|
"image_pil",
|
||||||
|
"audio_url",
|
||||||
|
"input_audio",
|
||||||
|
"video_url",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_chat_message_content_parts(
|
def _parse_chat_message_content_parts(
|
||||||
@ -1055,21 +1121,20 @@ def _parse_chat_message_content_parts(
|
|||||||
part,
|
part,
|
||||||
mm_parser,
|
mm_parser,
|
||||||
wrap_dicts=wrap_dicts,
|
wrap_dicts=wrap_dicts,
|
||||||
interleave_strings=interleave_strings
|
interleave_strings=interleave_strings,
|
||||||
)
|
)
|
||||||
if parse_res:
|
if parse_res:
|
||||||
content.append(parse_res)
|
content.append(parse_res)
|
||||||
|
|
||||||
if wrap_dicts:
|
if wrap_dicts:
|
||||||
# Parsing wraps images and texts as interleaved dictionaries
|
# Parsing wraps images and texts as interleaved dictionaries
|
||||||
return [ConversationMessage(role=role,
|
return [ConversationMessage(role=role, content=content)] # type: ignore
|
||||||
content=content)] # type: ignore
|
|
||||||
texts = cast(list[str], content)
|
texts = cast(list[str], content)
|
||||||
mm_placeholder_storage = mm_parser.mm_placeholder_storage()
|
mm_placeholder_storage = mm_parser.mm_placeholder_storage()
|
||||||
if mm_placeholder_storage:
|
if mm_placeholder_storage:
|
||||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage,
|
text_prompt = _get_full_multimodal_text_prompt(
|
||||||
texts,
|
mm_placeholder_storage, texts, interleave_strings
|
||||||
interleave_strings)
|
)
|
||||||
else:
|
else:
|
||||||
text_prompt = "\n".join(texts)
|
text_prompt = "\n".join(texts)
|
||||||
|
|
||||||
@ -1099,13 +1164,16 @@ def _parse_chat_message_content_part(
|
|||||||
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
|
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping multimodal part '%s' (type: '%s') "
|
"Skipping multimodal part '%s' (type: '%s') "
|
||||||
"with empty / unparsable content.", part, part_type)
|
"with empty / unparsable content.",
|
||||||
|
part,
|
||||||
|
part_type,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if part_type in ("text", "input_text", "refusal", "thinking"):
|
if part_type in ("text", "input_text", "refusal", "thinking"):
|
||||||
str_content = cast(str, content)
|
str_content = cast(str, content)
|
||||||
if wrap_dicts:
|
if wrap_dicts:
|
||||||
return {'type': 'text', 'text': str_content}
|
return {"type": "text", "text": str_content}
|
||||||
else:
|
else:
|
||||||
return str_content
|
return str_content
|
||||||
|
|
||||||
@ -1137,8 +1205,12 @@ def _parse_chat_message_content_part(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
return {'type': modality} if wrap_dicts else (
|
return (
|
||||||
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
|
{"type": modality}
|
||||||
|
if wrap_dicts
|
||||||
|
else (
|
||||||
|
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1171,14 +1243,16 @@ def _parse_chat_message_content(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for result_msg in result:
|
for result_msg in result:
|
||||||
if role == 'assistant':
|
if role == "assistant":
|
||||||
parsed_msg = _AssistantParser(message)
|
parsed_msg = _AssistantParser(message)
|
||||||
|
|
||||||
# The 'tool_calls' is not None check ensures compatibility.
|
# The 'tool_calls' is not None check ensures compatibility.
|
||||||
# It's needed only if downstream code doesn't strictly
|
# It's needed only if downstream code doesn't strictly
|
||||||
# follow the OpenAI spec.
|
# follow the OpenAI spec.
|
||||||
if ("tool_calls" in parsed_msg
|
if (
|
||||||
and parsed_msg["tool_calls"] is not None):
|
"tool_calls" in parsed_msg
|
||||||
|
and parsed_msg["tool_calls"] is not None
|
||||||
|
):
|
||||||
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
|
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
parsed_msg = _ToolParser(message)
|
parsed_msg = _ToolParser(message)
|
||||||
@ -1198,12 +1272,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
|
|||||||
# so, for messages that have tool_calls, parse the string (which we get
|
# so, for messages that have tool_calls, parse the string (which we get
|
||||||
# from openAI format) to dict
|
# from openAI format) to dict
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if (message["role"] == "assistant" and "tool_calls" in message
|
if (
|
||||||
and isinstance(message["tool_calls"], list)):
|
message["role"] == "assistant"
|
||||||
|
and "tool_calls" in message
|
||||||
|
and isinstance(message["tool_calls"], list)
|
||||||
|
):
|
||||||
for item in message["tool_calls"]:
|
for item in message["tool_calls"]:
|
||||||
item["function"]["arguments"] = json.loads(
|
item["function"]["arguments"] = json.loads(
|
||||||
item["function"]["arguments"])
|
item["function"]["arguments"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_chat_messages(
|
def parse_chat_messages(
|
||||||
@ -1224,7 +1301,7 @@ def parse_chat_messages(
|
|||||||
content_format == "string"
|
content_format == "string"
|
||||||
and model_config.multimodal_config is not None
|
and model_config.multimodal_config is not None
|
||||||
and model_config.multimodal_config.interleave_mm_strings
|
and model_config.multimodal_config.interleave_mm_strings
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation.extend(sub_messages)
|
conversation.extend(sub_messages)
|
||||||
@ -1252,7 +1329,7 @@ def parse_chat_messages_futures(
|
|||||||
content_format == "string"
|
content_format == "string"
|
||||||
and model_config.multimodal_config is not None
|
and model_config.multimodal_config is not None
|
||||||
and model_config.multimodal_config.interleave_mm_strings
|
and model_config.multimodal_config.interleave_mm_strings
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation.extend(sub_messages)
|
conversation.extend(sub_messages)
|
||||||
@ -1283,10 +1360,10 @@ def apply_hf_chat_template(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"As of transformers v4.44, default chat template is no longer "
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
"allowed, so you must provide a chat template if the tokenizer "
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
"does not define one.")
|
"does not define one."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
return tokenizer.apply_chat_template(
|
return tokenizer.apply_chat_template(
|
||||||
conversation=conversation, # type: ignore[arg-type]
|
conversation=conversation, # type: ignore[arg-type]
|
||||||
tools=tools, # type: ignore[arg-type]
|
tools=tools, # type: ignore[arg-type]
|
||||||
@ -1298,13 +1375,14 @@ def apply_hf_chat_template(
|
|||||||
# External library exceptions can sometimes occur despite the framework's
|
# External library exceptions can sometimes occur despite the framework's
|
||||||
# internal exception management capabilities.
|
# internal exception management capabilities.
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
# Log and report any library-related exceptions for further
|
# Log and report any library-related exceptions for further
|
||||||
# investigation.
|
# investigation.
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"An error occurred in `transformers` while applying chat template")
|
"An error occurred in `transformers` while applying chat template"
|
||||||
|
)
|
||||||
raise ValueError(str(e)) from e
|
raise ValueError(str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
def apply_mistral_chat_template(
|
def apply_mistral_chat_template(
|
||||||
tokenizer: MistralTokenizer,
|
tokenizer: MistralTokenizer,
|
||||||
messages: list[ChatCompletionMessageParam],
|
messages: list[ChatCompletionMessageParam],
|
||||||
@ -1337,26 +1415,26 @@ def apply_mistral_chat_template(
|
|||||||
# External library exceptions can sometimes occur despite the framework's
|
# External library exceptions can sometimes occur despite the framework's
|
||||||
# internal exception management capabilities.
|
# internal exception management capabilities.
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
# Log and report any library-related exceptions for further
|
# Log and report any library-related exceptions for further
|
||||||
# investigation.
|
# investigation.
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"An error occurred in `mistral_common` while applying chat "
|
"An error occurred in `mistral_common` while applying chat template"
|
||||||
"template")
|
)
|
||||||
raise ValueError(str(e)) from e
|
raise ValueError(str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
|
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
|
||||||
idx = 0
|
idx = 0
|
||||||
for msg in conversation:
|
for msg in conversation:
|
||||||
if msg['role'] == 'assistant':
|
if msg["role"] == "assistant":
|
||||||
tool_calls = msg.get('tool_calls')
|
tool_calls = msg.get("tool_calls")
|
||||||
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
|
|
||||||
|
|
||||||
if id_type=='kimi_k2':
|
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
|
||||||
return f'functions.{func_name}:{idx}'
|
if id_type == "kimi_k2":
|
||||||
|
return f"functions.{func_name}:{idx}"
|
||||||
else:
|
else:
|
||||||
# by default return random
|
# by default return random
|
||||||
return f"chatcmpl-tool-{random_uuid()}"
|
return f"chatcmpl-tool-{random_uuid()}"
|
||||||
|
|||||||
@ -82,16 +82,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
CompletionLikeRequest = Union[
|
||||||
EmbeddingCompletionRequest, RerankRequest,
|
CompletionRequest,
|
||||||
ClassificationRequest, ScoreRequest,
|
DetokenizeRequest,
|
||||||
TokenizeCompletionRequest]
|
EmbeddingCompletionRequest,
|
||||||
|
RerankRequest,
|
||||||
|
ClassificationRequest,
|
||||||
|
ScoreRequest,
|
||||||
|
TokenizeCompletionRequest,
|
||||||
|
]
|
||||||
|
|
||||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||||
TokenizeChatRequest]
|
TokenizeChatRequest]
|
||||||
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
|
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
|
||||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
|
AnyRequest = Union[
|
||||||
ResponsesRequest, IOProcessorRequest]
|
CompletionLikeRequest,
|
||||||
|
ChatLikeRequest,
|
||||||
|
SpeechToTextRequest,
|
||||||
|
ResponsesRequest,
|
||||||
|
IOProcessorRequest,
|
||||||
|
]
|
||||||
|
|
||||||
AnyResponse = Union[
|
AnyResponse = Union[
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
@ -135,6 +145,7 @@ class RequestProcessingMixin(BaseModel):
|
|||||||
Mixin for request processing,
|
Mixin for request processing,
|
||||||
handling prompt preparation and engine input.
|
handling prompt preparation and engine input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
||||||
engine_prompts: Optional[Union[list[EngineTokensPrompt],
|
engine_prompts: Optional[Union[list[EngineTokensPrompt],
|
||||||
list[EngineEmbedsPrompt]]] = []
|
list[EngineEmbedsPrompt]]] = []
|
||||||
@ -147,6 +158,7 @@ class ResponseGenerationMixin(BaseModel):
|
|||||||
Mixin for response generation,
|
Mixin for response generation,
|
||||||
managing result generators and final batch results.
|
managing result generators and final batch results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result_generator: Optional[AsyncGenerator[tuple[int, Union[
|
result_generator: Optional[AsyncGenerator[tuple[int, Union[
|
||||||
RequestOutput, PoolingRequestOutput]], None]] = None
|
RequestOutput, PoolingRequestOutput]], None]] = None
|
||||||
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
|
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
|
||||||
@ -155,8 +167,12 @@ class ResponseGenerationMixin(BaseModel):
|
|||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
|
class ServeContext(
|
||||||
Generic[RequestT]):
|
RequestProcessingMixin,
|
||||||
|
ResponseGenerationMixin,
|
||||||
|
BaseModel,
|
||||||
|
Generic[RequestT],
|
||||||
|
):
|
||||||
# Shared across all requests
|
# Shared across all requests
|
||||||
request: RequestT
|
request: RequestT
|
||||||
raw_request: Optional[Request] = None
|
raw_request: Optional[Request] = None
|
||||||
@ -298,8 +314,8 @@ class OpenAIServing:
|
|||||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
||||||
None)
|
None)
|
||||||
|
|
||||||
if truncate_prompt_tokens is not None and \
|
if (truncate_prompt_tokens is not None
|
||||||
truncate_prompt_tokens > self.max_model_len:
|
and truncate_prompt_tokens > self.max_model_len):
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"truncate_prompt_tokens value is "
|
"truncate_prompt_tokens value is "
|
||||||
"greater than max_model_len."
|
"greater than max_model_len."
|
||||||
@ -344,10 +360,12 @@ class OpenAIServing:
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Request prompts not available")
|
"Request prompts not available")
|
||||||
|
|
||||||
self._log_inputs(request_id_item,
|
self._log_inputs(
|
||||||
ctx.request_prompts[i],
|
request_id_item,
|
||||||
params=pooling_params,
|
ctx.request_prompts[i],
|
||||||
lora_request=ctx.lora_request)
|
params=pooling_params,
|
||||||
|
lora_request=ctx.lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
# Mypy has an existing bug related to inferring the variance of
|
# Mypy has an existing bug related to inferring the variance of
|
||||||
# TypedDicts with `builtins.enumerate`:
|
# TypedDicts with `builtins.enumerate`:
|
||||||
@ -410,10 +428,11 @@ class OpenAIServing:
|
|||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
def create_error_response(
|
def create_error_response(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
err_type: str = "BadRequestError",
|
err_type: str = "BadRequestError",
|
||||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||||
|
) -> ErrorResponse:
|
||||||
if self.log_error_stack:
|
if self.log_error_stack:
|
||||||
exc_type, _, _ = sys.exc_info()
|
exc_type, _, _ = sys.exc_info()
|
||||||
if exc_type is not None:
|
if exc_type is not None:
|
||||||
@ -424,10 +443,11 @@ class OpenAIServing:
|
|||||||
message=message, type=err_type, code=status_code.value))
|
message=message, type=err_type, code=status_code.value))
|
||||||
|
|
||||||
def create_streaming_error_response(
|
def create_streaming_error_response(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
err_type: str = "BadRequestError",
|
err_type: str = "BadRequestError",
|
||||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||||
|
) -> str:
|
||||||
json_str = json.dumps(
|
json_str = json.dumps(
|
||||||
self.create_error_response(message=message,
|
self.create_error_response(message=message,
|
||||||
err_type=err_type,
|
err_type=err_type,
|
||||||
@ -438,25 +458,25 @@ class OpenAIServing:
|
|||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
) -> Optional[ErrorResponse]:
|
) -> Optional[ErrorResponse]:
|
||||||
|
|
||||||
error_response = None
|
error_response = None
|
||||||
|
|
||||||
if self._is_model_supported(request.model):
|
if self._is_model_supported(request.model):
|
||||||
return None
|
return None
|
||||||
if request.model in self.models.lora_requests:
|
if request.model in self.models.lora_requests:
|
||||||
return None
|
return None
|
||||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
|
if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and
|
||||||
load_result := await self.models.resolve_lora(request.model)):
|
(load_result := await self.models.resolve_lora(request.model))):
|
||||||
if isinstance(load_result, LoRARequest):
|
if isinstance(load_result, LoRARequest):
|
||||||
return None
|
return None
|
||||||
if isinstance(load_result, ErrorResponse) and \
|
if (isinstance(load_result, ErrorResponse) and
|
||||||
load_result.error.code == HTTPStatus.BAD_REQUEST.value:
|
load_result.error.code == HTTPStatus.BAD_REQUEST.value):
|
||||||
error_response = load_result
|
error_response = load_result
|
||||||
|
|
||||||
return error_response or self.create_error_response(
|
return error_response or self.create_error_response(
|
||||||
message=f"The model `{request.model}` does not exist.",
|
message=f"The model `{request.model}` does not exist.",
|
||||||
err_type="NotFoundError",
|
err_type="NotFoundError",
|
||||||
status_code=HTTPStatus.NOT_FOUND)
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_active_default_mm_loras(
|
def _get_active_default_mm_loras(
|
||||||
self, request: AnyRequest) -> Optional[LoRARequest]:
|
self, request: AnyRequest) -> Optional[LoRARequest]:
|
||||||
@ -487,7 +507,6 @@ class OpenAIServing:
|
|||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
supports_default_mm_loras: bool = False,
|
supports_default_mm_loras: bool = False,
|
||||||
) -> Optional[LoRARequest]:
|
) -> Optional[LoRARequest]:
|
||||||
|
|
||||||
if request.model in self.models.lora_requests:
|
if request.model in self.models.lora_requests:
|
||||||
return self.models.lora_requests[request.model]
|
return self.models.lora_requests[request.model]
|
||||||
|
|
||||||
@ -548,13 +567,15 @@ class OpenAIServing:
|
|||||||
prompt,
|
prompt,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_model_len)
|
max_length=self.max_model_len,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
encoded = await async_tokenizer(
|
encoded = await async_tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=truncate_prompt_tokens)
|
max_length=truncate_prompt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = encoded.input_ids
|
input_ids = encoded.input_ids
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
@ -595,16 +616,22 @@ class OpenAIServing:
|
|||||||
|
|
||||||
# Note: EmbeddingRequest, ClassificationRequest,
|
# Note: EmbeddingRequest, ClassificationRequest,
|
||||||
# and ScoreRequest doesn't have max_tokens
|
# and ScoreRequest doesn't have max_tokens
|
||||||
if isinstance(request,
|
if isinstance(
|
||||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
request,
|
||||||
ScoreRequest, RerankRequest, ClassificationRequest)):
|
(
|
||||||
|
EmbeddingChatRequest,
|
||||||
|
EmbeddingCompletionRequest,
|
||||||
|
ScoreRequest,
|
||||||
|
RerankRequest,
|
||||||
|
ClassificationRequest,
|
||||||
|
),
|
||||||
|
):
|
||||||
# Note: input length can be up to the entire model context length
|
# Note: input length can be up to the entire model context length
|
||||||
# since these requests don't generate tokens.
|
# since these requests don't generate tokens.
|
||||||
if token_num > self.max_model_len:
|
if token_num > self.max_model_len:
|
||||||
operations: dict[type[AnyRequest], str] = {
|
operations: dict[type[AnyRequest], str] = {
|
||||||
ScoreRequest: "score",
|
ScoreRequest: "score",
|
||||||
ClassificationRequest: "classification"
|
ClassificationRequest: "classification",
|
||||||
}
|
}
|
||||||
operation = operations.get(type(request),
|
operation = operations.get(type(request),
|
||||||
"embedding generation")
|
"embedding generation")
|
||||||
@ -618,8 +645,11 @@ class OpenAIServing:
|
|||||||
|
|
||||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||||
# and does not require model context length validation
|
# and does not require model context length validation
|
||||||
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
if isinstance(
|
||||||
DetokenizeRequest)):
|
request,
|
||||||
|
(TokenizeCompletionRequest, TokenizeChatRequest,
|
||||||
|
DetokenizeRequest),
|
||||||
|
):
|
||||||
return TextTokensPrompt(prompt=input_text,
|
return TextTokensPrompt(prompt=input_text,
|
||||||
prompt_token_ids=input_ids)
|
prompt_token_ids=input_ids)
|
||||||
|
|
||||||
@ -639,8 +669,8 @@ class OpenAIServing:
|
|||||||
f"{token_num} input tokens. Please reduce the length of "
|
f"{token_num} input tokens. Please reduce the length of "
|
||||||
"the input messages.")
|
"the input messages.")
|
||||||
|
|
||||||
if max_tokens is not None and \
|
if (max_tokens is not None
|
||||||
token_num + max_tokens > self.max_model_len:
|
and token_num + max_tokens > self.max_model_len):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"'max_tokens' or 'max_completion_tokens' is too large: "
|
"'max_tokens' or 'max_completion_tokens' is too large: "
|
||||||
f"{max_tokens}. This model's maximum context length is "
|
f"{max_tokens}. This model's maximum context length is "
|
||||||
@ -745,13 +775,14 @@ class OpenAIServing:
|
|||||||
tasks = []
|
tasks = []
|
||||||
for prompt_input in batch_inputs:
|
for prompt_input in batch_inputs:
|
||||||
if prompt_input["is_tokens"] is False:
|
if prompt_input["is_tokens"] is False:
|
||||||
assert tokenizer is not None, \
|
assert tokenizer is not None, (
|
||||||
"Tokenizer is required for text prompts"
|
"Tokenizer is required for text prompts")
|
||||||
task = self._normalize_prompt_text_to_input(
|
task = self._normalize_prompt_text_to_input(
|
||||||
request,
|
request,
|
||||||
prompt_input["content"],
|
prompt_input["content"],
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
add_special_tokens=add_special_tokens)
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
task = self._normalize_prompt_tokens_to_input(
|
task = self._normalize_prompt_tokens_to_input(
|
||||||
request, prompt_input["content"], tokenizer=tokenizer)
|
request, prompt_input["content"], tokenizer=tokenizer)
|
||||||
@ -766,9 +797,14 @@ class OpenAIServing:
|
|||||||
@overload
|
@overload
|
||||||
async def _preprocess_completion(
|
async def _preprocess_completion(
|
||||||
self,
|
self,
|
||||||
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
|
request: Union[
|
||||||
RerankRequest, ClassificationRequest, ScoreRequest,
|
DetokenizeRequest,
|
||||||
TokenizeCompletionRequest],
|
EmbeddingCompletionRequest,
|
||||||
|
RerankRequest,
|
||||||
|
ClassificationRequest,
|
||||||
|
ScoreRequest,
|
||||||
|
TokenizeCompletionRequest,
|
||||||
|
],
|
||||||
tokenizer: Optional[AnyTokenizer],
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||||
add_special_tokens: bool = ...,
|
add_special_tokens: bool = ...,
|
||||||
@ -783,8 +819,10 @@ class OpenAIServing:
|
|||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]],
|
list[list[int]]]],
|
||||||
add_special_tokens: bool = ...,
|
add_special_tokens: bool = ...,
|
||||||
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
|
) -> tuple[
|
||||||
EngineTokensPrompt, EngineEmbedsPrompt]]]:
|
list[Union[TextTokensPrompt, EmbedsPrompt]],
|
||||||
|
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
|
||||||
|
]:
|
||||||
...
|
...
|
||||||
|
|
||||||
async def _preprocess_completion(
|
async def _preprocess_completion(
|
||||||
@ -794,32 +832,38 @@ class OpenAIServing:
|
|||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]],
|
list[list[int]]]],
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
) -> tuple[Union[list[TextTokensPrompt], list[Union[
|
) -> tuple[
|
||||||
TextTokensPrompt, EmbedsPrompt]]], Union[
|
Union[list[TextTokensPrompt], list[Union[TextTokensPrompt,
|
||||||
list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
|
EmbedsPrompt]]],
|
||||||
EngineEmbedsPrompt]]]]:
|
Union[
|
||||||
if not isinstance(request,
|
list[EngineTokensPrompt],
|
||||||
CompletionRequest) and input_or_inputs is None:
|
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
|
||||||
|
],
|
||||||
|
]:
|
||||||
|
if (not isinstance(request, CompletionRequest)
|
||||||
|
and input_or_inputs is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Prompt embeds with non-completion requests is not"
|
"Prompt embeds with non-completion requests is not"
|
||||||
" currently supported.")
|
" currently supported.")
|
||||||
|
|
||||||
(request_prompts_text, request_prompts_embeds
|
(
|
||||||
) = await self._tokenize_prompt_input_or_inputs_async(
|
request_prompts_text,
|
||||||
request,
|
request_prompts_embeds,
|
||||||
tokenizer,
|
) = await self._tokenize_prompt_input_or_inputs_async(
|
||||||
input_or_inputs,
|
request,
|
||||||
add_special_tokens=add_special_tokens,
|
tokenizer,
|
||||||
)
|
input_or_inputs,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
engine_prompts_text = [
|
engine_prompts_text = [
|
||||||
EngineTokensPrompt(
|
EngineTokensPrompt(
|
||||||
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
||||||
for request_prompt_text in request_prompts_text
|
for request_prompt_text in request_prompts_text
|
||||||
]
|
]
|
||||||
cache_salt = request.cache_salt if (
|
cache_salt = (request.cache_salt if
|
||||||
hasattr(request, "cache_salt")
|
(hasattr(request, "cache_salt")
|
||||||
and request.cache_salt is not None) else None
|
and request.cache_salt is not None) else None)
|
||||||
if cache_salt:
|
if cache_salt:
|
||||||
for prompt_text in engine_prompts_text:
|
for prompt_text in engine_prompts_text:
|
||||||
prompt_text["cache_salt"] = cache_salt
|
prompt_text["cache_salt"] = cache_salt
|
||||||
@ -831,8 +875,8 @@ class OpenAIServing:
|
|||||||
# non-completion requests and if we don't add the overload here,
|
# non-completion requests and if we don't add the overload here,
|
||||||
# everywhere this function is used outside of serving_completion will
|
# everywhere this function is used outside of serving_completion will
|
||||||
# need logic asserting that only text prompts are in the request.
|
# need logic asserting that only text prompts are in the request.
|
||||||
if not isinstance(request,
|
if (not isinstance(request, CompletionRequest)
|
||||||
CompletionRequest) and input_or_inputs is not None:
|
and input_or_inputs is not None):
|
||||||
return request_prompts_text, engine_prompts_text
|
return request_prompts_text, engine_prompts_text
|
||||||
|
|
||||||
engine_prompts_embeds = [
|
engine_prompts_embeds = [
|
||||||
@ -862,8 +906,11 @@ class OpenAIServing:
|
|||||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||||
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
||||||
add_special_tokens: bool = False,
|
add_special_tokens: bool = False,
|
||||||
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
) -> tuple[
|
||||||
list[EngineTokensPrompt]]:
|
list[ConversationMessage],
|
||||||
|
Sequence[RequestPrompt],
|
||||||
|
list[EngineTokensPrompt],
|
||||||
|
]:
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
|
|
||||||
resolved_content_format = resolve_chat_template_content_format(
|
resolved_content_format = resolve_chat_template_content_format(
|
||||||
@ -925,8 +972,8 @@ class OpenAIServing:
|
|||||||
|
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
assert isinstance(request_prompt, str), (
|
assert isinstance(request_prompt, str), (
|
||||||
"Prompt has to be a string", \
|
"Prompt has to be a string",
|
||||||
"when the tokenizer is not initialised"
|
"when the tokenizer is not initialised",
|
||||||
)
|
)
|
||||||
prompt_inputs = TextTokensPrompt(prompt=request_prompt,
|
prompt_inputs = TextTokensPrompt(prompt=request_prompt,
|
||||||
prompt_token_ids=[1])
|
prompt_token_ids=[1])
|
||||||
@ -943,7 +990,8 @@ class OpenAIServing:
|
|||||||
"Prompt has to be either a string or a list of token ids")
|
"Prompt has to be either a string or a list of token ids")
|
||||||
prompt_inputs = TextTokensPrompt(
|
prompt_inputs = TextTokensPrompt(
|
||||||
prompt=tokenizer.decode(request_prompt),
|
prompt=tokenizer.decode(request_prompt),
|
||||||
prompt_token_ids=request_prompt)
|
prompt_token_ids=request_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
engine_prompt = EngineTokensPrompt(
|
engine_prompt = EngineTokensPrompt(
|
||||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||||
@ -1007,22 +1055,23 @@ class OpenAIServing:
|
|||||||
prompt_token_ids=prompt_token_ids)
|
prompt_token_ids=prompt_token_ids)
|
||||||
request_prompt = prompt_token_ids
|
request_prompt = prompt_token_ids
|
||||||
# Update the sampling params.
|
# Update the sampling params.
|
||||||
sampling_params.max_tokens = (self.max_model_len -
|
sampling_params.max_tokens = self.max_model_len - len(
|
||||||
len(prompt_token_ids))
|
prompt_token_ids)
|
||||||
# OPTIMIZATION
|
# OPTIMIZATION
|
||||||
priority = orig_priority - 1
|
priority = orig_priority - 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_prompt_embeds(
|
def _load_prompt_embeds(
|
||||||
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||||
) -> list[EmbedsPrompt]:
|
) -> list[EmbedsPrompt]:
|
||||||
|
|
||||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||||
tensor = torch.load(io.BytesIO(
|
tensor = torch.load(
|
||||||
pybase64.b64decode(embed, validate=True)),
|
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||||
weights_only=True,
|
weights_only=True,
|
||||||
map_location=torch.device("cpu"))
|
map_location=torch.device("cpu"),
|
||||||
|
)
|
||||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||||
torch.float32,
|
torch.float32,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
@ -1061,7 +1110,7 @@ class OpenAIServing:
|
|||||||
prompt = inputs
|
prompt = inputs
|
||||||
elif isinstance(inputs, list):
|
elif isinstance(inputs, list):
|
||||||
prompt_token_ids = inputs
|
prompt_token_ids = inputs
|
||||||
elif 'prompt_embeds' in inputs:
|
elif "prompt_embeds" in inputs:
|
||||||
prompt_embeds = inputs.get("prompt_embeds")
|
prompt_embeds = inputs.get("prompt_embeds")
|
||||||
else:
|
else:
|
||||||
prompt = inputs["prompt"]
|
prompt = inputs["prompt"]
|
||||||
@ -1101,10 +1150,12 @@ class OpenAIServing:
|
|||||||
return raw_request.headers.get("X-Request-Id", default)
|
return raw_request.headers.get("X-Request-Id", default)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_decoded_token(logprob: Logprob,
|
def _get_decoded_token(
|
||||||
token_id: int,
|
logprob: Logprob,
|
||||||
tokenizer: AnyTokenizer,
|
token_id: int,
|
||||||
return_as_token_id: bool = False) -> str:
|
tokenizer: AnyTokenizer,
|
||||||
|
return_as_token_id: bool = False,
|
||||||
|
) -> str:
|
||||||
if return_as_token_id:
|
if return_as_token_id:
|
||||||
return f"token_id:{token_id}"
|
return f"token_id:{token_id}"
|
||||||
|
|
||||||
@ -1117,9 +1168,11 @@ class OpenAIServing:
|
|||||||
return True
|
return True
|
||||||
return self.models.is_base_model(model_name)
|
return self.models.is_base_model(model_name)
|
||||||
|
|
||||||
def _get_model_name(self,
|
def _get_model_name(
|
||||||
model_name: Optional[str] = None,
|
self,
|
||||||
lora_request: Optional[LoRARequest] = None) -> str:
|
model_name: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
) -> str:
|
||||||
if lora_request:
|
if lora_request:
|
||||||
return lora_request.lora_name
|
return lora_request.lora_name
|
||||||
if not model_name:
|
if not model_name:
|
||||||
@ -1129,7 +1182,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
def clamp_prompt_logprobs(
|
def clamp_prompt_logprobs(
|
||||||
prompt_logprobs: Union[PromptLogprobs,
|
prompt_logprobs: Union[PromptLogprobs,
|
||||||
None]) -> Union[PromptLogprobs, None]:
|
None], ) -> Union[PromptLogprobs, None]:
|
||||||
if prompt_logprobs is None:
|
if prompt_logprobs is None:
|
||||||
return prompt_logprobs
|
return prompt_logprobs
|
||||||
|
|
||||||
@ -1137,6 +1190,6 @@ def clamp_prompt_logprobs(
|
|||||||
if logprob_dict is None:
|
if logprob_dict is None:
|
||||||
continue
|
continue
|
||||||
for logprob_values in logprob_dict.values():
|
for logprob_values in logprob_dict.values():
|
||||||
if logprob_values.logprob == float('-inf'):
|
if logprob_values.logprob == float("-inf"):
|
||||||
logprob_values.logprob = -9999.0
|
logprob_values.logprob = -9999.0
|
||||||
return prompt_logprobs
|
return prompt_logprobs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user