[Frontend] Require flag for loading text and image embeds (#27204)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Russell Bryant 2025-10-22 11:52:02 -04:00 committed by GitHub
parent db6f28d898
commit 58fab50d82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 203 additions and 64 deletions

View File

@ -359,13 +359,19 @@ Full example: [examples/offline_inference/audio_language.py](../../examples/offl
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
You must enable this feature via `enable_mm_embeds=True`.
!!! warning
The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!
??? code
```python
from vllm import LLM
# Inference with image embeddings as input
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
llm = LLM(model="llava-hf/llava-1.5-7b-hf", enable_mm_embeds=True)
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
@ -397,7 +403,11 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
image_embeds = torch.load(...)
# Qwen2-VL
llm = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": 4})
llm = LLM(
"Qwen/Qwen2-VL-2B-Instruct",
limit_mm_per_prompt={"image": 4},
enable_mm_embeds=True,
)
mm_data = {
"image": {
"image_embeds": image_embeds,
@ -407,7 +417,12 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
}
# MiniCPM-V
llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={"image": 4})
llm = LLM(
"openbmb/MiniCPM-V-2_6",
trust_remote_code=True,
limit_mm_per_prompt={"image": 4},
enable_mm_embeds=True,
)
mm_data = {
"image": {
"image_embeds": image_embeds,
@ -732,7 +747,13 @@ Full example: [examples/online_serving/openai_chat_completion_client_for_multimo
### Embedding Inputs
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
You must enable this feature via the `--enable-mm-embeds` flag in `vllm serve`.
!!! warning
The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!
#### Image Embedding Inputs

View File

@ -20,12 +20,16 @@ You can pass prompt embeddings from Hugging Face Transformers models to the `'p
## Online Serving
Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package.
Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package and are enabled by the `--enable-prompt-embeds` flag in `vllm serve`.
When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first.
Prompt embeddings are passed in as base64 encoded torch tensors.
!!! warning
The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!
### Transformers Inputs via OpenAI Client
First, launch the OpenAI-compatible server:

View File

@ -49,6 +49,7 @@ class PrithviMAE:
dtype="float16",
enforce_eager=True,
model_impl="terratorch",
enable_mm_embeds=True,
)
def run(self, input_data, location_coords):

View File

@ -38,6 +38,7 @@ def main():
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff",
model_impl="terratorch",
enable_mm_embeds=True,
)
pooling_params = PoolingParams(task="token_classify", activation=False)

View File

@ -19,6 +19,7 @@ import requests
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff
# --enable-mm-embeds
def main():

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm import LLM
@ -12,8 +13,22 @@ def test_empty_prompt():
llm.generate([""])
@pytest.mark.skip_v1
def test_out_of_vocab_token():
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
with pytest.raises(ValueError, match="out of vocabulary"):
llm.generate({"prompt_token_ids": [999999]})
def test_require_mm_embeds():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
enforce_eager=True,
enable_mm_embeds=False,
)
with pytest.raises(ValueError, match="--enable-mm-embeds"):
llm.generate(
{
"prompt": "<image>",
"multi_modal_data": {"image": torch.empty(1, 1, 1)},
}
)

View File

