Add hidden dimension validation for multimodal embedding inputs (#30968)

Signed-off-by: Wenqi Glantz <wglantz@nvidia.com>
This commit is contained in:
Wenqi Glantz 2025-12-19 02:59:36 -05:00 committed by GitHub
parent 096b25c9ed
commit 4924ac582c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 589 additions and 10 deletions

View File

@ -0,0 +1,223 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Embedding shape validation in multimodal APIs.
Tests verify that embeddings with correct ndim but incorrect hidden_size
are rejected before they can cause crashes during model inference.
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
classes, not by CompletionRenderer or MediaIO classes.
"""
import pytest
import torch
from vllm.multimodal.parse import (
AudioEmbeddingItems,
ImageEmbeddingItems,
MultiModalDataParser,
VideoEmbeddingItems,
)
class TestMultiModalParserShapeValidation:
"""Test hidden_size validation in MultiModalDataParser."""
def test_image_embeddings_correct_hidden_size_accepted(self):
"""Baseline: Image embeddings with correct hidden_size should work."""
expected_hidden_size = 768
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
valid_embeds = torch.randn(2, 100, expected_hidden_size)
result = parser.parse_mm_data({"image": valid_embeds})
assert "image" in result
assert isinstance(result["image"], ImageEmbeddingItems)
assert result["image"].get_count() == 2
def test_image_embeddings_wrong_hidden_size_rejected(self):
"""Security: Image embeddings with wrong hidden_size should be rejected."""
expected_hidden_size = 768
wrong_hidden_size = 4096
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
with pytest.raises(ValueError) as exc_info:
parser.parse_mm_data({"image": invalid_embeds})
error_msg = str(exc_info.value).lower()
assert "image" in error_msg
assert "hidden dimension mismatch" in error_msg
def test_audio_embeddings_wrong_hidden_size_rejected(self):
"""Security: Audio embeddings with wrong hidden_size should be rejected."""
expected_hidden_size = 768
wrong_hidden_size = 2048
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
with pytest.raises(ValueError) as exc_info:
parser.parse_mm_data({"audio": invalid_embeds})
error_msg = str(exc_info.value).lower()
assert "audio" in error_msg
assert "hidden dimension mismatch" in error_msg
def test_video_embeddings_wrong_hidden_size_rejected(self):
"""Security: Video embeddings with wrong hidden_size should be rejected."""
expected_hidden_size = 768
wrong_hidden_size = 512
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
with pytest.raises(ValueError) as exc_info:
parser.parse_mm_data({"video": invalid_embeds})
error_msg = str(exc_info.value).lower()
assert "video" in error_msg
assert "hidden dimension mismatch" in error_msg
def test_list_of_embeddings_validates_each(self):
"""Security: Each embedding in list should be validated."""
expected_hidden_size = 768
wrong_hidden_size = 1024
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
# List with second tensor having wrong hidden_size
invalid_embeds = [
torch.randn(100, expected_hidden_size),
torch.randn(100, wrong_hidden_size),
]
with pytest.raises(ValueError) as exc_info:
parser.parse_mm_data({"image": invalid_embeds})
# Should identify which embedding failed
assert "[1]" in str(exc_info.value)
def test_validation_disabled_allows_any_size(self):
"""When validation disabled (legacy), any hidden_size allowed."""
parser = MultiModalDataParser(expected_hidden_size=None)
any_hidden_size = 12345
embeds = torch.randn(2, 100, any_hidden_size)
# Should not raise
result = parser.parse_mm_data({"image": embeds})
assert "image" in result
assert isinstance(result["image"], ImageEmbeddingItems)
class TestEmbeddingItemsDirectValidation:
"""Direct tests for EmbeddingItems hidden_size validation."""
def test_image_embedding_items_validates_batched_tensor(self):
"""Test validation for batched (3D) image embeddings."""
expected = 768
wrong = 1024
# Valid
valid = torch.randn(2, 100, expected)
items = ImageEmbeddingItems(valid, expected_hidden_size=expected)
assert items.get_count() == 2
# Invalid
invalid = torch.randn(2, 100, wrong)
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(invalid, expected_hidden_size=expected)
assert str(wrong) in str(exc_info.value)
assert str(expected) in str(exc_info.value)
def test_image_embedding_items_validates_list_of_tensors(self):
"""Test validation for list of 2D image embeddings."""
expected = 768
wrong = 512
# Valid list
valid_list = [torch.randn(100, expected), torch.randn(50, expected)]
items = ImageEmbeddingItems(valid_list, expected_hidden_size=expected)
assert items.get_count() == 2
# Invalid list
invalid_list = [torch.randn(100, expected), torch.randn(50, wrong)]
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(invalid_list, expected_hidden_size=expected)
assert "[1]" in str(exc_info.value)
def test_audio_embedding_items_validates(self):
"""Test validation for audio embeddings."""
expected = 768
wrong = 256
invalid = torch.randn(2, 100, wrong)
with pytest.raises(ValueError) as exc_info:
AudioEmbeddingItems(invalid, expected_hidden_size=expected)
assert "audio" in str(exc_info.value).lower()
def test_video_embedding_items_validates(self):
"""Test validation for video embeddings."""
expected = 768
wrong = 384
invalid = torch.randn(2, 100, wrong)
with pytest.raises(ValueError) as exc_info:
VideoEmbeddingItems(invalid, expected_hidden_size=expected)
assert "video" in str(exc_info.value).lower()
class TestShapeValidationIntegration:
"""Integration tests verifying attack scenarios are blocked."""
def test_attack_scenario_multimodal_image(self):
"""
Simulate attack through Chat API with image embeddings.
Verifies validation occurs in multimodal parser path.
"""
expected_hidden_size = 768
wrong_hidden_size = 4096
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
with pytest.raises(ValueError):
parser.parse_mm_data({"image": attack_tensor})
def test_attack_scenario_multimodal_audio(self):
"""
Simulate attack through Chat API with audio embeddings.
Verifies validation occurs in multimodal parser path.
"""
expected_hidden_size = 768
wrong_hidden_size = 2048
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
with pytest.raises(ValueError):
parser.parse_mm_data({"audio": attack_tensor})
def test_attack_scenario_multimodal_video(self):
"""
Simulate attack through Chat API with video embeddings.
Verifies validation occurs in multimodal parser path.
"""
expected_hidden_size = 768
wrong_hidden_size = 1024
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
with pytest.raises(ValueError):
parser.parse_mm_data({"video": attack_tensor})

View File

@ -0,0 +1,249 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for embedding shape validation.
Simple, fast unit tests that can run without server fixtures.
Run with: pytest tests/multimodal/test_embedding_shape_validation_unit.py -v
"""
import pytest
import torch
from vllm.multimodal.parse import (
AudioEmbeddingItems,
ImageEmbeddingItems,
)
class TestImageEmbedBasicValidation:
"""Test basic ndim validation in image embeddings via ImageEmbeddingItems."""
def test_valid_2d_tensor_accepted(self):
"""Baseline: 2D tensors should be accepted."""
valid_tensor = torch.randn(10, 768, dtype=torch.float32)
# Should not raise - 2D is valid
items = ImageEmbeddingItems(valid_tensor)
assert items.get_count() == 10
def test_valid_3d_tensor_accepted(self):
"""Baseline: 3D tensors should be accepted."""
valid_tensor = torch.randn(2, 10, 768, dtype=torch.float32)
# Should not raise - 3D is valid
items = ImageEmbeddingItems(valid_tensor)
assert items.get_count() == 2
def test_valid_list_of_2d_tensors_accepted(self):
"""Baseline: List of 2D tensors should be accepted."""
tensors = [
torch.randn(10, 768, dtype=torch.float32),
torch.randn(15, 768, dtype=torch.float32),
]
# Should not raise
items = ImageEmbeddingItems(tensors)
assert items.get_count() == 2
def test_1d_tensor_rejected(self):
"""Security: 1D tensors should be rejected (invalid ndim)."""
invalid_tensor = torch.randn(768, dtype=torch.float32) # 1D
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(invalid_tensor)
assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value)
def test_4d_tensor_rejected(self):
"""Security: 4D tensors should be rejected (invalid ndim)."""
invalid_tensor = torch.randn(1, 2, 10, 768, dtype=torch.float32) # 4D
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(invalid_tensor)
assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value)
def test_hidden_size_validation_correct_size(self):
"""Embeddings with correct hidden size should be accepted."""
expected_hidden_size = 768
valid_tensor = torch.randn(10, expected_hidden_size, dtype=torch.float32)
# Should not raise
items = ImageEmbeddingItems(
valid_tensor, expected_hidden_size=expected_hidden_size
)
assert items.get_count() == 10
def test_hidden_size_validation_wrong_size_rejected(self):
"""Embeddings with wrong hidden size should be rejected."""
expected_hidden_size = 768
wrong_hidden_size = 4096
invalid_tensor = torch.randn(10, wrong_hidden_size, dtype=torch.float32)
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(
invalid_tensor, expected_hidden_size=expected_hidden_size
)
error_msg = str(exc_info.value)
assert "hidden dimension mismatch" in error_msg.lower()
assert str(wrong_hidden_size) in error_msg
assert str(expected_hidden_size) in error_msg
class TestAudioEmbedBasicValidation:
"""Test basic ndim validation in audio embeddings via AudioEmbeddingItems."""
def test_valid_2d_tensor_accepted(self):
"""Baseline: 2D tensors should be accepted."""
valid_tensor = torch.randn(10, 768, dtype=torch.float32)
# Should not raise - 2D is valid
items = AudioEmbeddingItems(valid_tensor)
assert items.get_count() == 10
def test_valid_3d_tensor_accepted(self):
"""Baseline: 3D tensors should be accepted."""
valid_tensor = torch.randn(2, 10, 768, dtype=torch.float32)
# Should not raise - 3D is valid
items = AudioEmbeddingItems(valid_tensor)
assert items.get_count() == 2
def test_valid_list_of_2d_tensors_accepted(self):
"""Baseline: List of 2D tensors should be accepted."""
tensors = [
torch.randn(10, 768, dtype=torch.float32),
torch.randn(15, 768, dtype=torch.float32),
]
# Should not raise
items = AudioEmbeddingItems(tensors)
assert items.get_count() == 2
def test_1d_tensor_rejected(self):
"""Security: 1D tensors should be rejected (invalid ndim)."""
invalid_tensor = torch.randn(768, dtype=torch.float32) # 1D
with pytest.raises(ValueError) as exc_info:
AudioEmbeddingItems(invalid_tensor)
assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value)
def test_scalar_rejected(self):
"""Security: Scalar tensors should be rejected."""
invalid_tensor = torch.tensor(1.0) # 0D (scalar)
with pytest.raises(ValueError):
AudioEmbeddingItems(invalid_tensor)
def test_hidden_size_validation_correct_size(self):
"""Embeddings with correct hidden size should be accepted."""
expected_hidden_size = 768
valid_tensor = torch.randn(10, expected_hidden_size, dtype=torch.float32)
# Should not raise
items = AudioEmbeddingItems(
valid_tensor, expected_hidden_size=expected_hidden_size
)
assert items.get_count() == 10
def test_hidden_size_validation_wrong_size_rejected(self):
"""Embeddings with wrong hidden size should be rejected."""
expected_hidden_size = 768
wrong_hidden_size = 4096
invalid_tensor = torch.randn(10, wrong_hidden_size, dtype=torch.float32)
with pytest.raises(ValueError) as exc_info:
AudioEmbeddingItems(
invalid_tensor, expected_hidden_size=expected_hidden_size
)
error_msg = str(exc_info.value)
assert "hidden dimension mismatch" in error_msg.lower()
assert str(wrong_hidden_size) in error_msg
assert str(expected_hidden_size) in error_msg
class TestShapeValidationDoSPrevention:
"""
Tests for DoS prevention through shape validation.
Verifies that embeddings with incorrect shapes are rejected early,
preventing crashes during model inference.
"""
def test_prevent_crash_from_wrong_shape_image_embeds(self):
"""
Prevent crash scenario: wrong hidden size in image embeddings.
Without validation, this would pass initial checks but crash later
during model forward pass when dimensions don't match.
"""
expected_hidden_size = 768 # Typical model hidden size
wrong_hidden_size = 4096 # Wrong size (e.g., Llama-sized)
wrong_embedding = torch.randn(100, wrong_hidden_size, dtype=torch.float32)
# Should be rejected at instantiation time, not during inference
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(
wrong_embedding, expected_hidden_size=expected_hidden_size
)
error_msg = str(exc_info.value)
assert "hidden dimension mismatch" in error_msg.lower()
assert str(expected_hidden_size) in error_msg # Expected
assert str(wrong_hidden_size) in error_msg # Received
def test_prevent_crash_from_wrong_shape_audio_embeds(self):
"""
Prevent crash scenario: wrong hidden size in audio embeddings.
"""
expected_hidden_size = 768
wrong_hidden_size = 4096
wrong_embedding = torch.randn(100, wrong_hidden_size, dtype=torch.float32)
with pytest.raises(ValueError) as exc_info:
AudioEmbeddingItems(
wrong_embedding, expected_hidden_size=expected_hidden_size
)
error_msg = str(exc_info.value)
assert "hidden dimension mismatch" in error_msg.lower()
def test_extremely_large_hidden_size_rejected(self):
"""Security: Prevent DoS from extremely large embeddings."""
expected_hidden_size = 768
huge_hidden_size = 100000 # Large but not extreme to avoid test OOM
invalid_tensor = torch.randn(10, huge_hidden_size, dtype=torch.float32)
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(
invalid_tensor, expected_hidden_size=expected_hidden_size
)
assert "hidden dimension mismatch" in str(exc_info.value).lower()
def test_batch_with_mixed_hidden_sizes_rejected(self):
"""All embeddings in a list must have the same hidden size."""
expected_hidden_size = 768
# One correct, one wrong
batch = [
torch.randn(10, expected_hidden_size, dtype=torch.float32),
torch.randn(10, expected_hidden_size + 100, dtype=torch.float32), # Wrong!
]
# Should fail on the second one
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(batch, expected_hidden_size=expected_hidden_size)
assert "hidden dimension mismatch" in str(exc_info.value).lower()
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])

