diff --git a/tests/entrypoints/openai/test_sparse_tensor_validation.py b/tests/entrypoints/openai/test_sparse_tensor_validation.py new file mode 100644 index 0000000000000..907c82b57dead --- /dev/null +++ b/tests/entrypoints/openai/test_sparse_tensor_validation.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Sparse tensor validation in embedding APIs. + +Tests verify that malicious sparse tensors are rejected before they can trigger +out-of-bounds memory writes during to_dense() operations. +""" + +import base64 +import io + +import pytest +import torch + +from vllm.entrypoints.renderer import CompletionRenderer +from vllm.multimodal.audio import AudioEmbeddingMediaIO +from vllm.multimodal.image import ImageEmbeddingMediaIO + + +def _encode_tensor(tensor: torch.Tensor) -> bytes: + """Helper to encode a tensor as base64 bytes.""" + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + return base64.b64encode(buffer.read()) + + +def _create_malicious_sparse_tensor() -> torch.Tensor: + """ + Create a malicious sparse COO tensor with out-of-bounds indices. + + This tensor has indices that point beyond the declared shape, which would + cause an out-of-bounds write when converted to dense format without + validation. + """ + # Create a 3x3 sparse tensor but with indices pointing to (10, 10) + indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape + values = torch.tensor([1.0]) + shape = (3, 3) + + # Create sparse tensor (this will be invalid) + sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32) + return sparse_tensor + + +def _create_valid_sparse_tensor() -> torch.Tensor: + """Create a valid sparse COO tensor for baseline testing.""" + indices = torch.tensor([[0, 1, 2], [0, 1, 2]]) + values = torch.tensor([1.0, 2.0, 3.0]) + shape = (3, 3) + + sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32) + return sparse_tensor + + +def _create_valid_dense_tensor() -> torch.Tensor: + """Create a valid dense tensor for baseline testing.""" + return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size) + + +class TestPromptEmbedsValidation: + """Test sparse tensor validation in prompt embeddings (Completions API).""" + + def test_valid_dense_tensor_accepted(self, model_config): + """Baseline: Valid dense tensors should work normally.""" + renderer = CompletionRenderer(model_config) + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = renderer.load_prompt_embeds(encoded) + assert len(result) == 1 + assert result[0]["prompt_embeds"].shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should load successfully.""" + io_handler = ImageEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception (sparse tensors remain sparse) + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_sparse.shape + + def test_malicious_sparse_tensor_rejected(self, model_config): + """Security: Malicious sparse tensors should be rejected.""" + renderer = CompletionRenderer(model_config) + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + renderer.load_prompt_embeds(encoded) + + # Error should indicate sparse tensor validation failure + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_extremely_large_indices_rejected(self, model_config): + """Security: Sparse tensors with extremely large indices should be rejected.""" + renderer = CompletionRenderer(model_config) + + # Create tensor with indices far beyond reasonable bounds + indices = torch.tensor([[999999], [999999]]) + values = torch.tensor([1.0]) + shape = (10, 10) + + malicious_tensor = torch.sparse_coo_tensor( + indices, values, shape, dtype=torch.float32 + ) + encoded = _encode_tensor(malicious_tensor) + + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(encoded) + + def test_negative_indices_rejected(self, model_config): + """Security: Sparse tensors with negative indices should be rejected.""" + renderer = CompletionRenderer(model_config) + + # Create tensor with negative indices + indices = torch.tensor([[-1], [-1]]) + values = torch.tensor([1.0]) + shape = (10, 10) + + malicious_tensor = torch.sparse_coo_tensor( + indices, values, shape, dtype=torch.float32 + ) + encoded = _encode_tensor(malicious_tensor) + + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(encoded) + + +class TestImageEmbedsValidation: + """Test sparse tensor validation in image embeddings (Chat API).""" + + def test_valid_dense_tensor_accepted(self): + """Baseline: Valid dense tensors should work normally.""" + io_handler = ImageEmbeddingMediaIO() + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should load successfully.""" + io_handler = AudioEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception (sparse tensors remain sparse) + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_sparse.shape + + def test_malicious_sparse_tensor_rejected(self): + """Security: Malicious sparse tensors should be rejected.""" + io_handler = ImageEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + io_handler.load_base64("", encoded.decode("utf-8")) + + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_load_bytes_validates(self): + """Security: Validation should also work for load_bytes method.""" + io_handler = ImageEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + buffer = io.BytesIO() + torch.save(malicious_tensor, buffer) + buffer.seek(0) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_bytes(buffer.read()) + + +class TestAudioEmbedsValidation: + """Test sparse tensor validation in audio embeddings (Chat API).""" + + def test_valid_dense_tensor_accepted(self): + """Baseline: Valid dense tensors should work normally.""" + io_handler = AudioEmbeddingMediaIO() + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should be converted successfully.""" + io_handler = AudioEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.is_sparse is False + + def test_malicious_sparse_tensor_rejected(self): + """Security: Malicious sparse tensors should be rejected.""" + io_handler = AudioEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + io_handler.load_base64("", encoded.decode("utf-8")) + + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_load_bytes_validates(self): + """Security: Validation should also work for load_bytes method.""" + io_handler = AudioEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + buffer = io.BytesIO() + torch.save(malicious_tensor, buffer) + buffer.seek(0) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_bytes(buffer.read()) + + +class TestSparseTensorValidationIntegration: + """ + These tests verify the complete attack chain is blocked at all entry points. + """ + + def test_attack_scenario_completions_api(self, model_config): + """ + Simulate a complete attack through the Completions API. + + Attack scenario: + 1. Attacker crafts malicious sparse tensor + 2. Encodes it as base64 + 3. Sends to /v1/completions with prompt_embeds parameter + 4. Server should reject before memory corruption occurs + """ + renderer = CompletionRenderer(model_config) + + # Step 1-2: Attacker creates malicious payload + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + # Step 3-4: Server processes and should reject + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(attack_payload) + + def test_attack_scenario_chat_api_image(self): + """ + Simulate attack through Chat API with image_embeds. + + Verifies the image embeddings path is protected. + """ + io_handler = ImageEmbeddingMediaIO() + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_base64("", attack_payload.decode("utf-8")) + + def test_attack_scenario_chat_api_audio(self): + """ + Simulate attack through Chat API with audio_embeds. + + Verifies the audio embeddings path is protected. + """ + io_handler = AudioEmbeddingMediaIO() + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_base64("", attack_payload.decode("utf-8")) + + def test_multiple_valid_embeddings_in_batch(self, model_config): + """ + Regression test: Multiple valid embeddings should still work. + + Ensures the fix doesn't break legitimate batch processing. + """ + renderer = CompletionRenderer(model_config) + + valid_tensors = [ + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_valid_dense_tensor()), + ] + + # Should process all without error + result = renderer.load_prompt_embeds(valid_tensors) + assert len(result) == 3 + + def test_mixed_valid_and_malicious_rejected(self, model_config): + """ + Security: Batch with one malicious tensor should be rejected. + + Even if most tensors are valid, a single malicious one should + cause rejection of the entire batch. + """ + renderer = CompletionRenderer(model_config) + + mixed_batch = [ + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_malicious_sparse_tensor()), # Malicious + _encode_tensor(_create_valid_dense_tensor()), + ] + + # Should fail on the malicious tensor + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(mixed_batch) + + +# Pytest fixtures +@pytest.fixture +def model_config(): + """Mock ModelConfig for testing.""" + from vllm.config import ModelConfig + + return ModelConfig( + model="facebook/opt-125m", + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float32", + seed=0, + enable_prompt_embeds=True, # Required for prompt embeds tests + ) diff --git a/tests/multimodal/test_sparse_tensor_validation_unit.py b/tests/multimodal/test_sparse_tensor_validation_unit.py new file mode 100644 index 0000000000000..2eec8ea8283a2 --- /dev/null +++ b/tests/multimodal/test_sparse_tensor_validation_unit.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for sparse tensor validation. + +Simple, fast unit tests that can run without server fixtures. +Run with: pytest tests/multimodal/test_sparse_tensor_validation_unit.py -v +""" + +import io + +import pytest +import torch + + +class TestSparseTensorValidationContextManager: + """Test that torch.sparse.check_sparse_tensor_invariants() works as expected.""" + + def test_valid_sparse_tensor_passes(self): + """Valid sparse tensors should pass validation.""" + indices = torch.tensor([[0, 1], [0, 1]]) + values = torch.tensor([1.0, 2.0]) + shape = (2, 2) + + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + dense = tensor.to_dense() + + assert dense.shape == shape + + def test_out_of_bounds_indices_rejected(self): + """Sparse tensors with out-of-bounds indices should be rejected.""" + indices = torch.tensor([[5], [5]]) # Out of bounds for 2x2 + values = torch.tensor([1.0]) + shape = (2, 2) + + with pytest.raises(RuntimeError) as exc_info: # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + tensor.to_dense() + + assert ( + "index" in str(exc_info.value).lower() + or "bound" in str(exc_info.value).lower() + ) + + def test_negative_indices_rejected(self): + """Sparse tensors with negative indices should be rejected.""" + indices = torch.tensor([[-1], [0]]) + values = torch.tensor([1.0]) + shape = (2, 2) + + with pytest.raises(RuntimeError): # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + tensor.to_dense() + + def test_without_context_manager_allows_invalid(self): + """ + WITHOUT validation, invalid tensors may not immediately error. + + This demonstrates the vulnerability: PyTorch 2.8.0+ doesn't validate + by default, which can lead to memory corruption. + """ + indices = torch.tensor([[100], [100]]) # Way out of bounds + values = torch.tensor([1.0]) + shape = (2, 2) + + # Without validation context, this might create an invalid tensor + # (actual behavior depends on PyTorch version) + tensor = torch.sparse_coo_tensor(indices, values, shape) + + # The tensor object is created, but it's invalid + assert tensor.is_sparse + + +class TestTorchLoadWithValidation: + """Test torch.load() with sparse tensor validation.""" + + def test_load_valid_sparse_tensor_with_validation(self): + """Valid sparse tensors should load successfully with validation.""" + # Create and save a valid sparse tensor + indices = torch.tensor([[0, 1], [0, 1]]) + values = torch.tensor([1.0, 2.0]) + tensor = torch.sparse_coo_tensor(indices, values, (2, 2)) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + dense = loaded.to_dense() + + assert dense.shape == (2, 2) + + def test_load_invalid_sparse_tensor_rejected(self): + """Invalid sparse tensors should be caught when loaded with validation.""" + # Create an invalid sparse tensor (out of bounds) + indices = torch.tensor([[10], [10]]) + values = torch.tensor([1.0]) + tensor = torch.sparse_coo_tensor(indices, values, (2, 2)) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation - should fail on to_dense() + with pytest.raises(RuntimeError): # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + loaded.to_dense() + + def test_load_dense_tensor_unaffected(self): + """Dense tensors should work normally with the validation context.""" + # Create and save a dense tensor + tensor = torch.randn(10, 20) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation (should have no effect on dense tensors) + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + + assert loaded.shape == (10, 20) + assert not loaded.is_sparse + + +if __name__ == "__main__": + # Allow running directly for quick testing + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index 22f3c61ff73fa..0f89c840be80f 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -167,17 +167,20 @@ class BaseRenderer(ABC): ) def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: - tensor = torch.load( - io.BytesIO(pybase64.b64decode(embed, validate=True)), - weights_only=True, - map_location=torch.device("cpu"), - ) - assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( - torch.float32, - torch.bfloat16, - torch.float16, - ) - tensor = tensor.to_dense() + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load( + io.BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) + assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( + torch.float32, + torch.bfloat16, + torch.float16, + ) + tensor = tensor.to_dense() if tensor.dim() > 2: tensor = tensor.squeeze(0) assert tensor.dim() == 2 diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 062547401c3cf..51b8f77f29088 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -127,13 +127,21 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]): def load_bytes(self, data: bytes) -> torch.Tensor: buffer = BytesIO(data) - return torch.load(buffer, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(buffer, weights_only=True) + return tensor.to_dense() def load_base64(self, media_type: str, data: str) -> torch.Tensor: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: - return torch.load(filepath, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(filepath, weights_only=True) + return tensor.to_dense() def encode_base64(self, media: torch.Tensor) -> str: return tensor2base64(media) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 789421e9e0c3b..1506ecb8c7aa0 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -122,13 +122,21 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): def load_bytes(self, data: bytes) -> torch.Tensor: buffer = BytesIO(data) - return torch.load(buffer, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(buffer, weights_only=True) + return tensor.to_dense() def load_base64(self, media_type: str, data: str) -> torch.Tensor: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: - return torch.load(filepath, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(filepath, weights_only=True) + return tensor.to_dense() def encode_base64(self, media: torch.Tensor) -> str: return pybase64.b64encode(media.numpy()).decode("utf-8")