mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 17:45:17 +08:00
additional protection for CVE-2025-62164 (#30649)
Signed-off-by: Wenqi Glantz <wglantz@nvidia.com>
This commit is contained in:
parent
738648fb81
commit
84e23d103d
342
tests/entrypoints/openai/test_sparse_tensor_validation.py
Normal file
342
tests/entrypoints/openai/test_sparse_tensor_validation.py
Normal file
@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Sparse tensor validation in embedding APIs.
|
||||
|
||||
Tests verify that malicious sparse tensors are rejected before they can trigger
|
||||
out-of-bounds memory writes during to_dense() operations.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.multimodal.audio import AudioEmbeddingMediaIO
|
||||
from vllm.multimodal.image import ImageEmbeddingMediaIO
|
||||
|
||||
|
||||
def _encode_tensor(tensor: torch.Tensor) -> bytes:
|
||||
"""Helper to encode a tensor as base64 bytes."""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
return base64.b64encode(buffer.read())
|
||||
|
||||
|
||||
def _create_malicious_sparse_tensor() -> torch.Tensor:
|
||||
"""
|
||||
Create a malicious sparse COO tensor with out-of-bounds indices.
|
||||
|
||||
This tensor has indices that point beyond the declared shape, which would
|
||||
cause an out-of-bounds write when converted to dense format without
|
||||
validation.
|
||||
"""
|
||||
# Create a 3x3 sparse tensor but with indices pointing to (10, 10)
|
||||
indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape
|
||||
values = torch.tensor([1.0])
|
||||
shape = (3, 3)
|
||||
|
||||
# Create sparse tensor (this will be invalid)
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
|
||||
return sparse_tensor
|
||||
|
||||
|
||||
def _create_valid_sparse_tensor() -> torch.Tensor:
|
||||
"""Create a valid sparse COO tensor for baseline testing."""
|
||||
indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
||||
values = torch.tensor([1.0, 2.0, 3.0])
|
||||
shape = (3, 3)
|
||||
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
|
||||
return sparse_tensor
|
||||
|
||||
|
||||
def _create_valid_dense_tensor() -> torch.Tensor:
|
||||
"""Create a valid dense tensor for baseline testing."""
|
||||
return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size)
|
||||
|
||||
|
||||
class TestPromptEmbedsValidation:
|
||||
"""Test sparse tensor validation in prompt embeddings (Completions API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self, model_config):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = renderer.load_prompt_embeds(encoded)
|
||||
assert len(result) == 1
|
||||
assert result[0]["prompt_embeds"].shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should load successfully."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception (sparse tensors remain sparse)
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_sparse.shape
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self, model_config):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
# Error should indicate sparse tensor validation failure
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_extremely_large_indices_rejected(self, model_config):
|
||||
"""Security: Sparse tensors with extremely large indices should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Create tensor with indices far beyond reasonable bounds
|
||||
indices = torch.tensor([[999999], [999999]])
|
||||
values = torch.tensor([1.0])
|
||||
shape = (10, 10)
|
||||
|
||||
malicious_tensor = torch.sparse_coo_tensor(
|
||||
indices, values, shape, dtype=torch.float32
|
||||
)
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
def test_negative_indices_rejected(self, model_config):
|
||||
"""Security: Sparse tensors with negative indices should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Create tensor with negative indices
|
||||
indices = torch.tensor([[-1], [-1]])
|
||||
values = torch.tensor([1.0])
|
||||
shape = (10, 10)
|
||||
|
||||
malicious_tensor = torch.sparse_coo_tensor(
|
||||
indices, values, shape, dtype=torch.float32
|
||||
)
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
|
||||
class TestImageEmbedsValidation:
|
||||
"""Test sparse tensor validation in image embeddings (Chat API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should load successfully."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception (sparse tensors remain sparse)
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_sparse.shape
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_load_bytes_validates(self):
|
||||
"""Security: Validation should also work for load_bytes method."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
buffer = io.BytesIO()
|
||||
torch.save(malicious_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_bytes(buffer.read())
|
||||
|
||||
|
||||
class TestAudioEmbedsValidation:
|
||||
"""Test sparse tensor validation in audio embeddings (Chat API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should be converted successfully."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.is_sparse is False
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_load_bytes_validates(self):
|
||||
"""Security: Validation should also work for load_bytes method."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
buffer = io.BytesIO()
|
||||
torch.save(malicious_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_bytes(buffer.read())
|
||||
|
||||
|
||||
class TestSparseTensorValidationIntegration:
|
||||
"""
|
||||
These tests verify the complete attack chain is blocked at all entry points.
|
||||
"""
|
||||
|
||||
def test_attack_scenario_completions_api(self, model_config):
|
||||
"""
|
||||
Simulate a complete attack through the Completions API.
|
||||
|
||||
Attack scenario:
|
||||
1. Attacker crafts malicious sparse tensor
|
||||
2. Encodes it as base64
|
||||
3. Sends to /v1/completions with prompt_embeds parameter
|
||||
4. Server should reject before memory corruption occurs
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Step 1-2: Attacker creates malicious payload
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
# Step 3-4: Server processes and should reject
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(attack_payload)
|
||||
|
||||
def test_attack_scenario_chat_api_image(self):
|
||||
"""
|
||||
Simulate attack through Chat API with image_embeds.
|
||||
|
||||
Verifies the image embeddings path is protected.
|
||||
"""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_base64("", attack_payload.decode("utf-8"))
|
||||
|
||||
def test_attack_scenario_chat_api_audio(self):
|
||||
"""
|
||||
Simulate attack through Chat API with audio_embeds.
|
||||
|
||||
Verifies the audio embeddings path is protected.
|
||||
"""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_base64("", attack_payload.decode("utf-8"))
|
||||
|
||||
def test_multiple_valid_embeddings_in_batch(self, model_config):
|
||||
"""
|
||||
Regression test: Multiple valid embeddings should still work.
|
||||
|
||||
Ensures the fix doesn't break legitimate batch processing.
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
valid_tensors = [
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
]
|
||||
|
||||
# Should process all without error
|
||||
result = renderer.load_prompt_embeds(valid_tensors)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_mixed_valid_and_malicious_rejected(self, model_config):
|
||||
"""
|
||||
Security: Batch with one malicious tensor should be rejected.
|
||||
|
||||
Even if most tensors are valid, a single malicious one should
|
||||
cause rejection of the entire batch.
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
mixed_batch = [
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
]
|
||||
|
||||
# Should fail on the malicious tensor
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(mixed_batch)
|
||||
|
||||
|
||||
# Pytest fixtures
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
"""Mock ModelConfig for testing."""
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
return ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float32",
|
||||
seed=0,
|
||||
enable_prompt_embeds=True, # Required for prompt embeds tests
|
||||
)
|
||||
134
tests/multimodal/test_sparse_tensor_validation_unit.py
Normal file
134
tests/multimodal/test_sparse_tensor_validation_unit.py
Normal file
@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for sparse tensor validation.
|
||||
|
||||
Simple, fast unit tests that can run without server fixtures.
|
||||
Run with: pytest tests/multimodal/test_sparse_tensor_validation_unit.py -v
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class TestSparseTensorValidationContextManager:
|
||||
"""Test that torch.sparse.check_sparse_tensor_invariants() works as expected."""
|
||||
|
||||
def test_valid_sparse_tensor_passes(self):
|
||||
"""Valid sparse tensors should pass validation."""
|
||||
indices = torch.tensor([[0, 1], [0, 1]])
|
||||
values = torch.tensor([1.0, 2.0])
|
||||
shape = (2, 2)
|
||||
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
||||
dense = tensor.to_dense()
|
||||
|
||||
assert dense.shape == shape
|
||||
|
||||
def test_out_of_bounds_indices_rejected(self):
|
||||
"""Sparse tensors with out-of-bounds indices should be rejected."""
|
||||
indices = torch.tensor([[5], [5]]) # Out of bounds for 2x2
|
||||
values = torch.tensor([1.0])
|
||||
shape = (2, 2)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info: # noqa: SIM117
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
||||
tensor.to_dense()
|
||||
|
||||
assert (
|
||||
"index" in str(exc_info.value).lower()
|
||||
or "bound" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
def test_negative_indices_rejected(self):
|
||||
"""Sparse tensors with negative indices should be rejected."""
|
||||
indices = torch.tensor([[-1], [0]])
|
||||
values = torch.tensor([1.0])
|
||||
shape = (2, 2)
|
||||
|
||||
with pytest.raises(RuntimeError): # noqa: SIM117
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
||||
tensor.to_dense()
|
||||
|
||||
def test_without_context_manager_allows_invalid(self):
|
||||
"""
|
||||
WITHOUT validation, invalid tensors may not immediately error.
|
||||
|
||||
This demonstrates the vulnerability: PyTorch 2.8.0+ doesn't validate
|
||||
by default, which can lead to memory corruption.
|
||||
"""
|
||||
indices = torch.tensor([[100], [100]]) # Way out of bounds
|
||||
values = torch.tensor([1.0])
|
||||
shape = (2, 2)
|
||||
|
||||
# Without validation context, this might create an invalid tensor
|
||||
# (actual behavior depends on PyTorch version)
|
||||
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
||||
|
||||
# The tensor object is created, but it's invalid
|
||||
assert tensor.is_sparse
|
||||
|
||||
|
||||
class TestTorchLoadWithValidation:
|
||||
"""Test torch.load() with sparse tensor validation."""
|
||||
|
||||
def test_load_valid_sparse_tensor_with_validation(self):
|
||||
"""Valid sparse tensors should load successfully with validation."""
|
||||
# Create and save a valid sparse tensor
|
||||
indices = torch.tensor([[0, 1], [0, 1]])
|
||||
values = torch.tensor([1.0, 2.0])
|
||||
tensor = torch.sparse_coo_tensor(indices, values, (2, 2))
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
# Load with validation
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
loaded = torch.load(buffer, weights_only=True)
|
||||
dense = loaded.to_dense()
|
||||
|
||||
assert dense.shape == (2, 2)
|
||||
|
||||
def test_load_invalid_sparse_tensor_rejected(self):
|
||||
"""Invalid sparse tensors should be caught when loaded with validation."""
|
||||
# Create an invalid sparse tensor (out of bounds)
|
||||
indices = torch.tensor([[10], [10]])
|
||||
values = torch.tensor([1.0])
|
||||
tensor = torch.sparse_coo_tensor(indices, values, (2, 2))
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
# Load with validation - should fail on to_dense()
|
||||
with pytest.raises(RuntimeError): # noqa: SIM117
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
loaded = torch.load(buffer, weights_only=True)
|
||||
loaded.to_dense()
|
||||
|
||||
def test_load_dense_tensor_unaffected(self):
|
||||
"""Dense tensors should work normally with the validation context."""
|
||||
# Create and save a dense tensor
|
||||
tensor = torch.randn(10, 20)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
# Load with validation (should have no effect on dense tensors)
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
loaded = torch.load(buffer, weights_only=True)
|
||||
|
||||
assert loaded.shape == (10, 20)
|
||||
assert not loaded.is_sparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running directly for quick testing
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@ -167,17 +167,20 @@ class BaseRenderer(ABC):
|
||||
)
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||
tensor = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
|
||||
@ -127,13 +127,21 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
return torch.load(buffer, weights_only=True)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(buffer, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
return torch.load(filepath, weights_only=True)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(filepath, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return tensor2base64(media)
|
||||
|
||||
@ -122,13 +122,21 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
return torch.load(buffer, weights_only=True)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(buffer, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
return torch.load(filepath, weights_only=True)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(filepath, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return pybase64.b64encode(media.numpy()).decode("utf-8")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user