View File

@ -126,6 +126,30 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
return {}
def validate_embedding_ndim(
tensor: torch.Tensor,
modality: str,
index: int | None = None,
) -> None:
"""Validate tensor ndim for multimodal embeddings.
Single embeddings should be 2D (seq_len, hidden_size).
Batched embeddings should be 3D (batch, seq_len, hidden_size).
Args:
tensor: The tensor to validate.
modality: The modality name for error messages (e.g., "image", "audio").
index: Optional index for list items, included in error messages.
"""
if tensor.ndim < 2 or tensor.ndim > 3:
idx_str = f" [{index}]" if index is not None else ""
raise ValueError(
f"{modality.capitalize()} embedding{idx_str} must be 2D "
f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), "
f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
)
class EmbeddingItems(
ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
):
@ -134,6 +158,63 @@ class EmbeddingItems(
or a list of embedding tensors (one per item).
"""
def __init__(
self,
data: torch.Tensor | list[torch.Tensor],
modality: str,
expected_hidden_size: int | None = None,
) -> None:
super().__init__(data, modality)
# Validate ndim first (before hidden_size which depends on correct ndim)
self._validate_ndim()
# Validate hidden dimension if expected size is provided
if expected_hidden_size is not None:
self._validate_hidden_size(expected_hidden_size)
def _validate_ndim(self) -> None:
"""Validate that embedding tensors have correct ndim (2D or 3D)."""
if isinstance(self.data, torch.Tensor):
validate_embedding_ndim(self.data, self.modality)
else:
# List of tensors: each should be 2D (seq_len, hidden_size)
for idx, tensor in enumerate(self.data):
if tensor.ndim != 2:
raise ValueError(
f"{self.modality.capitalize()} embedding [{idx}] must be "
f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor "
f"with shape {tuple(tensor.shape)}"
)
def _validate_hidden_size(self, expected_hidden_size: int) -> None:
"""Validate that embedding hidden dimension matches expected size.
This validates hidden dimensions to prevent vulnerabilities: Embeddings
with correct ndim but wrong hidden dimension could bypass initial
checks and cause crashes during model inference when dimensions don't match.
"""
if isinstance(self.data, torch.Tensor):
# Batched tensor: shape is (batch, seq_len, hidden_size)
actual_hidden_size = self.data.shape[-1]
if actual_hidden_size != expected_hidden_size:
raise ValueError(
f"{self.modality.capitalize()} embedding hidden dimension "
f"mismatch: got {actual_hidden_size}, but model expects "
f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}"
)
else:
# List of tensors: each has shape (seq_len, hidden_size)
for idx, tensor in enumerate(self.data):
actual_hidden_size = tensor.shape[-1]
if actual_hidden_size != expected_hidden_size:
raise ValueError(
f"{self.modality.capitalize()} embedding [{idx}] hidden "
f"dimension mismatch: got {actual_hidden_size}, but model "
f"expects {expected_hidden_size}. "
f"Embedding shape: {tuple(tensor.shape)}"
)
def _unwrap(
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
) -> torch.Tensor:
@ -228,8 +309,12 @@ class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
class AudioEmbeddingItems(EmbeddingItems):
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
super().__init__(data, "audio")
def __init__(
self,
data: torch.Tensor | list[torch.Tensor],
expected_hidden_size: int | None = None,
) -> None:
super().__init__(data, "audio", expected_hidden_size)
class ImageSize(NamedTuple):
@ -256,8 +341,12 @@ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
class ImageEmbeddingItems(EmbeddingItems):
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
super().__init__(data, "image")
def __init__(
self,
data: torch.Tensor | list[torch.Tensor],
expected_hidden_size: int | None = None,
) -> None:
super().__init__(data, "image", expected_hidden_size)
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
@ -287,8 +376,12 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
class VideoEmbeddingItems(EmbeddingItems):
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
super().__init__(data, "video")
def __init__(
self,
data: torch.Tensor | list[torch.Tensor],
expected_hidden_size: int | None = None,
) -> None:
super().__init__(data, "video", expected_hidden_size)
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
@ -363,6 +456,10 @@ class MultiModalDataParser:
Args:
target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate.
expected_hidden_size (int, optional): Expected hidden dimension for
embedding inputs. If provided, validates that user-supplied
embeddings have the correct hidden size to prevent crashes
during model inference.
"""
def __init__(
@ -371,6 +468,7 @@ class MultiModalDataParser:
target_sr: float | None = None,
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
video_needs_metadata: bool = False,
expected_hidden_size: int | None = None,
) -> None:
super().__init__()
@ -379,6 +477,7 @@ class MultiModalDataParser:
method=audio_resample_method,
)
self.video_needs_metadata = video_needs_metadata
self.expected_hidden_size = expected_hidden_size
@classmethod
def is_embeddings(
@ -443,7 +542,7 @@ class MultiModalDataParser:
return None
if self.is_embeddings(data):
return AudioEmbeddingItems(data)
return AudioEmbeddingItems(data, self.expected_hidden_size)
data_items: list[AudioItem]
if (
@ -481,7 +580,7 @@ class MultiModalDataParser:
return None
if self.is_embeddings(data):
return ImageEmbeddingItems(data)
return ImageEmbeddingItems(data, self.expected_hidden_size)
if (
isinstance(data, (PILImage.Image, MediaWithBytes))
@ -507,7 +606,7 @@ class MultiModalDataParser:
return None
if self.is_embeddings(data):
return VideoEmbeddingItems(data)
return VideoEmbeddingItems(data, self.expected_hidden_size)
data_items: list[VideoItem]
if (

View File

@ -1330,7 +1330,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
that has additional subparsers.
"""
return MultiModalDataParser()
# Get expected hidden size for embedding validation if mm_embeds enabled
# This validates hidden dimensions to prevent vulnerabilities: embeddings
# with correct ndim but wrong shape could cause crashes at inference time
mm_config = self.info.ctx.model_config.get_multimodal_config()
expected_hidden_size = None
if mm_config.enable_mm_embeds:
expected_hidden_size = self.info.ctx.model_config.get_inputs_embeds_size()
return MultiModalDataParser(expected_hidden_size=expected_hidden_size)
def validate_num_items(
self,