mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:25:01 +08:00
[Bugfix] Fix RuntimeError: Index put requires the source and destination dtypes match (#22065)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
3303f134e0
commit
17eaaef595
103
tests/v1/entrypoints/openai/test_completion_with_image_embeds.py
Normal file
103
tests/v1/entrypoints/openai/test_completion_with_image_embeds.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
import torch
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from tests.conftest import ImageTestAssets
|
||||||
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# any model with a chat template should work here
|
||||||
|
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||||
|
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
||||||
|
MAXIMUM_IMAGES = 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def default_image_embeds_server_args() -> list[str]:
|
||||||
|
return [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"4",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--limit-mm-per-prompt",
|
||||||
|
json.dumps({"image": MAXIMUM_IMAGES}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server_with_image_embeds(default_image_embeds_server_args):
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME,
|
||||||
|
default_image_embeds_server_args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client_with_image_embeds(server_with_image_embeds):
|
||||||
|
async with server_with_image_embeds.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_embedding_to_base64(image_embedding) -> str:
|
||||||
|
"""
|
||||||
|
Encode image embedding to base64 string
|
||||||
|
"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(image_embedding, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
binary_data = buffer.read()
|
||||||
|
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
|
||||||
|
return base64_image_embedding
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32])
|
||||||
|
async def test_completions_with_image_embeds(
|
||||||
|
client_with_image_embeds: openai.AsyncOpenAI,
|
||||||
|
model_name: str,
|
||||||
|
image_assets: ImageTestAssets,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
# Test case: Single image embeds input
|
||||||
|
image_embeds = image_assets[0].image_embeds.to(dtype=dtype)
|
||||||
|
base64_image_embedding = encode_image_embedding_to_base64(image_embeds)
|
||||||
|
chat_completion = await client_with_image_embeds.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type":
|
||||||
|
"text",
|
||||||
|
"text":
|
||||||
|
"Describe these images separately. For each image,"
|
||||||
|
"reply with a short sentence (no more than 10 words).",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_embeds",
|
||||||
|
"image_embeds": base64_image_embedding,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
assert chat_completion.choices[0].message.content is not None
|
||||||
|
assert isinstance(chat_completion.choices[0].message.content, str)
|
||||||
|
assert len(chat_completion.choices[0].message.content) > 0
|
||||||
@ -401,7 +401,7 @@ def merge_multimodal_embeddings_from_map(
|
|||||||
"""
|
"""
|
||||||
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
|
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
|
||||||
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
|
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
|
||||||
placeholder_map.src]
|
placeholder_map.src].to(dtype=inputs_embeds.dtype)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
@ -421,7 +421,8 @@ def _merge_multimodal_embeddings(
|
|||||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||||
try:
|
try:
|
||||||
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
|
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
|
||||||
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
|
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
|
||||||
|
flattened.to(dtype=inputs_embeds.dtype))
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
num_expected_tokens = is_multimodal.sum().item()
|
num_expected_tokens = is_multimodal.sum().item()
|
||||||
assert isinstance(num_expected_tokens, int)
|
assert isinstance(num_expected_tokens, int)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user