mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 00:35:00 +08:00
[Misc] Various cleanups for MM input processing (#29970)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
80f8af4b2f
commit
9ae2f60374
@ -795,14 +795,12 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
image_embedding = torch.load(...)
|
||||
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(image_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
|
||||
base64_image_embedding = tensor2base64(image_embedding)
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
@ -28,13 +28,11 @@ Dependencies:
|
||||
- openai
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
@ -58,11 +56,7 @@ def main():
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
|
||||
# Prompt embeddings
|
||||
buffer = io.BytesIO()
|
||||
torch.save(prompt_embeds, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
encoded_embeds = base64.b64encode(binary_data).decode("utf-8")
|
||||
encoded_embeds = tensor2base64(prompt_embeds)
|
||||
|
||||
completion = client.completions.create(
|
||||
model=model_name,
|
||||
|
||||
@ -2,64 +2,47 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
DTYPE = "float16"
|
||||
|
||||
|
||||
def _terratorch_dummy_inputs(model_name: str):
|
||||
def _terratorch_dummy_messages():
|
||||
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
|
||||
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
|
||||
|
||||
buffer_tiff = io.BytesIO()
|
||||
torch.save(pixel_values, buffer_tiff)
|
||||
buffer_tiff.seek(0)
|
||||
binary_data = buffer_tiff.read()
|
||||
base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
buffer_coord = io.BytesIO()
|
||||
torch.save(location_coords, buffer_coord)
|
||||
buffer_coord.seek(0)
|
||||
binary_data = buffer_coord.read()
|
||||
base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
return {
|
||||
"model": model_name,
|
||||
"additional_data": {"prompt_token_ids": [1]},
|
||||
"encoding_format": "base64",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"pixel_values": base64_tensor_embedding,
|
||||
"location_coords": base64_coord_embedding,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"pixel_values": tensor2base64(pixel_values),
|
||||
"location_coords": tensor2base64(location_coords),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_request(model_name: str):
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||
)
|
||||
def test_single_request(model_name: str):
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"float16",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--max-num-seqs",
|
||||
@ -70,11 +53,15 @@ async def test_single_request(model_name: str):
|
||||
"--enable-mm-embeds",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as server:
|
||||
prompt = _terratorch_dummy_inputs(model_name)
|
||||
|
||||
# test single pooling
|
||||
response = requests.post(server.url_for("pooling"), json=prompt)
|
||||
with RemoteOpenAIServer(model_name, args) as server:
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": _terratorch_dummy_messages(),
|
||||
"encoding_format": "base64",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
output = response.json()["data"][0]["data"]
|
||||
|
||||
@ -29,6 +29,7 @@ from vllm.multimodal.utils import (
|
||||
encode_video_base64,
|
||||
)
|
||||
from vllm.tokenizers import MistralTokenizer, get_tokenizer
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import VLLM_PATH
|
||||
@ -85,11 +86,6 @@ def phi3v_model_config_image_embeds():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phi3v_tokenizer():
|
||||
return get_tokenizer(PHI3V_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def qwen2_audio_model_config():
|
||||
return ModelConfig(
|
||||
@ -115,11 +111,6 @@ def audio_embeds_model_config():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen2_audio_tokenizer():
|
||||
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def qwen25omni_model_config_mm_interleaved():
|
||||
return ModelConfig(
|
||||
@ -134,11 +125,6 @@ def qwen25omni_model_config_mm_interleaved():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen25omni_tokenizer():
|
||||
return get_tokenizer(QWEN25OMNI_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mistral_model_config():
|
||||
return ModelConfig(
|
||||
@ -150,11 +136,6 @@ def mistral_model_config():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
return get_tokenizer(MISTRAL_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def image_url():
|
||||
image = ImageAsset("cherry_blossom")
|
||||
@ -239,7 +220,6 @@ def _assert_mm_data_inputs(
|
||||
|
||||
def test_parse_chat_messages_single_image(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -253,7 +233,6 @@ def test_parse_chat_messages_single_image(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -266,7 +245,6 @@ def test_parse_chat_messages_single_image(
|
||||
|
||||
def test_parse_chat_messages_single_image_with_uuid(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -287,7 +265,6 @@ def test_parse_chat_messages_single_image_with_uuid(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -300,7 +277,6 @@ def test_parse_chat_messages_single_image_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_single_empty_image_with_uuid(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -319,7 +295,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -332,7 +307,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_single_image_with_bad_uuid_format(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -354,7 +328,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -367,7 +340,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_with_uuids(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -397,7 +369,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -413,7 +384,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
|
||||
|
||||
def test_parse_chat_messages_multiple_empty_images_with_uuids(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -439,7 +409,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -455,7 +424,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
|
||||
|
||||
def test_parse_chat_messages_mixed_empty_images_with_uuids(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -483,7 +451,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -500,7 +467,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_single_image_with_uuid_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -519,7 +485,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -533,7 +498,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_empty_image_with_uuid_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -552,7 +516,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -566,7 +529,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -592,7 +554,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -609,7 +570,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -635,7 +595,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -652,7 +611,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid2 = "my_uuid_2"
|
||||
@ -676,7 +634,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -692,7 +649,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
|
||||
|
||||
def test_parse_chat_messages_empty_system(
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
):
|
||||
# Test string format
|
||||
conversation, _, _ = parse_chat_messages(
|
||||
@ -704,7 +660,6 @@ def test_parse_chat_messages_empty_system(
|
||||
},
|
||||
],
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
assert conversation == [
|
||||
@ -722,7 +677,6 @@ def test_parse_chat_messages_empty_system(
|
||||
},
|
||||
],
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
content_format="openai",
|
||||
)
|
||||
assert conversation == [
|
||||
@ -734,7 +688,6 @@ def test_parse_chat_messages_empty_system(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_single_image_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
@ -748,7 +701,6 @@ async def test_parse_chat_messages_single_image_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -761,7 +713,6 @@ async def test_parse_chat_messages_single_image_async(
|
||||
|
||||
def test_parse_chat_messages_multiple_images(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -779,7 +730,6 @@ def test_parse_chat_messages_multiple_images(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -795,7 +745,6 @@ def test_parse_chat_messages_multiple_images(
|
||||
|
||||
def test_parse_chat_messages_empty_pil_image_with_uuid(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
):
|
||||
uuid = "abcd"
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -809,7 +758,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -825,7 +773,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
phi3v_model_config_image_embeds,
|
||||
phi3v_tokenizer,
|
||||
):
|
||||
uuid = "abcd"
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -839,7 +786,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -857,7 +803,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with UUID (no actual embeds data)."""
|
||||
uuid = "test-audio-uuid-123"
|
||||
@ -873,7 +818,6 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -889,11 +833,8 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_audio_embeds_with_string(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with base64 string embedding data."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
@ -901,11 +842,7 @@ def test_parse_chat_messages_audio_embeds_with_string(
|
||||
audio_embedding = torch.randn(1, 128, 768)
|
||||
|
||||
# Encode it as base64
|
||||
buffer = io.BytesIO()
|
||||
torch.save(audio_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
base64_audio_embedding = tensor2base64(audio_embedding)
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
@ -921,7 +858,6 @@ def test_parse_chat_messages_audio_embeds_with_string(
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -939,11 +875,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_audio_embeds_async(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with async futures."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
@ -951,11 +884,7 @@ async def test_parse_chat_messages_audio_embeds_async(
|
||||
audio_embedding = torch.randn(1, 128, 768)
|
||||
|
||||
# Encode it as base64
|
||||
buffer = io.BytesIO()
|
||||
torch.save(audio_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
base64_audio_embedding = tensor2base64(audio_embedding)
|
||||
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
[
|
||||
@ -971,7 +900,6 @@ async def test_parse_chat_messages_audio_embeds_async(
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -990,7 +918,6 @@ async def test_parse_chat_messages_audio_embeds_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
phi3v_model_config_image_embeds,
|
||||
phi3v_tokenizer,
|
||||
):
|
||||
uuid = "abcd"
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
@ -1004,7 +931,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1024,7 +950,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
@ -1042,7 +967,6 @@ async def test_parse_chat_messages_multiple_images_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1058,7 +982,6 @@ async def test_parse_chat_messages_multiple_images_async(
|
||||
|
||||
def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1076,7 +999,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
assert conversation == [
|
||||
@ -1091,7 +1013,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||
|
||||
def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1110,7 +1031,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1127,7 +1047,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_across_messages(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1149,7 +1068,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
|
||||
},
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1164,7 +1082,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -1195,7 +1112,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
|
||||
},
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1210,7 +1126,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
|
||||
|
||||
def test_parse_chat_messages_context_text_format(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
@ -1222,7 +1137,6 @@ def test_parse_chat_messages_context_text_format(
|
||||
{"role": "user", "content": "What about this one?"},
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="openai",
|
||||
)
|
||||
|
||||
@ -1246,7 +1160,6 @@ def test_parse_chat_messages_context_text_format(
|
||||
|
||||
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with warnings.catch_warnings():
|
||||
@ -1277,14 +1190,12 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
|
||||
def test_parse_chat_messages_rejects_too_many_images_across_messages(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with warnings.catch_warnings():
|
||||
@ -1322,14 +1233,12 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
|
||||
},
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_images_uncommon_input(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1344,7 +1253,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1360,7 +1268,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_interleave(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1380,7 +1287,6 @@ def test_parse_chat_messages_multiple_images_interleave(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1398,7 +1304,6 @@ def test_parse_chat_messages_multiple_images_interleave(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_interleave_async(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages_futures(
|
||||
@ -1418,7 +1323,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1436,7 +1340,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -1465,7 +1368,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1482,7 +1384,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1505,7 +1406,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
|
||||
},
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1523,7 +1423,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -1555,7 +1454,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
|
||||
},
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1573,7 +1471,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
image_url,
|
||||
video_url,
|
||||
audio_url,
|
||||
@ -1601,7 +1498,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
|
||||
},
|
||||
],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1627,7 +1523,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave(
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
image_url,
|
||||
video_url,
|
||||
audio_url,
|
||||
@ -1671,7 +1566,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
|
||||
},
|
||||
],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1699,7 +1593,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
image_url,
|
||||
video_url,
|
||||
audio_url,
|
||||
@ -1743,7 +1636,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
|
||||
},
|
||||
],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1775,7 +1667,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
image_url,
|
||||
video_url,
|
||||
audio_url,
|
||||
@ -1811,7 +1702,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
|
||||
},
|
||||
],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1837,7 +1727,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
|
||||
|
||||
def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with pytest.raises(
|
||||
@ -1861,7 +1750,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -2237,9 +2125,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
||||
assert resolved_format == expected_format
|
||||
|
||||
|
||||
def test_parse_chat_messages_include_thinking_chunk(
|
||||
mistral_model_config, mistral_tokenizer
|
||||
):
|
||||
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@ -2269,7 +2155,6 @@ def test_parse_chat_messages_include_thinking_chunk(
|
||||
conversation_with_thinking, _, _ = parse_chat_messages(
|
||||
messages,
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
content_format="openai",
|
||||
)
|
||||
|
||||
@ -2353,7 +2238,6 @@ def test_apply_mistral_chat_template_thinking_chunk():
|
||||
|
||||
def test_parse_chat_messages_single_empty_audio_with_uuid(
|
||||
qwen2_audio_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
audio_uuid = "abcd"
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -2371,7 +2255,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
|
||||
}
|
||||
],
|
||||
qwen2_audio_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -2389,7 +2272,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
|
||||
qwen2_audio_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
audio_uuid = "abcd"
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
@ -2407,7 +2289,6 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
|
||||
}
|
||||
],
|
||||
qwen2_audio_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
@ -13,6 +11,7 @@ from transformers import AutoConfig
|
||||
|
||||
from tests.conftest import ImageTestAssets
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
@ -50,18 +49,6 @@ async def client_with_image_embeds(server_with_image_embeds):
|
||||
yield async_client
|
||||
|
||||
|
||||
def encode_image_embedding_to_base64(image_embedding) -> str:
|
||||
"""
|
||||
Encode image embedding to base64 string
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(image_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_image_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
return base64_image_embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32])
|
||||
@ -73,7 +60,7 @@ async def test_completions_with_image_embeds(
|
||||
):
|
||||
# Test case: Single image embeds input
|
||||
image_embeds = image_assets[0].image_embeds.to(dtype=dtype)
|
||||
base64_image_embedding = encode_image_embedding_to_base64(image_embeds)
|
||||
base64_image_embedding = tensor2base64(image_embeds)
|
||||
chat_completion = await client_with_image_embeds.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
||||
@ -536,7 +536,7 @@ def resolve_hf_chat_template(
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
@ -627,11 +627,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
maximum per prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[_T | None]](list)
|
||||
self._uuids_by_modality = defaultdict[str, list[str | None]](list)
|
||||
@ -1612,7 +1611,6 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
|
||||
def parse_chat_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
@ -1620,7 +1618,7 @@ def parse_chat_messages(
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(
|
||||
@ -1644,7 +1642,6 @@ def parse_chat_messages(
|
||||
def parse_chat_messages_futures(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
@ -1652,7 +1649,7 @@ def parse_chat_messages_futures(
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(
|
||||
|
||||
@ -834,7 +834,6 @@ class LLM:
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
msgs,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
|
||||
@ -1088,11 +1088,6 @@ class OpenAIServing:
|
||||
Sequence[RequestPrompt],
|
||||
list[EngineTokensPrompt],
|
||||
]:
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
@ -1105,7 +1100,6 @@ class OpenAIServing:
|
||||
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
|
||||
messages,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
|
||||
@ -89,12 +89,10 @@ def parse_score_data(
|
||||
data_1: str | ScoreContentPartParam,
|
||||
data_2: str | ScoreContentPartParam,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> tuple[str, str, MultiModalDataDict | None]:
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
|
||||
content_1 = _parse_score_content(data_1, mm_tracker)
|
||||
|
||||
content_2 = _parse_score_content(data_2, mm_tracker)
|
||||
|
||||
def ensure_str(content: _ContentPart | None) -> str:
|
||||
@ -188,7 +186,6 @@ def get_score_prompt(
|
||||
data_1,
|
||||
data_2,
|
||||
model_config,
|
||||
tokenizer,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
|
||||
@ -62,6 +62,7 @@ from vllm.multimodal.inputs import (
|
||||
from vllm.multimodal.parse import (
|
||||
DictEmbeddingItems,
|
||||
ImageSize,
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
@ -570,7 +571,7 @@ class HunYuanVLMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
|
||||
):
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
|
||||
@ -1000,7 +1000,7 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
@ -1017,7 +1017,7 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[VideoItem],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
|
||||
@ -333,7 +333,7 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
@ -350,7 +350,7 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[VideoItem],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
|
||||
@ -11,6 +11,7 @@ import pybase64
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
@ -135,8 +136,4 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
return torch.load(filepath, weights_only=True)
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
buffer = BytesIO()
|
||||
torch.save(media, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
return pybase64.b64encode(binary_data).decode("utf-8")
|
||||
return tensor2base64(media)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import io
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
@ -52,6 +53,15 @@ Endianness = Literal["native", "big", "little"]
|
||||
EncodingFormat = Literal["float", "base64", "bytes"]
|
||||
|
||||
|
||||
def tensor2base64(x: torch.Tensor) -> str:
|
||||
with io.BytesIO() as buf:
|
||||
torch.save(x, buf)
|
||||
buf.seek(0)
|
||||
binary_data = buf.read()
|
||||
|
||||
return base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
|
||||
def tensor2binary(
|
||||
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
|
||||
) -> bytes:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user