From 4924ac582c5007fbf7b15d719708812e06961009 Mon Sep 17 00:00:00 2001 From: Wenqi Glantz Date: Fri, 19 Dec 2025 02:59:36 -0500 Subject: [PATCH] Add hidden dimension validation for multimodal embedding inputs (#30968) Signed-off-by: Wenqi Glantz --- .../openai/test_embedding_shape_validation.py | 223 ++++++++++++++++ .../test_embedding_shape_validation_unit.py | 249 ++++++++++++++++++ vllm/multimodal/parse.py | 117 +++++++- vllm/multimodal/processing.py | 10 +- 4 files changed, 589 insertions(+), 10 deletions(-) create mode 100644 tests/entrypoints/openai/test_embedding_shape_validation.py create mode 100644 tests/multimodal/test_embedding_shape_validation_unit.py diff --git a/tests/entrypoints/openai/test_embedding_shape_validation.py b/tests/entrypoints/openai/test_embedding_shape_validation.py new file mode 100644 index 0000000000000..27060e0be5aee --- /dev/null +++ b/tests/entrypoints/openai/test_embedding_shape_validation.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Embedding shape validation in multimodal APIs. + +Tests verify that embeddings with correct ndim but incorrect hidden_size +are rejected before they can cause crashes during model inference. + +Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems +classes, not by CompletionRenderer or MediaIO classes. +""" + +import pytest +import torch + +from vllm.multimodal.parse import ( + AudioEmbeddingItems, + ImageEmbeddingItems, + MultiModalDataParser, + VideoEmbeddingItems, +) + + +class TestMultiModalParserShapeValidation: + """Test hidden_size validation in MultiModalDataParser.""" + + def test_image_embeddings_correct_hidden_size_accepted(self): + """Baseline: Image embeddings with correct hidden_size should work.""" + expected_hidden_size = 768 + parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size) + + valid_embeds = torch.randn(2, 100, expected_hidden_size) + + result = parser.parse_mm_data({"image": valid_embeds}) + + assert "image" in result + assert isinstance(result["image"], ImageEmbeddingItems) + assert result["image"].get_count() == 2 + + def test_image_embeddings_wrong_hidden_size_rejected(self): + """Security: Image embeddings with wrong hidden_size should be rejected.""" + expected_hidden_size = 768 + wrong_hidden_size = 4096 + parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size) + + invalid_embeds = torch.randn(2, 100, wrong_hidden_size) + + with pytest.raises(ValueError) as exc_info: + parser.parse_mm_data({"image": invalid_embeds}) + + error_msg = str(exc_info.value).lower() + assert "image" in error_msg + assert "hidden dimension mismatch" in error_msg + + def test_audio_embeddings_wrong_hidden_size_rejected(self): + """Security: Audio embeddings with wrong hidden_size should be rejected.""" + expected_hidden_size = 768 + wrong_hidden_size = 2048 + parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size) + + invalid_embeds = torch.randn(2, 100, wrong_hidden_size) + + with pytest.raises(ValueError) as exc_info: + parser.parse_mm_data({"audio": invalid_embeds}) + + error_msg = str(exc_info.value).lower() + assert "audio" in error_msg + assert "hidden dimension mismatch" in error_msg + + def test_video_embeddings_wrong_hidden_size_rejected(self): + """Security: Video embeddings with wrong hidden_size should be rejected.""" + expected_hidden_size = 768 + wrong_hidden_size = 512 + parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size) + + invalid_embeds = torch.randn(2, 100, wrong_hidden_size) + + with pytest.raises(ValueError) as exc_info: + parser.parse_mm_data({"video": invalid_embeds}) + + error_msg = str(exc_info.value).lower() + assert "video" in error_msg + assert "hidden dimension mismatch" in error_msg + + def test_list_of_embeddings_validates_each(self): + """Security: Each embedding in list should be validated.""" + expected_hidden_size = 768 + wrong_hidden_size = 1024 + parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size) + + # List with second tensor having wrong hidden_size + invalid_embeds = [ + torch.randn(100, expected_hidden_size), + torch.randn(100, wrong_hidden_size), + ] + + with pytest.raises(ValueError) as exc_info: + parser.parse_mm_data({"image": invalid_embeds}) + + # Should identify which embedding failed + assert "[1]" in str(exc_info.value) + + def test_validation_disabled_allows_any_size(self): + """When validation disabled (legacy), any hidden_size allowed.""" + parser = MultiModalDataParser(expected_hidden_size=None) + + any_hidden_size = 12345 + embeds = torch.randn(2, 100, any_hidden_size) + + # Should not raise + result = parser.parse_mm_data({"image": embeds}) + assert "image" in result + assert isinstance(result["image"], ImageEmbeddingItems) + + +class TestEmbeddingItemsDirectValidation: + """Direct tests for EmbeddingItems hidden_size validation.""" + + def test_image_embedding_items_validates_batched_tensor(self): + """Test validation for batched (3D) image embeddings.""" + expected = 768 + wrong = 1024 + + # Valid + valid = torch.randn(2, 100, expected) + items = ImageEmbeddingItems(valid, expected_hidden_size=expected) + assert items.get_count() == 2 + + # Invalid + invalid = torch.randn(2, 100, wrong) + with pytest.raises(ValueError) as exc_info: + ImageEmbeddingItems(invalid, expected_hidden_size=expected) + + assert str(wrong) in str(exc_info.value) + assert str(expected) in str(exc_info.value) + + def test_image_embedding_items_validates_list_of_tensors(self): + """Test validation for list of 2D image embeddings.""" + expected = 768 + wrong = 512 + + # Valid list + valid_list = [torch.randn(100, expected), torch.randn(50, expected)] + items = ImageEmbeddingItems(valid_list, expected_hidden_size=expected) + assert items.get_count() == 2 + + # Invalid list + invalid_list = [torch.randn(100, expected), torch.randn(50, wrong)] + with pytest.raises(ValueError) as exc_info: + ImageEmbeddingItems(invalid_list, expected_hidden_size=expected) + + assert "[1]" in str(exc_info.value) + + def test_audio_embedding_items_validates(self): + """Test validation for audio embeddings.""" + expected = 768 + wrong = 256 + + invalid = torch.randn(2, 100, wrong) + with pytest.raises(ValueError) as exc_info: + AudioEmbeddingItems(invalid, expected_hidden_size=expected) + + assert "audio" in str(exc_info.value).lower() + + def test_video_embedding_items_validates(self): + """Test validation for video embeddings.""" + expected = 768 + wrong = 384 + + invalid = torch.randn(2, 100, wrong) + with pytest.raises(ValueError) as exc_info: + VideoEmbeddingItems(invalid, expected_hidden_size=expected) + + assert "video" in str(exc_info.value).lower() + + +class TestShapeValidationIntegration: + """Integration tests verifying attack scenarios are blocked.""" + + def test_attack_scenario_multimodal_image(self): + """ + Simulate attack through Chat API with image embeddings. + + Verifies validation occurs in multimodal parser path. + """ + expected_hidden_size = 768 + wrong_hidden_size = 4096 + parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size) + + attack_tensor = torch.randn(1, 100, wrong_hidden_size) + + with pytest.raises(ValueError): + parser.parse_mm_data({"image": attack_tensor}) + + def test_attack_scenario_multimodal_audio(self): + """ + Simulate attack through Chat API with audio embeddings. + + Verifies validation occurs in multimodal parser path. + """ + expected_hidden_size = 768 + wrong_hidden_size = 2048 + parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size) + + attack_tensor = torch.randn(1, 100, wrong_hidden_size) + + with pytest.raises(ValueError): + parser.parse_mm_data({"audio": attack_tensor}) + + def test_attack_scenario_multimodal_video(self): + """ + Simulate attack through Chat API with video embeddings. + + Verifies validation occurs in multimodal parser path. + """ + expected_hidden_size = 768 + wrong_hidden_size = 1024 + parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size) + + attack_tensor = torch.randn(1, 100, wrong_hidden_size) + + with pytest.raises(ValueError): + parser.parse_mm_data({"video": attack_tensor}) diff --git a/tests/multimodal/test_embedding_shape_validation_unit.py b/tests/multimodal/test_embedding_shape_validation_unit.py new file mode 100644 index 0000000000000..7966aad4e988c --- /dev/null +++ b/tests/multimodal/test_embedding_shape_validation_unit.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for embedding shape validation. + +Simple, fast unit tests that can run without server fixtures. +Run with: pytest tests/multimodal/test_embedding_shape_validation_unit.py -v +""" + +import pytest +import torch + +from vllm.multimodal.parse import ( + AudioEmbeddingItems, + ImageEmbeddingItems, +) + + +class TestImageEmbedBasicValidation: + """Test basic ndim validation in image embeddings via ImageEmbeddingItems.""" + + def test_valid_2d_tensor_accepted(self): + """Baseline: 2D tensors should be accepted.""" + valid_tensor = torch.randn(10, 768, dtype=torch.float32) + + # Should not raise - 2D is valid + items = ImageEmbeddingItems(valid_tensor) + assert items.get_count() == 10 + + def test_valid_3d_tensor_accepted(self): + """Baseline: 3D tensors should be accepted.""" + valid_tensor = torch.randn(2, 10, 768, dtype=torch.float32) + + # Should not raise - 3D is valid + items = ImageEmbeddingItems(valid_tensor) + assert items.get_count() == 2 + + def test_valid_list_of_2d_tensors_accepted(self): + """Baseline: List of 2D tensors should be accepted.""" + tensors = [ + torch.randn(10, 768, dtype=torch.float32), + torch.randn(15, 768, dtype=torch.float32), + ] + + # Should not raise + items = ImageEmbeddingItems(tensors) + assert items.get_count() == 2 + + def test_1d_tensor_rejected(self): + """Security: 1D tensors should be rejected (invalid ndim).""" + invalid_tensor = torch.randn(768, dtype=torch.float32) # 1D + + with pytest.raises(ValueError) as exc_info: + ImageEmbeddingItems(invalid_tensor) + + assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value) + + def test_4d_tensor_rejected(self): + """Security: 4D tensors should be rejected (invalid ndim).""" + invalid_tensor = torch.randn(1, 2, 10, 768, dtype=torch.float32) # 4D + + with pytest.raises(ValueError) as exc_info: + ImageEmbeddingItems(invalid_tensor) + + assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value) + + def test_hidden_size_validation_correct_size(self): + """Embeddings with correct hidden size should be accepted.""" + expected_hidden_size = 768 + valid_tensor = torch.randn(10, expected_hidden_size, dtype=torch.float32) + + # Should not raise + items = ImageEmbeddingItems( + valid_tensor, expected_hidden_size=expected_hidden_size + ) + assert items.get_count() == 10 + + def test_hidden_size_validation_wrong_size_rejected(self): + """Embeddings with wrong hidden size should be rejected.""" + expected_hidden_size = 768 + wrong_hidden_size = 4096 + invalid_tensor = torch.randn(10, wrong_hidden_size, dtype=torch.float32) + + with pytest.raises(ValueError) as exc_info: + ImageEmbeddingItems( + invalid_tensor, expected_hidden_size=expected_hidden_size + ) + + error_msg = str(exc_info.value) + assert "hidden dimension mismatch" in error_msg.lower() + assert str(wrong_hidden_size) in error_msg + assert str(expected_hidden_size) in error_msg + + +class TestAudioEmbedBasicValidation: + """Test basic ndim validation in audio embeddings via AudioEmbeddingItems.""" + + def test_valid_2d_tensor_accepted(self): + """Baseline: 2D tensors should be accepted.""" + valid_tensor = torch.randn(10, 768, dtype=torch.float32) + + # Should not raise - 2D is valid + items = AudioEmbeddingItems(valid_tensor) + assert items.get_count() == 10 + + def test_valid_3d_tensor_accepted(self): + """Baseline: 3D tensors should be accepted.""" + valid_tensor = torch.randn(2, 10, 768, dtype=torch.float32) + + # Should not raise - 3D is valid + items = AudioEmbeddingItems(valid_tensor) + assert items.get_count() == 2 + + def test_valid_list_of_2d_tensors_accepted(self): + """Baseline: List of 2D tensors should be accepted.""" + tensors = [ + torch.randn(10, 768, dtype=torch.float32), + torch.randn(15, 768, dtype=torch.float32), + ] + + # Should not raise + items = AudioEmbeddingItems(tensors) + assert items.get_count() == 2 + + def test_1d_tensor_rejected(self): + """Security: 1D tensors should be rejected (invalid ndim).""" + invalid_tensor = torch.randn(768, dtype=torch.float32) # 1D + + with pytest.raises(ValueError) as exc_info: + AudioEmbeddingItems(invalid_tensor) + + assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value) + + def test_scalar_rejected(self): + """Security: Scalar tensors should be rejected.""" + invalid_tensor = torch.tensor(1.0) # 0D (scalar) + + with pytest.raises(ValueError): + AudioEmbeddingItems(invalid_tensor) + + def test_hidden_size_validation_correct_size(self): + """Embeddings with correct hidden size should be accepted.""" + expected_hidden_size = 768 + valid_tensor = torch.randn(10, expected_hidden_size, dtype=torch.float32) + + # Should not raise + items = AudioEmbeddingItems( + valid_tensor, expected_hidden_size=expected_hidden_size + ) + assert items.get_count() == 10 + + def test_hidden_size_validation_wrong_size_rejected(self): + """Embeddings with wrong hidden size should be rejected.""" + expected_hidden_size = 768 + wrong_hidden_size = 4096 + invalid_tensor = torch.randn(10, wrong_hidden_size, dtype=torch.float32) + + with pytest.raises(ValueError) as exc_info: + AudioEmbeddingItems( + invalid_tensor, expected_hidden_size=expected_hidden_size + ) + + error_msg = str(exc_info.value) + assert "hidden dimension mismatch" in error_msg.lower() + assert str(wrong_hidden_size) in error_msg + assert str(expected_hidden_size) in error_msg + + +class TestShapeValidationDoSPrevention: + """ + Tests for DoS prevention through shape validation. + + Verifies that embeddings with incorrect shapes are rejected early, + preventing crashes during model inference. + """ + + def test_prevent_crash_from_wrong_shape_image_embeds(self): + """ + Prevent crash scenario: wrong hidden size in image embeddings. + + Without validation, this would pass initial checks but crash later + during model forward pass when dimensions don't match. + """ + expected_hidden_size = 768 # Typical model hidden size + wrong_hidden_size = 4096 # Wrong size (e.g., Llama-sized) + + wrong_embedding = torch.randn(100, wrong_hidden_size, dtype=torch.float32) + + # Should be rejected at instantiation time, not during inference + with pytest.raises(ValueError) as exc_info: + ImageEmbeddingItems( + wrong_embedding, expected_hidden_size=expected_hidden_size + ) + + error_msg = str(exc_info.value) + assert "hidden dimension mismatch" in error_msg.lower() + assert str(expected_hidden_size) in error_msg # Expected + assert str(wrong_hidden_size) in error_msg # Received + + def test_prevent_crash_from_wrong_shape_audio_embeds(self): + """ + Prevent crash scenario: wrong hidden size in audio embeddings. + """ + expected_hidden_size = 768 + wrong_hidden_size = 4096 + + wrong_embedding = torch.randn(100, wrong_hidden_size, dtype=torch.float32) + + with pytest.raises(ValueError) as exc_info: + AudioEmbeddingItems( + wrong_embedding, expected_hidden_size=expected_hidden_size + ) + + error_msg = str(exc_info.value) + assert "hidden dimension mismatch" in error_msg.lower() + + def test_extremely_large_hidden_size_rejected(self): + """Security: Prevent DoS from extremely large embeddings.""" + expected_hidden_size = 768 + huge_hidden_size = 100000 # Large but not extreme to avoid test OOM + + invalid_tensor = torch.randn(10, huge_hidden_size, dtype=torch.float32) + + with pytest.raises(ValueError) as exc_info: + ImageEmbeddingItems( + invalid_tensor, expected_hidden_size=expected_hidden_size + ) + + assert "hidden dimension mismatch" in str(exc_info.value).lower() + + def test_batch_with_mixed_hidden_sizes_rejected(self): + """All embeddings in a list must have the same hidden size.""" + expected_hidden_size = 768 + + # One correct, one wrong + batch = [ + torch.randn(10, expected_hidden_size, dtype=torch.float32), + torch.randn(10, expected_hidden_size + 100, dtype=torch.float32), # Wrong! + ] + + # Should fail on the second one + with pytest.raises(ValueError) as exc_info: + ImageEmbeddingItems(batch, expected_hidden_size=expected_hidden_size) + + assert "hidden dimension mismatch" in str(exc_info.value).lower() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index a69afc3176cab..64c03f8d4da94 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -126,6 +126,30 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): return {} +def validate_embedding_ndim( + tensor: torch.Tensor, + modality: str, + index: int | None = None, +) -> None: + """Validate tensor ndim for multimodal embeddings. + + Single embeddings should be 2D (seq_len, hidden_size). + Batched embeddings should be 3D (batch, seq_len, hidden_size). + + Args: + tensor: The tensor to validate. + modality: The modality name for error messages (e.g., "image", "audio"). + index: Optional index for list items, included in error messages. + """ + if tensor.ndim < 2 or tensor.ndim > 3: + idx_str = f" [{index}]" if index is not None else "" + raise ValueError( + f"{modality.capitalize()} embedding{idx_str} must be 2D " + f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), " + f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) + + class EmbeddingItems( ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor] ): @@ -134,6 +158,63 @@ class EmbeddingItems( or a list of embedding tensors (one per item). """ + def __init__( + self, + data: torch.Tensor | list[torch.Tensor], + modality: str, + expected_hidden_size: int | None = None, + ) -> None: + super().__init__(data, modality) + + # Validate ndim first (before hidden_size which depends on correct ndim) + self._validate_ndim() + + # Validate hidden dimension if expected size is provided + if expected_hidden_size is not None: + self._validate_hidden_size(expected_hidden_size) + + def _validate_ndim(self) -> None: + """Validate that embedding tensors have correct ndim (2D or 3D).""" + if isinstance(self.data, torch.Tensor): + validate_embedding_ndim(self.data, self.modality) + else: + # List of tensors: each should be 2D (seq_len, hidden_size) + for idx, tensor in enumerate(self.data): + if tensor.ndim != 2: + raise ValueError( + f"{self.modality.capitalize()} embedding [{idx}] must be " + f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor " + f"with shape {tuple(tensor.shape)}" + ) + + def _validate_hidden_size(self, expected_hidden_size: int) -> None: + """Validate that embedding hidden dimension matches expected size. + + This validates hidden dimensions to prevent vulnerabilities: Embeddings + with correct ndim but wrong hidden dimension could bypass initial + checks and cause crashes during model inference when dimensions don't match. + """ + if isinstance(self.data, torch.Tensor): + # Batched tensor: shape is (batch, seq_len, hidden_size) + actual_hidden_size = self.data.shape[-1] + if actual_hidden_size != expected_hidden_size: + raise ValueError( + f"{self.modality.capitalize()} embedding hidden dimension " + f"mismatch: got {actual_hidden_size}, but model expects " + f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}" + ) + else: + # List of tensors: each has shape (seq_len, hidden_size) + for idx, tensor in enumerate(self.data): + actual_hidden_size = tensor.shape[-1] + if actual_hidden_size != expected_hidden_size: + raise ValueError( + f"{self.modality.capitalize()} embedding [{idx}] hidden " + f"dimension mismatch: got {actual_hidden_size}, but model " + f"expects {expected_hidden_size}. " + f"Embedding shape: {tuple(tensor.shape)}" + ) + def _unwrap( self, item: torch.Tensor | MediaWithBytes[torch.Tensor] ) -> torch.Tensor: @@ -228,8 +309,12 @@ class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): class AudioEmbeddingItems(EmbeddingItems): - def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None: - super().__init__(data, "audio") + def __init__( + self, + data: torch.Tensor | list[torch.Tensor], + expected_hidden_size: int | None = None, + ) -> None: + super().__init__(data, "audio", expected_hidden_size) class ImageSize(NamedTuple): @@ -256,8 +341,12 @@ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): class ImageEmbeddingItems(EmbeddingItems): - def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None: - super().__init__(data, "image") + def __init__( + self, + data: torch.Tensor | list[torch.Tensor], + expected_hidden_size: int | None = None, + ) -> None: + super().__init__(data, "image", expected_hidden_size) class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): @@ -287,8 +376,12 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): class VideoEmbeddingItems(EmbeddingItems): - def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None: - super().__init__(data, "video") + def __init__( + self, + data: torch.Tensor | list[torch.Tensor], + expected_hidden_size: int | None = None, + ) -> None: + super().__init__(data, "video", expected_hidden_size) _D = TypeVar("_D", bound=ModalityDataItems[Any, Any]) @@ -363,6 +456,10 @@ class MultiModalDataParser: Args: target_sr (float, optional): Enables automatic resampling of audio items to the model's expected sampling rate. + expected_hidden_size (int, optional): Expected hidden dimension for + embedding inputs. If provided, validates that user-supplied + embeddings have the correct hidden size to prevent crashes + during model inference. """ def __init__( @@ -371,6 +468,7 @@ class MultiModalDataParser: target_sr: float | None = None, audio_resample_method: Literal["librosa", "scipy"] = "librosa", video_needs_metadata: bool = False, + expected_hidden_size: int | None = None, ) -> None: super().__init__() @@ -379,6 +477,7 @@ class MultiModalDataParser: method=audio_resample_method, ) self.video_needs_metadata = video_needs_metadata + self.expected_hidden_size = expected_hidden_size @classmethod def is_embeddings( @@ -443,7 +542,7 @@ class MultiModalDataParser: return None if self.is_embeddings(data): - return AudioEmbeddingItems(data) + return AudioEmbeddingItems(data, self.expected_hidden_size) data_items: list[AudioItem] if ( @@ -481,7 +580,7 @@ class MultiModalDataParser: return None if self.is_embeddings(data): - return ImageEmbeddingItems(data) + return ImageEmbeddingItems(data, self.expected_hidden_size) if ( isinstance(data, (PILImage.Image, MediaWithBytes)) @@ -507,7 +606,7 @@ class MultiModalDataParser: return None if self.is_embeddings(data): - return VideoEmbeddingItems(data) + return VideoEmbeddingItems(data, self.expected_hidden_size) data_items: list[VideoItem] if ( diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 0390773783961..3bbdab3b393c5 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1330,7 +1330,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser] that has additional subparsers. """ - return MultiModalDataParser() + # Get expected hidden size for embedding validation if mm_embeds enabled + # This validates hidden dimensions to prevent vulnerabilities: embeddings + # with correct ndim but wrong shape could cause crashes at inference time + mm_config = self.info.ctx.model_config.get_multimodal_config() + expected_hidden_size = None + if mm_config.enable_mm_embeds: + expected_hidden_size = self.info.ctx.model_config.get_inputs_embeds_size() + + return MultiModalDataParser(expected_hidden_size=expected_hidden_size) def validate_num_items( self,