[Core] Add audio_embeds support to chat completions (#29059)

Signed-off-by: Jeremy Teboul <jeremyteboul@fb.com>
Co-authored-by: Jeremy Teboul <jeremyteboul@fb.com>
This commit is contained in:
jeremyteboul 2025-11-20 19:39:47 -08:00 committed by GitHub
parent a982f5b5ea
commit 0730414999
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 360 additions and 3 deletions

View File

@ -365,6 +365,8 @@ You must enable this feature via `enable_mm_embeds=True`.
The vLLM engine may crash if incorrect shape of embeddings is passed. The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users! Only enable this flag for trusted users!
#### Image Embeddings
??? code ??? code
```python ```python
@ -441,6 +443,36 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
print(generated_text) print(generated_text)
``` ```
#### Audio Embeddings
You can pass pre-computed audio embeddings similar to image embeddings:
??? code
```python
from vllm import LLM
import torch
# Enable audio embeddings support
llm = LLM(model="fixie-ai/ultravox-v0_5-llama-3_2-1b", enable_mm_embeds=True)
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <audio>\nWhat is in this audio?\nASSISTANT:"
# Load pre-computed audio embeddings
# torch.Tensor of shape (1, audio_feature_size, hidden_size of LM)
audio_embeds = torch.load(...)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"audio": audio_embeds},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
```
## Online Serving ## Online Serving
Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests. Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests.

View File

@ -103,6 +103,19 @@ def qwen2_audio_model_config():
) )
@pytest.fixture(scope="function")
def audio_embeds_model_config():
return ModelConfig(
QWEN2AUDIO_MODEL_ID,
runner="generate",
trust_remote_code=True,
limit_mm_per_prompt={
"audio": 2,
},
enable_mm_embeds=True,
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def qwen2_audio_tokenizer(): def qwen2_audio_tokenizer():
return get_tokenizer(QWEN2AUDIO_MODEL_ID) return get_tokenizer(QWEN2AUDIO_MODEL_ID)
@ -843,6 +856,138 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
def test_parse_chat_messages_empty_audio_embeds_with_uuid(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with UUID (no actual embeds data)."""
uuid = "test-audio-uuid-123"
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this audio"},
{"type": "audio_embeds", "audio_embeds": None, "uuid": uuid},
],
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
# Should have audio in mm_data as None (UUID provided)
assert mm_data is not None
assert "audio" in mm_data
assert mm_data["audio"] is None
# UUID should be recorded
assert mm_uuids is not None
assert "audio" in mm_uuids
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
def test_parse_chat_messages_audio_embeds_with_string(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with base64 string embedding data."""
import base64
import io
import torch
# Create a sample audio embedding tensor
audio_embedding = torch.randn(1, 128, 768)
# Encode it as base64
buffer = io.BytesIO()
torch.save(audio_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this audio"},
{
"type": "audio_embeds",
"audio_embeds": base64_audio_embedding,
},
],
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
# Should have audio embedding in mm_data (single tensor, not a list)
assert mm_data is not None
assert "audio" in mm_data
assert isinstance(mm_data["audio"], torch.Tensor)
assert mm_data["audio"].shape == audio_embedding.shape
# No UUID provided
assert mm_uuids is not None
assert "audio" in mm_uuids
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
@pytest.mark.asyncio
async def test_parse_chat_messages_audio_embeds_async(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with async futures."""
import base64
import io
import torch
# Create a sample audio embedding tensor
audio_embedding = torch.randn(1, 128, 768)
# Encode it as base64
buffer = io.BytesIO()
torch.save(audio_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
[
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this audio"},
{
"type": "audio_embeds",
"audio_embeds": base64_audio_embedding,
},
],
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
# Should have audio embedding in mm_data (single tensor, not a list)
mm_data = await mm_future
assert mm_data is not None
assert "audio" in mm_data
assert isinstance(mm_data["audio"], torch.Tensor)
assert mm_data["audio"].shape == audio_embedding.shape
# No UUID provided
assert mm_uuids is not None
assert "audio" in mm_uuids
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config_image_embeds, phi3v_model_config_image_embeds,

View File

@ -94,6 +94,22 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
""" """
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
audio_embeds: str | dict[str, str] | None
"""
The audio embeddings. It can be either:
- A single base64 string representing a serialized torch tensor.
- A dictionary where each value is a base64 string.
"""
type: Required[Literal["audio_embeds"]]
"""The type of the content part."""
uuid: str | None
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class VideoURL(TypedDict, total=False): class VideoURL(TypedDict, total=False):
url: Required[str] url: Required[str]
""" """
@ -211,6 +227,7 @@ ChatCompletionContentPartParam: TypeAlias = (
| CustomChatCompletionContentPILImageParam | CustomChatCompletionContentPILImageParam
| CustomChatCompletionContentSimpleImageParam | CustomChatCompletionContentSimpleImageParam
| ChatCompletionContentPartImageEmbedsParam | ChatCompletionContentPartImageEmbedsParam
| ChatCompletionContentPartAudioEmbedsParam
| CustomChatCompletionContentSimpleAudioParam | CustomChatCompletionContentSimpleAudioParam
| CustomChatCompletionContentSimpleVideoParam | CustomChatCompletionContentSimpleVideoParam
| str | str
@ -599,7 +616,7 @@ def resolve_chat_template_content_format(
return detected_format return detected_format
ModalityStr = Literal["image", "audio", "video", "image_embeds"] ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T = TypeVar("_T") _T = TypeVar("_T")
@ -684,6 +701,11 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids["image"] = uuids_by_modality["image_embeds"] mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality: if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
if "audio_embeds" in uuids_by_modality:
audio_embeds_uuids = uuids_by_modality["audio_embeds"]
if len(audio_embeds_uuids) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
if "audio" in uuids_by_modality: if "audio" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
if "video" in uuids_by_modality: if "video" in uuids_by_modality:
@ -703,6 +725,8 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
items_by_modality = dict(self._items_by_modality) items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed") raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
if "image_embeds" in items_by_modality: if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"] image_embeds_lst = items_by_modality["image_embeds"]
@ -711,6 +735,11 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
mm_inputs["image"] = image_embeds_lst[0] mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
if len(audio_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_inputs["audio"] = audio_embeds_lst[0]
if "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: if "video" in items_by_modality:
@ -738,6 +767,8 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed") raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
if "image_embeds" in items_by_modality: if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"] image_embeds_lst = items_by_modality["image_embeds"]
@ -746,6 +777,11 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
mm_inputs["image"] = image_embeds_lst[0] mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
if len(audio_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_inputs["audio"] = audio_embeds_lst[0]
if "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: if "video" in items_by_modality:
@ -804,6 +840,14 @@ class BaseMultiModalContentParser(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def parse_audio_embeds(
self,
audio_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
raise NotImplementedError
@abstractmethod @abstractmethod
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None: def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
raise NotImplementedError raise NotImplementedError
@ -861,6 +905,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_audio_embeds(
self,
audio_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `audio_embeds`"
)
if isinstance(audio_embeds, dict):
embeds = {
k: self._connector.fetch_audio_embedding(v)
for k, v in audio_embeds.items()
}
placeholder = self._tracker.add("audio_embeds", embeds, uuid)
elif isinstance(audio_embeds, str):
embedding = self._connector.fetch_audio_embedding(audio_embeds)
placeholder = self._tracker.add("audio_embeds", embedding, uuid)
else:
placeholder = self._tracker.add("audio_embeds", None, uuid)
self._add_placeholder("audio", placeholder)
def parse_image_pil( def parse_image_pil(
self, image_pil: Image.Image | None, uuid: str | None = None self, image_pil: Image.Image | None, uuid: str | None = None
) -> None: ) -> None:
@ -950,6 +1019,67 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image_embeds", future, uuid) placeholder = self._tracker.add("image_embeds", future, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_audio_embeds(
self,
audio_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `audio_embeds`"
)
logger.info(
"🎵 Parsing audio_embeds: type=%s, uuid=%s, is_dict=%s, "
"is_str=%s, is_none=%s",
type(audio_embeds).__name__,
uuid,
isinstance(audio_embeds, dict),
isinstance(audio_embeds, str),
audio_embeds is None,
)
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
if isinstance(audio_embeds, dict):
logger.info(
"🎵 Processing dict audio_embeds with %d entries",
len(audio_embeds),
)
embeds = {
k: self._connector.fetch_audio_embedding(v)
for k, v in audio_embeds.items()
}
future.set_result(embeds)
logger.info(
"🎵 Successfully loaded %d audio embeddings from dict",
len(embeds),
)
if isinstance(audio_embeds, str):
base64_size = len(audio_embeds)
logger.info(
"🎵 Processing base64 audio_embeds: %d chars (%.2f KB)",
base64_size,
base64_size / 1024,
)
embedding = self._connector.fetch_audio_embedding(audio_embeds)
future.set_result(embedding)
logger.info(
"🎵 Successfully loaded audio embedding tensor: shape=%s, dtype=%s",
embedding.shape,
embedding.dtype,
)
if audio_embeds is None:
logger.info("🎵 Audio embeds is None (UUID-only reference)")
future.set_result(None)
placeholder = self._tracker.add("audio_embeds", future, uuid)
self._add_placeholder("audio", placeholder)
logger.info("🎵 Added audio_embeds placeholder with uuid=%s", uuid)
def parse_image_pil( def parse_image_pil(
self, image_pil: Image.Image | None, uuid: str | None = None self, image_pil: Image.Image | None, uuid: str | None = None
) -> None: ) -> None:
@ -1132,6 +1262,7 @@ def _get_full_multimodal_text_prompt(
# No need to validate using Pydantic again # No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam) _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam) _PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
@ -1155,6 +1286,7 @@ MM_PARSER_MAP: dict[
"input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None), "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
"image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None), "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None), "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None), "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio": lambda part: _InputAudioParser(part).get("input_audio", None), "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
@ -1223,8 +1355,17 @@ def _parse_chat_message_content_mm_part(
) )
image_embeds = image_params.get("image_embeds", None) image_embeds = image_params.get("image_embeds", None)
return "image_embeds", image_embeds return "image_embeds", image_embeds
if "audio_embeds" in part:
# "audio_embeds" could be None if UUID is provided.
audio_params = cast( # type: ignore[assignment]
ChatCompletionContentPartAudioEmbedsParam, part
)
audio_embeds = audio_params.get("audio_embeds", None)
return "audio_embeds", audio_embeds
if "audio_url" in part: if "audio_url" in part:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part) audio_params = cast( # type: ignore[assignment]
CustomChatCompletionContentSimpleAudioParam, part
)
audio_url = audio_params.get("audio_url", None) audio_url = audio_params.get("audio_url", None)
if isinstance(audio_url, dict): if isinstance(audio_url, dict):
# Can potentially happen if user provides a uuid # Can potentially happen if user provides a uuid
@ -1348,6 +1489,10 @@ def _parse_chat_message_content_part(
content = cast(str | dict[str, str], content) if content is not None else None content = cast(str | dict[str, str], content) if content is not None else None
mm_parser.parse_image_embeds(content, uuid) mm_parser.parse_image_embeds(content, uuid)
modality = "image" modality = "image"
elif part_type == "audio_embeds":
content = cast(str | dict[str, str], content) if content is not None else None
mm_parser.parse_audio_embeds(content, uuid)
modality = "audio"
elif part_type == "audio_url": elif part_type == "audio_url":
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_audio(str_content, uuid) mm_parser.parse_audio(str_content, uuid)

View File

@ -7,6 +7,8 @@ from typing import Literal
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import pybase64
import torch
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
@ -116,3 +118,25 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
data = buffer.getvalue() data = buffer.getvalue()
return base64.b64encode(data).decode("utf-8") return base64.b64encode(data).decode("utf-8")
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath, weights_only=True)
def encode_base64(self, media: torch.Tensor) -> str:
buffer = BytesIO()
torch.save(media, buffer)
buffer.seek(0)
binary_data = buffer.read()
return pybase64.b64encode(binary_data).decode("utf-8")

View File

@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
from vllm.utils.registry import ExtensionManager from vllm.utils.registry import ExtensionManager
from .audio import AudioMediaIO from .audio import AudioEmbeddingMediaIO, AudioMediaIO
from .base import MediaIO from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VideoMediaIO from .video import VideoMediaIO
@ -342,6 +342,17 @@ class MediaConnector:
return image_embedding_io.load_base64("", data) return image_embedding_io.load_base64("", data)
def fetch_audio_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load audio embedding from a URL.
"""
audio_embedding_io = AudioEmbeddingMediaIO()
return audio_embedding_io.load_base64("", data)
def encode_audio_base64( def encode_audio_base64(
audio: np.ndarray, audio: np.ndarray,