mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-21 18:24:38 +08:00
Add hidden dimension validation for multimodal embedding inputs (#30968)
Signed-off-by: Wenqi Glantz <wglantz@nvidia.com>
This commit is contained in:
parent
096b25c9ed
commit
4924ac582c
223
tests/entrypoints/openai/test_embedding_shape_validation.py
Normal file
223
tests/entrypoints/openai/test_embedding_shape_validation.py
Normal file
@ -0,0 +1,223 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Embedding shape validation in multimodal APIs.
|
||||
|
||||
Tests verify that embeddings with correct ndim but incorrect hidden_size
|
||||
are rejected before they can cause crashes during model inference.
|
||||
|
||||
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
|
||||
classes, not by CompletionRenderer or MediaIO classes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.parse import (
|
||||
AudioEmbeddingItems,
|
||||
ImageEmbeddingItems,
|
||||
MultiModalDataParser,
|
||||
VideoEmbeddingItems,
|
||||
)
|
||||
|
||||
|
||||
class TestMultiModalParserShapeValidation:
|
||||
"""Test hidden_size validation in MultiModalDataParser."""
|
||||
|
||||
def test_image_embeddings_correct_hidden_size_accepted(self):
|
||||
"""Baseline: Image embeddings with correct hidden_size should work."""
|
||||
expected_hidden_size = 768
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
valid_embeds = torch.randn(2, 100, expected_hidden_size)
|
||||
|
||||
result = parser.parse_mm_data({"image": valid_embeds})
|
||||
|
||||
assert "image" in result
|
||||
assert isinstance(result["image"], ImageEmbeddingItems)
|
||||
assert result["image"].get_count() == 2
|
||||
|
||||
def test_image_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Image embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"image": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "image" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_audio_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Audio embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 2048
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"audio": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "audio" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_video_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Video embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 512
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"video": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "video" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_list_of_embeddings_validates_each(self):
|
||||
"""Security: Each embedding in list should be validated."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 1024
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
# List with second tensor having wrong hidden_size
|
||||
invalid_embeds = [
|
||||
torch.randn(100, expected_hidden_size),
|
||||
torch.randn(100, wrong_hidden_size),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"image": invalid_embeds})
|
||||
|
||||
# Should identify which embedding failed
|
||||
assert "[1]" in str(exc_info.value)
|
||||
|
||||
def test_validation_disabled_allows_any_size(self):
|
||||
"""When validation disabled (legacy), any hidden_size allowed."""
|
||||
parser = MultiModalDataParser(expected_hidden_size=None)
|
||||
|
||||
any_hidden_size = 12345
|
||||
embeds = torch.randn(2, 100, any_hidden_size)
|
||||
|
||||
# Should not raise
|
||||
result = parser.parse_mm_data({"image": embeds})
|
||||
assert "image" in result
|
||||
assert isinstance(result["image"], ImageEmbeddingItems)
|
||||
|
||||
|
||||
class TestEmbeddingItemsDirectValidation:
|
||||
"""Direct tests for EmbeddingItems hidden_size validation."""
|
||||
|
||||
def test_image_embedding_items_validates_batched_tensor(self):
|
||||
"""Test validation for batched (3D) image embeddings."""
|
||||
expected = 768
|
||||
wrong = 1024
|
||||
|
||||
# Valid
|
||||
valid = torch.randn(2, 100, expected)
|
||||
items = ImageEmbeddingItems(valid, expected_hidden_size=expected)
|
||||
assert items.get_count() == 2
|
||||
|
||||
# Invalid
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert str(wrong) in str(exc_info.value)
|
||||
assert str(expected) in str(exc_info.value)
|
||||
|
||||
def test_image_embedding_items_validates_list_of_tensors(self):
|
||||
"""Test validation for list of 2D image embeddings."""
|
||||
expected = 768
|
||||
wrong = 512
|
||||
|
||||
# Valid list
|
||||
valid_list = [torch.randn(100, expected), torch.randn(50, expected)]
|
||||
items = ImageEmbeddingItems(valid_list, expected_hidden_size=expected)
|
||||
assert items.get_count() == 2
|
||||
|
||||
# Invalid list
|
||||
invalid_list = [torch.randn(100, expected), torch.randn(50, wrong)]
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(invalid_list, expected_hidden_size=expected)
|
||||
|
||||
assert "[1]" in str(exc_info.value)
|
||||
|
||||
def test_audio_embedding_items_validates(self):
|
||||
"""Test validation for audio embeddings."""
|
||||
expected = 768
|
||||
wrong = 256
|
||||
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AudioEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert "audio" in str(exc_info.value).lower()
|
||||
|
||||
def test_video_embedding_items_validates(self):
|
||||
"""Test validation for video embeddings."""
|
||||
expected = 768
|
||||
wrong = 384
|
||||
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
VideoEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert "video" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestShapeValidationIntegration:
|
||||
"""Integration tests verifying attack scenarios are blocked."""
|
||||
|
||||
def test_attack_scenario_multimodal_image(self):
|
||||
"""
|
||||
Simulate attack through Chat API with image embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"image": attack_tensor})
|
||||
|
||||
def test_attack_scenario_multimodal_audio(self):
|
||||
"""
|
||||
Simulate attack through Chat API with audio embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 2048
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"audio": attack_tensor})
|
||||
|
||||
def test_attack_scenario_multimodal_video(self):
|
||||
"""
|
||||
Simulate attack through Chat API with video embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 1024
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"video": attack_tensor})
|
||||
249
tests/multimodal/test_embedding_shape_validation_unit.py
Normal file
249
tests/multimodal/test_embedding_shape_validation_unit.py
Normal file
@ -0,0 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for embedding shape validation.
|
||||
|
||||
Simple, fast unit tests that can run without server fixtures.
|
||||
Run with: pytest tests/multimodal/test_embedding_shape_validation_unit.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.parse import (
|
||||
AudioEmbeddingItems,
|
||||
ImageEmbeddingItems,
|
||||
)
|
||||
|
||||
|
||||
class TestImageEmbedBasicValidation:
|
||||
"""Test basic ndim validation in image embeddings via ImageEmbeddingItems."""
|
||||
|
||||
def test_valid_2d_tensor_accepted(self):
|
||||
"""Baseline: 2D tensors should be accepted."""
|
||||
valid_tensor = torch.randn(10, 768, dtype=torch.float32)
|
||||
|
||||
# Should not raise - 2D is valid
|
||||
items = ImageEmbeddingItems(valid_tensor)
|
||||
assert items.get_count() == 10
|
||||
|
||||
def test_valid_3d_tensor_accepted(self):
|
||||
"""Baseline: 3D tensors should be accepted."""
|
||||
valid_tensor = torch.randn(2, 10, 768, dtype=torch.float32)
|
||||
|
||||
# Should not raise - 3D is valid
|
||||
items = ImageEmbeddingItems(valid_tensor)
|
||||
assert items.get_count() == 2
|
||||
|
||||
def test_valid_list_of_2d_tensors_accepted(self):
|
||||
"""Baseline: List of 2D tensors should be accepted."""
|
||||
tensors = [
|
||||
torch.randn(10, 768, dtype=torch.float32),
|
||||
torch.randn(15, 768, dtype=torch.float32),
|
||||
]
|
||||
|
||||
# Should not raise
|
||||
items = ImageEmbeddingItems(tensors)
|
||||
assert items.get_count() == 2
|
||||
|
||||
def test_1d_tensor_rejected(self):
|
||||
"""Security: 1D tensors should be rejected (invalid ndim)."""
|
||||
invalid_tensor = torch.randn(768, dtype=torch.float32) # 1D
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(invalid_tensor)
|
||||
|
||||
assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value)
|
||||
|
||||
def test_4d_tensor_rejected(self):
|
||||
"""Security: 4D tensors should be rejected (invalid ndim)."""
|
||||
invalid_tensor = torch.randn(1, 2, 10, 768, dtype=torch.float32) # 4D
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(invalid_tensor)
|
||||
|
||||
assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value)
|
||||
|
||||
def test_hidden_size_validation_correct_size(self):
|
||||
"""Embeddings with correct hidden size should be accepted."""
|
||||
expected_hidden_size = 768
|
||||
valid_tensor = torch.randn(10, expected_hidden_size, dtype=torch.float32)
|
||||
|
||||
# Should not raise
|
||||
items = ImageEmbeddingItems(
|
||||
valid_tensor, expected_hidden_size=expected_hidden_size
|
||||
)
|
||||
assert items.get_count() == 10
|
||||
|
||||
def test_hidden_size_validation_wrong_size_rejected(self):
|
||||
"""Embeddings with wrong hidden size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
invalid_tensor = torch.randn(10, wrong_hidden_size, dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(
|
||||
invalid_tensor, expected_hidden_size=expected_hidden_size
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "hidden dimension mismatch" in error_msg.lower()
|
||||
assert str(wrong_hidden_size) in error_msg
|
||||
assert str(expected_hidden_size) in error_msg
|
||||
|
||||
|
||||
class TestAudioEmbedBasicValidation:
|
||||
"""Test basic ndim validation in audio embeddings via AudioEmbeddingItems."""
|
||||
|
||||
def test_valid_2d_tensor_accepted(self):
|
||||
"""Baseline: 2D tensors should be accepted."""
|
||||
valid_tensor = torch.randn(10, 768, dtype=torch.float32)
|
||||
|
||||
# Should not raise - 2D is valid
|
||||
items = AudioEmbeddingItems(valid_tensor)
|
||||
assert items.get_count() == 10
|
||||
|
||||
def test_valid_3d_tensor_accepted(self):
|
||||
"""Baseline: 3D tensors should be accepted."""
|
||||
valid_tensor = torch.randn(2, 10, 768, dtype=torch.float32)
|
||||
|
||||
# Should not raise - 3D is valid
|
||||
items = AudioEmbeddingItems(valid_tensor)
|
||||
assert items.get_count() == 2
|
||||
|
||||
def test_valid_list_of_2d_tensors_accepted(self):
|
||||
"""Baseline: List of 2D tensors should be accepted."""
|
||||
tensors = [
|
||||
torch.randn(10, 768, dtype=torch.float32),
|
||||
torch.randn(15, 768, dtype=torch.float32),
|
||||
]
|
||||
|
||||
# Should not raise
|
||||
items = AudioEmbeddingItems(tensors)
|
||||
assert items.get_count() == 2
|
||||
|
||||
def test_1d_tensor_rejected(self):
|
||||
"""Security: 1D tensors should be rejected (invalid ndim)."""
|
||||
invalid_tensor = torch.randn(768, dtype=torch.float32) # 1D
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AudioEmbeddingItems(invalid_tensor)
|
||||
|
||||
assert "must be 2D" in str(exc_info.value) or "3D" in str(exc_info.value)
|
||||
|
||||
def test_scalar_rejected(self):
|
||||
"""Security: Scalar tensors should be rejected."""
|
||||
invalid_tensor = torch.tensor(1.0) # 0D (scalar)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AudioEmbeddingItems(invalid_tensor)
|
||||
|
||||
def test_hidden_size_validation_correct_size(self):
|
||||
"""Embeddings with correct hidden size should be accepted."""
|
||||
expected_hidden_size = 768
|
||||
valid_tensor = torch.randn(10, expected_hidden_size, dtype=torch.float32)
|
||||
|
||||
# Should not raise
|
||||
items = AudioEmbeddingItems(
|
||||
valid_tensor, expected_hidden_size=expected_hidden_size
|
||||
)
|
||||
assert items.get_count() == 10
|
||||
|
||||
def test_hidden_size_validation_wrong_size_rejected(self):
|
||||
"""Embeddings with wrong hidden size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
invalid_tensor = torch.randn(10, wrong_hidden_size, dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AudioEmbeddingItems(
|
||||
invalid_tensor, expected_hidden_size=expected_hidden_size
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "hidden dimension mismatch" in error_msg.lower()
|
||||
assert str(wrong_hidden_size) in error_msg
|
||||
assert str(expected_hidden_size) in error_msg
|
||||
|
||||
|
||||
class TestShapeValidationDoSPrevention:
|
||||
"""
|
||||
Tests for DoS prevention through shape validation.
|
||||
|
||||
Verifies that embeddings with incorrect shapes are rejected early,
|
||||
preventing crashes during model inference.
|
||||
"""
|
||||
|
||||
def test_prevent_crash_from_wrong_shape_image_embeds(self):
|
||||
"""
|
||||
Prevent crash scenario: wrong hidden size in image embeddings.
|
||||
|
||||
Without validation, this would pass initial checks but crash later
|
||||
during model forward pass when dimensions don't match.
|
||||
"""
|
||||
expected_hidden_size = 768 # Typical model hidden size
|
||||
wrong_hidden_size = 4096 # Wrong size (e.g., Llama-sized)
|
||||
|
||||
wrong_embedding = torch.randn(100, wrong_hidden_size, dtype=torch.float32)
|
||||
|
||||
# Should be rejected at instantiation time, not during inference
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(
|
||||
wrong_embedding, expected_hidden_size=expected_hidden_size
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "hidden dimension mismatch" in error_msg.lower()
|
||||
assert str(expected_hidden_size) in error_msg # Expected
|
||||
assert str(wrong_hidden_size) in error_msg # Received
|
||||
|
||||
def test_prevent_crash_from_wrong_shape_audio_embeds(self):
|
||||
"""
|
||||
Prevent crash scenario: wrong hidden size in audio embeddings.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
|
||||
wrong_embedding = torch.randn(100, wrong_hidden_size, dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AudioEmbeddingItems(
|
||||
wrong_embedding, expected_hidden_size=expected_hidden_size
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "hidden dimension mismatch" in error_msg.lower()
|
||||
|
||||
def test_extremely_large_hidden_size_rejected(self):
|
||||
"""Security: Prevent DoS from extremely large embeddings."""
|
||||
expected_hidden_size = 768
|
||||
huge_hidden_size = 100000 # Large but not extreme to avoid test OOM
|
||||
|
||||
invalid_tensor = torch.randn(10, huge_hidden_size, dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(
|
||||
invalid_tensor, expected_hidden_size=expected_hidden_size
|
||||
)
|
||||
|
||||
assert "hidden dimension mismatch" in str(exc_info.value).lower()
|
||||
|
||||
def test_batch_with_mixed_hidden_sizes_rejected(self):
|
||||
"""All embeddings in a list must have the same hidden size."""
|
||||
expected_hidden_size = 768
|
||||
|
||||
# One correct, one wrong
|
||||
batch = [
|
||||
torch.randn(10, expected_hidden_size, dtype=torch.float32),
|
||||
torch.randn(10, expected_hidden_size + 100, dtype=torch.float32), # Wrong!
|
||||
]
|
||||
|
||||
# Should fail on the second one
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(batch, expected_hidden_size=expected_hidden_size)
|
||||
|
||||
assert "hidden dimension mismatch" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@ -126,6 +126,30 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
return {}
|
||||
|
||||
|
||||
def validate_embedding_ndim(
|
||||
tensor: torch.Tensor,
|
||||
modality: str,
|
||||
index: int | None = None,
|
||||
) -> None:
|
||||
"""Validate tensor ndim for multimodal embeddings.
|
||||
|
||||
Single embeddings should be 2D (seq_len, hidden_size).
|
||||
Batched embeddings should be 3D (batch, seq_len, hidden_size).
|
||||
|
||||
Args:
|
||||
tensor: The tensor to validate.
|
||||
modality: The modality name for error messages (e.g., "image", "audio").
|
||||
index: Optional index for list items, included in error messages.
|
||||
"""
|
||||
if tensor.ndim < 2 or tensor.ndim > 3:
|
||||
idx_str = f" [{index}]" if index is not None else ""
|
||||
raise ValueError(
|
||||
f"{modality.capitalize()} embedding{idx_str} must be 2D "
|
||||
f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), "
|
||||
f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingItems(
|
||||
ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
|
||||
):
|
||||
@ -134,6 +158,63 @@ class EmbeddingItems(
|
||||
or a list of embedding tensors (one per item).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor | list[torch.Tensor],
|
||||
modality: str,
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, modality)
|
||||
|
||||
# Validate ndim first (before hidden_size which depends on correct ndim)
|
||||
self._validate_ndim()
|
||||
|
||||
# Validate hidden dimension if expected size is provided
|
||||
if expected_hidden_size is not None:
|
||||
self._validate_hidden_size(expected_hidden_size)
|
||||
|
||||
def _validate_ndim(self) -> None:
|
||||
"""Validate that embedding tensors have correct ndim (2D or 3D)."""
|
||||
if isinstance(self.data, torch.Tensor):
|
||||
validate_embedding_ndim(self.data, self.modality)
|
||||
else:
|
||||
# List of tensors: each should be 2D (seq_len, hidden_size)
|
||||
for idx, tensor in enumerate(self.data):
|
||||
if tensor.ndim != 2:
|
||||
raise ValueError(
|
||||
f"{self.modality.capitalize()} embedding [{idx}] must be "
|
||||
f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor "
|
||||
f"with shape {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
def _validate_hidden_size(self, expected_hidden_size: int) -> None:
|
||||
"""Validate that embedding hidden dimension matches expected size.
|
||||
|
||||
This validates hidden dimensions to prevent vulnerabilities: Embeddings
|
||||
with correct ndim but wrong hidden dimension could bypass initial
|
||||
checks and cause crashes during model inference when dimensions don't match.
|
||||
"""
|
||||
if isinstance(self.data, torch.Tensor):
|
||||
# Batched tensor: shape is (batch, seq_len, hidden_size)
|
||||
actual_hidden_size = self.data.shape[-1]
|
||||
if actual_hidden_size != expected_hidden_size:
|
||||
raise ValueError(
|
||||
f"{self.modality.capitalize()} embedding hidden dimension "
|
||||
f"mismatch: got {actual_hidden_size}, but model expects "
|
||||
f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}"
|
||||
)
|
||||
else:
|
||||
# List of tensors: each has shape (seq_len, hidden_size)
|
||||
for idx, tensor in enumerate(self.data):
|
||||
actual_hidden_size = tensor.shape[-1]
|
||||
if actual_hidden_size != expected_hidden_size:
|
||||
raise ValueError(
|
||||
f"{self.modality.capitalize()} embedding [{idx}] hidden "
|
||||
f"dimension mismatch: got {actual_hidden_size}, but model "
|
||||
f"expects {expected_hidden_size}. "
|
||||
f"Embedding shape: {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
def _unwrap(
|
||||
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
@ -228,8 +309,12 @@ class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
|
||||
|
||||
class AudioEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor | list[torch.Tensor],
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, "audio", expected_hidden_size)
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
@ -256,8 +341,12 @@ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
||||
|
||||
|
||||
class ImageEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "image")
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor | list[torch.Tensor],
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, "image", expected_hidden_size)
|
||||
|
||||
|
||||
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
@ -287,8 +376,12 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
|
||||
|
||||
class VideoEmbeddingItems(EmbeddingItems):
|
||||
def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
|
||||
super().__init__(data, "video")
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor | list[torch.Tensor],
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, "video", expected_hidden_size)
|
||||
|
||||
|
||||
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
||||
@ -363,6 +456,10 @@ class MultiModalDataParser:
|
||||
Args:
|
||||
target_sr (float, optional): Enables automatic resampling of audio
|
||||
items to the model's expected sampling rate.
|
||||
expected_hidden_size (int, optional): Expected hidden dimension for
|
||||
embedding inputs. If provided, validates that user-supplied
|
||||
embeddings have the correct hidden size to prevent crashes
|
||||
during model inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -371,6 +468,7 @@ class MultiModalDataParser:
|
||||
target_sr: float | None = None,
|
||||
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
||||
video_needs_metadata: bool = False,
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -379,6 +477,7 @@ class MultiModalDataParser:
|
||||
method=audio_resample_method,
|
||||
)
|
||||
self.video_needs_metadata = video_needs_metadata
|
||||
self.expected_hidden_size = expected_hidden_size
|
||||
|
||||
@classmethod
|
||||
def is_embeddings(
|
||||
@ -443,7 +542,7 @@ class MultiModalDataParser:
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return AudioEmbeddingItems(data)
|
||||
return AudioEmbeddingItems(data, self.expected_hidden_size)
|
||||
|
||||
data_items: list[AudioItem]
|
||||
if (
|
||||
@ -481,7 +580,7 @@ class MultiModalDataParser:
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return ImageEmbeddingItems(data)
|
||||
return ImageEmbeddingItems(data, self.expected_hidden_size)
|
||||
|
||||
if (
|
||||
isinstance(data, (PILImage.Image, MediaWithBytes))
|
||||
@ -507,7 +606,7 @@ class MultiModalDataParser:
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
return VideoEmbeddingItems(data, self.expected_hidden_size)
|
||||
|
||||
data_items: list[VideoItem]
|
||||
if (
|
||||
|
||||
@ -1330,7 +1330,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
|
||||
that has additional subparsers.
|
||||
"""
|
||||
return MultiModalDataParser()
|
||||
# Get expected hidden size for embedding validation if mm_embeds enabled
|
||||
# This validates hidden dimensions to prevent vulnerabilities: embeddings
|
||||
# with correct ndim but wrong shape could cause crashes at inference time
|
||||
mm_config = self.info.ctx.model_config.get_multimodal_config()
|
||||
expected_hidden_size = None
|
||||
if mm_config.enable_mm_embeds:
|
||||
expected_hidden_size = self.info.ctx.model_config.get_inputs_embeds_size()
|
||||
|
||||
return MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
def validate_num_items(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user