@ -292,3 +292,16 @@ async def test_prompt_logprobs_raises_error(
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},
)
@pytest.mark.asyncio
async def test_empty_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI,
) -> None:
await client_with_prompt_embeds.completions.create(
model=MODEL_NAME,
prompt="Hello",
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": []},
)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from unittest.mock import Mock
# imports for structured outputs tests
import openai
@ -10,7 +11,8 @@ import pytest
import regex as re
import torch
from vllm.entrypoints.renderer import BaseRenderer
from vllm.config import ModelConfig
from vllm.entrypoints.renderer import CompletionRenderer
from ...utils import RemoteOpenAIServer
@ -59,6 +61,10 @@ async def test_out_of_vocab_token_ids():
def test_load_prompt_embeds(
dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int
):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = True
renderer = CompletionRenderer(model_config, tokenizer=None)
# construct arbitrary tensors of various dtypes, layouts, and sizes.
# We need to check against different layouts to make sure that if a user
# uses sparse tensors to reduce the transmission size of prompt embeddings,
@ -83,7 +89,7 @@ def test_load_prompt_embeds(
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())
loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor)
loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor)
assert len(loaded_prompt_embeds) == 1
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
assert loaded_tensor.device.type == "cpu"
@ -91,3 +97,22 @@ def test_load_prompt_embeds(
torch.testing.assert_close(
loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("seq_len", [2])
@pytest.mark.parametrize("hidden_size", [2])
def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = False
renderer = CompletionRenderer(model_config, tokenizer=None)
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())
with pytest.raises(ValueError, match="--enable-prompt-embeds"):
renderer.load_prompt_embeds(encoded_tensor)

View File

@ -15,30 +15,7 @@ MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16"
@pytest.fixture(scope="module")
def server():
args = [
"--runner",
"pooling",
# use half precision for speed and memory savings in CI environment
"--dtype",
DTYPE,
"--enforce-eager",
"--trust-remote-code",
"--skip-tokenizer-init",
"--max-num-seqs",
"32",
"--model-impl",
"terratorch",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_request(server: RemoteOpenAIServer, model_name: str):
def _terratorch_dummy_inputs(model_name: str):
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
@ -54,7 +31,7 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str):
binary_data = buffer_coord.read()
base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8")
prompt = {
return {
"model": model_name,
"additional_data": {"prompt_token_ids": [1]},
"encoding_format": "base64",
@ -74,12 +51,33 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str):
],
}
# test single pooling
response = requests.post(server.url_for("pooling"), json=prompt)
response.raise_for_status()
output = response.json()["data"][0]["data"]
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_request(model_name: str):
args = [
"--runner",
"pooling",
# use half precision for speed and memory savings in CI environment
"--dtype",
DTYPE,
"--enforce-eager",
"--trust-remote-code",
"--max-num-seqs",
"32",
"--model-impl",
"terratorch",
"--skip-tokenizer-init",
"--enable-mm-embeds",
]
np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32)
with RemoteOpenAIServer(MODEL_NAME, args) as server:
prompt = _terratorch_dummy_inputs(model_name)
assert len(np_response) == 524288
# test single pooling
response = requests.post(server.url_for("pooling"), json=prompt)
response.raise_for_status()
output = response.json()["data"][0]["data"]
np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32)
assert len(np_response) == 524288

View File

@ -73,6 +73,19 @@ def phi3v_model_config_mm_interleaved():
)
@pytest.fixture(scope="function")
def phi3v_model_config_image_embeds():
return ModelConfig(
PHI3V_MODEL_ID,
runner="generate",
trust_remote_code=True,
limit_mm_per_prompt={
"image": 2,
},
enable_mm_embeds=True,
)
@pytest.fixture(scope="module")
def phi3v_tokenizer():
return get_tokenizer(PHI3V_MODEL_ID)
@ -799,7 +812,7 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
def test_parse_chat_messages_empty_image_embeds_with_uuid(
phi3v_model_config,
phi3v_model_config_image_embeds,
phi3v_tokenizer,
):
uuid = "abcd"
@ -813,7 +826,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
],
}
],
phi3v_model_config,
phi3v_model_config_image_embeds,
phi3v_tokenizer,
content_format="string",
)
@ -832,7 +845,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
@pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config,
phi3v_model_config_image_embeds,
phi3v_tokenizer,
):
uuid = "abcd"
@ -846,7 +859,7 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
],
}
],
phi3v_model_config,
phi3v_model_config_image_embeds,
phi3v_tokenizer,
content_format="string",
)

View File

@ -17,6 +17,7 @@ from vllm.inputs.data import is_embeds_prompt
class MockModelConfig:
max_model_len: int = 100
encoder_config: dict | None = None
enable_prompt_embeds: bool = True
class MockTokenizerResult:

View File

@ -109,8 +109,7 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt={"image": 4},
)
],
# TODO: Revert to "auto" when CPU backend can use torch > 2.6
dtype="bfloat16" if current_platform.is_cpu() else "auto",
vllm_runner_kwargs={"enable_mm_embeds": True},
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen2_5_vl": VLMTestInfo(

View File

