mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 00:55:01 +08:00
135 lines
4.7 KiB
Python
135 lines
4.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Unit tests for sparse tensor validation.
|
|
|
|
Simple, fast unit tests that can run without server fixtures.
|
|
Run with: pytest tests/multimodal/test_sparse_tensor_validation_unit.py -v
|
|
"""
|
|
|
|
import io
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
|
|
class TestSparseTensorValidationContextManager:
|
|
"""Test that torch.sparse.check_sparse_tensor_invariants() works as expected."""
|
|
|
|
def test_valid_sparse_tensor_passes(self):
|
|
"""Valid sparse tensors should pass validation."""
|
|
indices = torch.tensor([[0, 1], [0, 1]])
|
|
values = torch.tensor([1.0, 2.0])
|
|
shape = (2, 2)
|
|
|
|
with torch.sparse.check_sparse_tensor_invariants():
|
|
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
|
dense = tensor.to_dense()
|
|
|
|
assert dense.shape == shape
|
|
|
|
def test_out_of_bounds_indices_rejected(self):
|
|
"""Sparse tensors with out-of-bounds indices should be rejected."""
|
|
indices = torch.tensor([[5], [5]]) # Out of bounds for 2x2
|
|
values = torch.tensor([1.0])
|
|
shape = (2, 2)
|
|
|
|
with pytest.raises(RuntimeError) as exc_info: # noqa: SIM117
|
|
with torch.sparse.check_sparse_tensor_invariants():
|
|
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
|
tensor.to_dense()
|
|
|
|
assert (
|
|
"index" in str(exc_info.value).lower()
|
|
or "bound" in str(exc_info.value).lower()
|
|
)
|
|
|
|
def test_negative_indices_rejected(self):
|
|
"""Sparse tensors with negative indices should be rejected."""
|
|
indices = torch.tensor([[-1], [0]])
|
|
values = torch.tensor([1.0])
|
|
shape = (2, 2)
|
|
|
|
with pytest.raises(RuntimeError): # noqa: SIM117
|
|
with torch.sparse.check_sparse_tensor_invariants():
|
|
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
|
tensor.to_dense()
|
|
|
|
def test_without_context_manager_allows_invalid(self):
|
|
"""
|
|
WITHOUT validation, invalid tensors may not immediately error.
|
|
|
|
This demonstrates the vulnerability: PyTorch 2.8.0+ doesn't validate
|
|
by default, which can lead to memory corruption.
|
|
"""
|
|
indices = torch.tensor([[100], [100]]) # Way out of bounds
|
|
values = torch.tensor([1.0])
|
|
shape = (2, 2)
|
|
|
|
# Without validation context, this might create an invalid tensor
|
|
# (actual behavior depends on PyTorch version)
|
|
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
|
|
|
# The tensor object is created, but it's invalid
|
|
assert tensor.is_sparse
|
|
|
|
|
|
class TestTorchLoadWithValidation:
|
|
"""Test torch.load() with sparse tensor validation."""
|
|
|
|
def test_load_valid_sparse_tensor_with_validation(self):
|
|
"""Valid sparse tensors should load successfully with validation."""
|
|
# Create and save a valid sparse tensor
|
|
indices = torch.tensor([[0, 1], [0, 1]])
|
|
values = torch.tensor([1.0, 2.0])
|
|
tensor = torch.sparse_coo_tensor(indices, values, (2, 2))
|
|
|
|
buffer = io.BytesIO()
|
|
torch.save(tensor, buffer)
|
|
buffer.seek(0)
|
|
|
|
# Load with validation
|
|
with torch.sparse.check_sparse_tensor_invariants():
|
|
loaded = torch.load(buffer, weights_only=True)
|
|
dense = loaded.to_dense()
|
|
|
|
assert dense.shape == (2, 2)
|
|
|
|
def test_load_invalid_sparse_tensor_rejected(self):
|
|
"""Invalid sparse tensors should be caught when loaded with validation."""
|
|
# Create an invalid sparse tensor (out of bounds)
|
|
indices = torch.tensor([[10], [10]])
|
|
values = torch.tensor([1.0])
|
|
tensor = torch.sparse_coo_tensor(indices, values, (2, 2))
|
|
|
|
buffer = io.BytesIO()
|
|
torch.save(tensor, buffer)
|
|
buffer.seek(0)
|
|
|
|
# Load with validation - should fail on to_dense()
|
|
with pytest.raises(RuntimeError): # noqa: SIM117
|
|
with torch.sparse.check_sparse_tensor_invariants():
|
|
loaded = torch.load(buffer, weights_only=True)
|
|
loaded.to_dense()
|
|
|
|
def test_load_dense_tensor_unaffected(self):
|
|
"""Dense tensors should work normally with the validation context."""
|
|
# Create and save a dense tensor
|
|
tensor = torch.randn(10, 20)
|
|
|
|
buffer = io.BytesIO()
|
|
torch.save(tensor, buffer)
|
|
buffer.seek(0)
|
|
|
|
# Load with validation (should have no effect on dense tensors)
|
|
with torch.sparse.check_sparse_tensor_invariants():
|
|
loaded = torch.load(buffer, weights_only=True)
|
|
|
|
assert loaded.shape == (10, 20)
|
|
assert not loaded.is_sparse
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Allow running directly for quick testing
|
|
pytest.main([__file__, "-v", "--tb=short"])
|