mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:54:56 +08:00
[Core] Add audio_embeds support to chat completions (#29059)
Signed-off-by: Jeremy Teboul <jeremyteboul@fb.com> Co-authored-by: Jeremy Teboul <jeremyteboul@fb.com>
This commit is contained in:
parent
a982f5b5ea
commit
0730414999
@ -365,6 +365,8 @@ You must enable this feature via `enable_mm_embeds=True`.
|
|||||||
The vLLM engine may crash if incorrect shape of embeddings is passed.
|
The vLLM engine may crash if incorrect shape of embeddings is passed.
|
||||||
Only enable this flag for trusted users!
|
Only enable this flag for trusted users!
|
||||||
|
|
||||||
|
#### Image Embeddings
|
||||||
|
|
||||||
??? code
|
??? code
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -441,6 +443,36 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
|
|||||||
print(generated_text)
|
print(generated_text)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Audio Embeddings
|
||||||
|
|
||||||
|
You can pass pre-computed audio embeddings similar to image embeddings:
|
||||||
|
|
||||||
|
??? code
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm import LLM
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Enable audio embeddings support
|
||||||
|
llm = LLM(model="fixie-ai/ultravox-v0_5-llama-3_2-1b", enable_mm_embeds=True)
|
||||||
|
|
||||||
|
# Refer to the HuggingFace repo for the correct format to use
|
||||||
|
prompt = "USER: <audio>\nWhat is in this audio?\nASSISTANT:"
|
||||||
|
|
||||||
|
# Load pre-computed audio embeddings
|
||||||
|
# torch.Tensor of shape (1, audio_feature_size, hidden_size of LM)
|
||||||
|
audio_embeds = torch.load(...)
|
||||||
|
|
||||||
|
outputs = llm.generate({
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {"audio": audio_embeds},
|
||||||
|
})
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
```
|
||||||
|
|
||||||
## Online Serving
|
## Online Serving
|
||||||
|
|
||||||
Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests.
|
Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests.
|
||||||
|
|||||||
@ -103,6 +103,19 @@ def qwen2_audio_model_config():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def audio_embeds_model_config():
|
||||||
|
return ModelConfig(
|
||||||
|
QWEN2AUDIO_MODEL_ID,
|
||||||
|
runner="generate",
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={
|
||||||
|
"audio": 2,
|
||||||
|
},
|
||||||
|
enable_mm_embeds=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def qwen2_audio_tokenizer():
|
def qwen2_audio_tokenizer():
|
||||||
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
|
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
|
||||||
@ -843,6 +856,138 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
|||||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||||
|
audio_embeds_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
):
|
||||||
|
"""Test audio_embeds with UUID (no actual embeds data)."""
|
||||||
|
uuid = "test-audio-uuid-123"
|
||||||
|
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Describe this audio"},
|
||||||
|
{"type": "audio_embeds", "audio_embeds": None, "uuid": uuid},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
audio_embeds_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have audio in mm_data as None (UUID provided)
|
||||||
|
assert mm_data is not None
|
||||||
|
assert "audio" in mm_data
|
||||||
|
assert mm_data["audio"] is None
|
||||||
|
# UUID should be recorded
|
||||||
|
assert mm_uuids is not None
|
||||||
|
assert "audio" in mm_uuids
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_audio_embeds_with_string(
|
||||||
|
audio_embeds_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
):
|
||||||
|
"""Test audio_embeds with base64 string embedding data."""
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Create a sample audio embedding tensor
|
||||||
|
audio_embedding = torch.randn(1, 128, 768)
|
||||||
|
|
||||||
|
# Encode it as base64
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(audio_embedding, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
binary_data = buffer.read()
|
||||||
|
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||||
|
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Describe this audio"},
|
||||||
|
{
|
||||||
|
"type": "audio_embeds",
|
||||||
|
"audio_embeds": base64_audio_embedding,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
audio_embeds_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have audio embedding in mm_data (single tensor, not a list)
|
||||||
|
assert mm_data is not None
|
||||||
|
assert "audio" in mm_data
|
||||||
|
assert isinstance(mm_data["audio"], torch.Tensor)
|
||||||
|
assert mm_data["audio"].shape == audio_embedding.shape
|
||||||
|
# No UUID provided
|
||||||
|
assert mm_uuids is not None
|
||||||
|
assert "audio" in mm_uuids
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_chat_messages_audio_embeds_async(
|
||||||
|
audio_embeds_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
):
|
||||||
|
"""Test audio_embeds with async futures."""
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Create a sample audio embedding tensor
|
||||||
|
audio_embedding = torch.randn(1, 128, 768)
|
||||||
|
|
||||||
|
# Encode it as base64
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(audio_embedding, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
binary_data = buffer.read()
|
||||||
|
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||||
|
|
||||||
|
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Describe this audio"},
|
||||||
|
{
|
||||||
|
"type": "audio_embeds",
|
||||||
|
"audio_embeds": base64_audio_embedding,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
audio_embeds_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have audio embedding in mm_data (single tensor, not a list)
|
||||||
|
mm_data = await mm_future
|
||||||
|
assert mm_data is not None
|
||||||
|
assert "audio" in mm_data
|
||||||
|
assert isinstance(mm_data["audio"], torch.Tensor)
|
||||||
|
assert mm_data["audio"].shape == audio_embedding.shape
|
||||||
|
# No UUID provided
|
||||||
|
assert mm_uuids is not None
|
||||||
|
assert "audio" in mm_uuids
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||||
phi3v_model_config_image_embeds,
|
phi3v_model_config_image_embeds,
|
||||||
|
|||||||
@ -94,6 +94,22 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
|
||||||
|
audio_embeds: str | dict[str, str] | None
|
||||||
|
"""
|
||||||
|
The audio embeddings. It can be either:
|
||||||
|
- A single base64 string representing a serialized torch tensor.
|
||||||
|
- A dictionary where each value is a base64 string.
|
||||||
|
"""
|
||||||
|
type: Required[Literal["audio_embeds"]]
|
||||||
|
"""The type of the content part."""
|
||||||
|
uuid: str | None
|
||||||
|
"""
|
||||||
|
User-provided UUID of a media. User must guarantee that it is properly
|
||||||
|
generated and unique for different medias.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class VideoURL(TypedDict, total=False):
|
class VideoURL(TypedDict, total=False):
|
||||||
url: Required[str]
|
url: Required[str]
|
||||||
"""
|
"""
|
||||||
@ -211,6 +227,7 @@ ChatCompletionContentPartParam: TypeAlias = (
|
|||||||
| CustomChatCompletionContentPILImageParam
|
| CustomChatCompletionContentPILImageParam
|
||||||
| CustomChatCompletionContentSimpleImageParam
|
| CustomChatCompletionContentSimpleImageParam
|
||||||
| ChatCompletionContentPartImageEmbedsParam
|
| ChatCompletionContentPartImageEmbedsParam
|
||||||
|
| ChatCompletionContentPartAudioEmbedsParam
|
||||||
| CustomChatCompletionContentSimpleAudioParam
|
| CustomChatCompletionContentSimpleAudioParam
|
||||||
| CustomChatCompletionContentSimpleVideoParam
|
| CustomChatCompletionContentSimpleVideoParam
|
||||||
| str
|
| str
|
||||||
@ -599,7 +616,7 @@ def resolve_chat_template_content_format(
|
|||||||
return detected_format
|
return detected_format
|
||||||
|
|
||||||
|
|
||||||
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
|
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
@ -684,6 +701,11 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
||||||
if "image" in uuids_by_modality:
|
if "image" in uuids_by_modality:
|
||||||
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
|
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
|
||||||
|
if "audio_embeds" in uuids_by_modality:
|
||||||
|
audio_embeds_uuids = uuids_by_modality["audio_embeds"]
|
||||||
|
if len(audio_embeds_uuids) > 1:
|
||||||
|
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
|
||||||
|
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
|
||||||
if "audio" in uuids_by_modality:
|
if "audio" in uuids_by_modality:
|
||||||
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
||||||
if "video" in uuids_by_modality:
|
if "video" in uuids_by_modality:
|
||||||
@ -703,6 +725,8 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
|||||||
items_by_modality = dict(self._items_by_modality)
|
items_by_modality = dict(self._items_by_modality)
|
||||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||||
|
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||||
|
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||||
|
|
||||||
if "image_embeds" in items_by_modality:
|
if "image_embeds" in items_by_modality:
|
||||||
image_embeds_lst = items_by_modality["image_embeds"]
|
image_embeds_lst = items_by_modality["image_embeds"]
|
||||||
@ -711,6 +735,11 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
|||||||
mm_inputs["image"] = image_embeds_lst[0]
|
mm_inputs["image"] = image_embeds_lst[0]
|
||||||
if "image" in items_by_modality:
|
if "image" in items_by_modality:
|
||||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||||
|
if "audio_embeds" in items_by_modality:
|
||||||
|
audio_embeds_lst = items_by_modality["audio_embeds"]
|
||||||
|
if len(audio_embeds_lst) > 1:
|
||||||
|
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
|
||||||
|
mm_inputs["audio"] = audio_embeds_lst[0]
|
||||||
if "audio" in items_by_modality:
|
if "audio" in items_by_modality:
|
||||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||||
if "video" in items_by_modality:
|
if "video" in items_by_modality:
|
||||||
@ -738,6 +767,8 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
|||||||
|
|
||||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||||
|
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||||
|
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||||
|
|
||||||
if "image_embeds" in items_by_modality:
|
if "image_embeds" in items_by_modality:
|
||||||
image_embeds_lst = items_by_modality["image_embeds"]
|
image_embeds_lst = items_by_modality["image_embeds"]
|
||||||
@ -746,6 +777,11 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
|||||||
mm_inputs["image"] = image_embeds_lst[0]
|
mm_inputs["image"] = image_embeds_lst[0]
|
||||||
if "image" in items_by_modality:
|
if "image" in items_by_modality:
|
||||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||||
|
if "audio_embeds" in items_by_modality:
|
||||||
|
audio_embeds_lst = items_by_modality["audio_embeds"]
|
||||||
|
if len(audio_embeds_lst) > 1:
|
||||||
|
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
|
||||||
|
mm_inputs["audio"] = audio_embeds_lst[0]
|
||||||
if "audio" in items_by_modality:
|
if "audio" in items_by_modality:
|
||||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||||
if "video" in items_by_modality:
|
if "video" in items_by_modality:
|
||||||
@ -804,6 +840,14 @@ class BaseMultiModalContentParser(ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_audio_embeds(
|
||||||
|
self,
|
||||||
|
audio_embeds: str | dict[str, str] | None,
|
||||||
|
uuid: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
|
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -861,6 +905,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
|
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
|
def parse_audio_embeds(
|
||||||
|
self,
|
||||||
|
audio_embeds: str | dict[str, str] | None,
|
||||||
|
uuid: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
mm_config = self.model_config.get_multimodal_config()
|
||||||
|
if not mm_config.enable_mm_embeds:
|
||||||
|
raise ValueError(
|
||||||
|
"You must set `--enable-mm-embeds` to input `audio_embeds`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(audio_embeds, dict):
|
||||||
|
embeds = {
|
||||||
|
k: self._connector.fetch_audio_embedding(v)
|
||||||
|
for k, v in audio_embeds.items()
|
||||||
|
}
|
||||||
|
placeholder = self._tracker.add("audio_embeds", embeds, uuid)
|
||||||
|
elif isinstance(audio_embeds, str):
|
||||||
|
embedding = self._connector.fetch_audio_embedding(audio_embeds)
|
||||||
|
placeholder = self._tracker.add("audio_embeds", embedding, uuid)
|
||||||
|
else:
|
||||||
|
placeholder = self._tracker.add("audio_embeds", None, uuid)
|
||||||
|
|
||||||
|
self._add_placeholder("audio", placeholder)
|
||||||
|
|
||||||
def parse_image_pil(
|
def parse_image_pil(
|
||||||
self, image_pil: Image.Image | None, uuid: str | None = None
|
self, image_pil: Image.Image | None, uuid: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -950,6 +1019,67 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
placeholder = self._tracker.add("image_embeds", future, uuid)
|
placeholder = self._tracker.add("image_embeds", future, uuid)
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
|
def parse_audio_embeds(
|
||||||
|
self,
|
||||||
|
audio_embeds: str | dict[str, str] | None,
|
||||||
|
uuid: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
mm_config = self.model_config.get_multimodal_config()
|
||||||
|
if not mm_config.enable_mm_embeds:
|
||||||
|
raise ValueError(
|
||||||
|
"You must set `--enable-mm-embeds` to input `audio_embeds`"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"🎵 Parsing audio_embeds: type=%s, uuid=%s, is_dict=%s, "
|
||||||
|
"is_str=%s, is_none=%s",
|
||||||
|
type(audio_embeds).__name__,
|
||||||
|
uuid,
|
||||||
|
isinstance(audio_embeds, dict),
|
||||||
|
isinstance(audio_embeds, str),
|
||||||
|
audio_embeds is None,
|
||||||
|
)
|
||||||
|
|
||||||
|
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
|
||||||
|
|
||||||
|
if isinstance(audio_embeds, dict):
|
||||||
|
logger.info(
|
||||||
|
"🎵 Processing dict audio_embeds with %d entries",
|
||||||
|
len(audio_embeds),
|
||||||
|
)
|
||||||
|
embeds = {
|
||||||
|
k: self._connector.fetch_audio_embedding(v)
|
||||||
|
for k, v in audio_embeds.items()
|
||||||
|
}
|
||||||
|
future.set_result(embeds)
|
||||||
|
logger.info(
|
||||||
|
"🎵 Successfully loaded %d audio embeddings from dict",
|
||||||
|
len(embeds),
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(audio_embeds, str):
|
||||||
|
base64_size = len(audio_embeds)
|
||||||
|
logger.info(
|
||||||
|
"🎵 Processing base64 audio_embeds: %d chars (%.2f KB)",
|
||||||
|
base64_size,
|
||||||
|
base64_size / 1024,
|
||||||
|
)
|
||||||
|
embedding = self._connector.fetch_audio_embedding(audio_embeds)
|
||||||
|
future.set_result(embedding)
|
||||||
|
logger.info(
|
||||||
|
"🎵 Successfully loaded audio embedding tensor: shape=%s, dtype=%s",
|
||||||
|
embedding.shape,
|
||||||
|
embedding.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if audio_embeds is None:
|
||||||
|
logger.info("🎵 Audio embeds is None (UUID-only reference)")
|
||||||
|
future.set_result(None)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("audio_embeds", future, uuid)
|
||||||
|
self._add_placeholder("audio", placeholder)
|
||||||
|
logger.info("🎵 Added audio_embeds placeholder with uuid=%s", uuid)
|
||||||
|
|
||||||
def parse_image_pil(
|
def parse_image_pil(
|
||||||
self, image_pil: Image.Image | None, uuid: str | None = None
|
self, image_pil: Image.Image | None, uuid: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -1132,6 +1262,7 @@ def _get_full_multimodal_text_prompt(
|
|||||||
# No need to validate using Pydantic again
|
# No need to validate using Pydantic again
|
||||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||||
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
||||||
|
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
|
||||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||||
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
|
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
|
||||||
@ -1155,6 +1286,7 @@ MM_PARSER_MAP: dict[
|
|||||||
"input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
|
"input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
|
||||||
"image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
|
"image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
|
||||||
"image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
|
"image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
|
||||||
|
"audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
|
||||||
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
|
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
|
||||||
"audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
|
"audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
|
||||||
"input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
|
"input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
|
||||||
@ -1223,8 +1355,17 @@ def _parse_chat_message_content_mm_part(
|
|||||||
)
|
)
|
||||||
image_embeds = image_params.get("image_embeds", None)
|
image_embeds = image_params.get("image_embeds", None)
|
||||||
return "image_embeds", image_embeds
|
return "image_embeds", image_embeds
|
||||||
|
if "audio_embeds" in part:
|
||||||
|
# "audio_embeds" could be None if UUID is provided.
|
||||||
|
audio_params = cast( # type: ignore[assignment]
|
||||||
|
ChatCompletionContentPartAudioEmbedsParam, part
|
||||||
|
)
|
||||||
|
audio_embeds = audio_params.get("audio_embeds", None)
|
||||||
|
return "audio_embeds", audio_embeds
|
||||||
if "audio_url" in part:
|
if "audio_url" in part:
|
||||||
audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part)
|
audio_params = cast( # type: ignore[assignment]
|
||||||
|
CustomChatCompletionContentSimpleAudioParam, part
|
||||||
|
)
|
||||||
audio_url = audio_params.get("audio_url", None)
|
audio_url = audio_params.get("audio_url", None)
|
||||||
if isinstance(audio_url, dict):
|
if isinstance(audio_url, dict):
|
||||||
# Can potentially happen if user provides a uuid
|
# Can potentially happen if user provides a uuid
|
||||||
@ -1348,6 +1489,10 @@ def _parse_chat_message_content_part(
|
|||||||
content = cast(str | dict[str, str], content) if content is not None else None
|
content = cast(str | dict[str, str], content) if content is not None else None
|
||||||
mm_parser.parse_image_embeds(content, uuid)
|
mm_parser.parse_image_embeds(content, uuid)
|
||||||
modality = "image"
|
modality = "image"
|
||||||
|
elif part_type == "audio_embeds":
|
||||||
|
content = cast(str | dict[str, str], content) if content is not None else None
|
||||||
|
mm_parser.parse_audio_embeds(content, uuid)
|
||||||
|
modality = "audio"
|
||||||
elif part_type == "audio_url":
|
elif part_type == "audio_url":
|
||||||
str_content = cast(str, content)
|
str_content = cast(str, content)
|
||||||
mm_parser.parse_audio(str_content, uuid)
|
mm_parser.parse_audio(str_content, uuid)
|
||||||
|
|||||||
@ -7,6 +7,8 @@ from typing import Literal
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import pybase64
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.utils.import_utils import PlaceholderModule
|
from vllm.utils.import_utils import PlaceholderModule
|
||||||
|
|
||||||
@ -116,3 +118,25 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
|||||||
data = buffer.getvalue()
|
data = buffer.getvalue()
|
||||||
|
|
||||||
return base64.b64encode(data).decode("utf-8")
|
return base64.b64encode(data).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||||
|
buffer = BytesIO(data)
|
||||||
|
return torch.load(buffer, weights_only=True)
|
||||||
|
|
||||||
|
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||||
|
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||||
|
|
||||||
|
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||||
|
return torch.load(filepath, weights_only=True)
|
||||||
|
|
||||||
|
def encode_base64(self, media: torch.Tensor) -> str:
|
||||||
|
buffer = BytesIO()
|
||||||
|
torch.save(media, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
binary_data = buffer.read()
|
||||||
|
return pybase64.b64encode(binary_data).decode("utf-8")
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils.jsontree import json_map_leaves
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
from vllm.utils.registry import ExtensionManager
|
from vllm.utils.registry import ExtensionManager
|
||||||
|
|
||||||
from .audio import AudioMediaIO
|
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
|
||||||
from .base import MediaIO
|
from .base import MediaIO
|
||||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||||
from .video import VideoMediaIO
|
from .video import VideoMediaIO
|
||||||
@ -342,6 +342,17 @@ class MediaConnector:
|
|||||||
|
|
||||||
return image_embedding_io.load_base64("", data)
|
return image_embedding_io.load_base64("", data)
|
||||||
|
|
||||||
|
def fetch_audio_embedding(
|
||||||
|
self,
|
||||||
|
data: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Load audio embedding from a URL.
|
||||||
|
"""
|
||||||
|
audio_embedding_io = AudioEmbeddingMediaIO()
|
||||||
|
|
||||||
|
return audio_embedding_io.load_base64("", data)
|
||||||
|
|
||||||
|
|
||||||
def encode_audio_base64(
|
def encode_audio_base64(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user