@ -292,6 +292,7 @@ def run_embedding_input_test(
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
default_torch_num_threads=1,
enable_mm_embeds=True,
) as vllm_model:
outputs_per_case_for_original_input = [
vllm_model.generate_greedy_logprobs(

View File

@ -34,6 +34,7 @@ def _run_test(
dtype="half",
enforce_eager=True,
skip_tokenizer_init=True,
enable_mm_embeds=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,

View File

@ -104,6 +104,11 @@ def can_initialize(
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
extra_args = {}
if model_arch in ("PrithviGeoSpatialMAE", "Terratorch"):
extra_args["enable_mm_embeds"] = True
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
@ -128,6 +133,7 @@ def can_initialize(
else "vllm",
hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs,
**extra_args,
)

View File

@ -32,6 +32,7 @@ def test_inference(
dtype="half",
enforce_eager=True,
skip_tokenizer_init=True,
enable_mm_embeds=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,

View File

@ -38,6 +38,7 @@ def server():
"prithvi_to_tiff",
"--model-impl",
"terratorch",
"--enable-mm-embeds",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -6,7 +6,6 @@ import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import regex as re
import requests
from openai import BadRequestError
from tests.utils import RemoteOpenAIServer
@ -686,17 +685,3 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
"structured_outputs": {"grammar": invalid_simplified_sql_grammar}
},
)
@pytest.mark.asyncio
async def test_completion_with_empty_prompt_embeds(client: openai.AsyncOpenAI) -> None:
"""Test completion with empty prompt embeds."""
payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []}
headers: dict[str, str] = {"Content-Type": "application/json"}
# base_url = http://localhost:8000/v1/completions
response = requests.post(
f"{client.base_url}completions", headers=headers, json=payload
)
assert response.status_code == 200, (
f"Expected status code 200, got {response.status_code}. "
)

View File

@ -32,6 +32,7 @@ def default_image_embeds_server_args() -> list[str]:
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"image": MAXIMUM_IMAGES}),
"--enable-mm-embeds",
]

View File

@ -232,8 +232,10 @@ class ModelConfig:
output will contain token ids."""
enable_prompt_embeds: bool = False
"""If `True`, enables passing text embeddings as inputs via the
`prompt_embeds` key. Note that enabling this will double the time required
for graph compilation."""
`prompt_embeds` key.
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!"""
served_model_name: str | list[str] | None = None
"""The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the
@ -303,6 +305,7 @@ class ModelConfig:
"""Configuration for multimodal model. If `None`, this will be inferred
from the architecture of `self.model`."""
limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None
enable_mm_embeds: InitVar[bool | None] = None
media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None
mm_processor_kwargs: InitVar[dict[str, Any] | None] = None
mm_processor_cache_gb: InitVar[float | None] = None
@ -421,6 +424,7 @@ class ModelConfig:
self,
# Multimodal config init vars
limit_mm_per_prompt: dict[str, int] | None,
enable_mm_embeds: bool | None,
media_io_kwargs: dict[str, dict[str, Any]] | None,
mm_processor_kwargs: dict[str, Any] | None,
mm_processor_cache_gb: float | None,
@ -731,6 +735,7 @@ class ModelConfig:
mm_config_kwargs = dict(
limit_per_prompt=limit_mm_per_prompt,
enable_mm_embeds=enable_mm_embeds,
media_io_kwargs=media_io_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
mm_processor_cache_gb=mm_processor_cache_gb,

View File

@ -75,6 +75,14 @@ class MultiModalConfig:
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
"height": 512}}
"""
enable_mm_embeds: bool = False
"""If `True`, enables passing multimodal embeddings:
for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
for the OpenAI-compatible server, this refers to chat messages with content
`"type": "*_embeds"`.
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!"""
media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set

View File

@ -438,6 +438,7 @@ class EngineArgs:
limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
MultiModalConfig, "limit_per_prompt"
)
enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str, Any]] = get_field(
MultiModalConfig, "media_io_kwargs"
@ -896,6 +897,9 @@ class EngineArgs:
multimodal_group.add_argument(
"--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
)
multimodal_group.add_argument(
"--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
)
multimodal_group.add_argument(
"--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
)
@ -1159,6 +1163,7 @@ class EngineArgs:
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
enable_mm_embeds=self.enable_mm_embeds,
interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling,

View File

@ -811,6 +811,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image = self._connector.fetch_image(image_url) if image_url else None
@ -822,6 +826,12 @@ class MultiModalContentParser(BaseMultiModalContentParser):
image_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
@ -886,6 +896,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
@property
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
@ -897,6 +911,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
if isinstance(image_embeds, dict):

View File

@ -156,14 +156,17 @@ class BaseRenderer(ABC):
"""
raise NotImplementedError
@classmethod
def load_prompt_embeds(
cls,
self,
prompt_embeds: bytes | list[bytes],
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
cache_salt: str | None = None,
) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load(

View File

@ -1308,6 +1308,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
for modality, items in mm_items.items():
self.validate_num_items(modality, len(items))