diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index e31a1d077608f..4197583074dfe 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import io + # imports for guided decoding tests import openai +import pybase64 import pytest import regex as re +import torch + +from vllm.entrypoints.openai.serving_engine import OpenAIServing from ...utils import RemoteOpenAIServer @@ -42,3 +48,46 @@ async def test_out_of_vocab_token_ids(): prompt=[999999], max_tokens=5, temperature=0.0) + + +@pytest.mark.parametrize("dtype", + [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize( + "layout", + [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]) +@pytest.mark.parametrize("seq_len", [2, 10]) +@pytest.mark.parametrize("hidden_size", [2, 10]) +def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, + seq_len: int, hidden_size: int): + # construct arbitrary tensors of various dtypes, layouts, and sizes. + # We need to check against different layouts to make sure that if a user + # uses sparse tensors to reduce the transmission size of prompt embeddings, + # we must cast them to dense/strided before passing them into the engine. + # We don't use non-CPU tensors in this test to avoid preemptively + # initializing cuda and break other tests in the suite that fork processes. + # We also need to make sure that we only use devices that are actually + # available in the environment the test is running on. For simplicity, + # we just test against CPU. + tensor = torch.randn((seq_len, hidden_size), dtype=dtype) + if layout == torch.strided: + tensor = tensor.contiguous() + elif layout == torch.sparse_coo: + tensor = tensor.to_sparse_coo() + elif layout == torch.sparse_csc: + tensor = tensor.to_sparse_csc() + elif layout == torch.sparse_csr: + tensor = tensor.to_sparse_csr() + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + encoded_tensor = pybase64.b64encode(buffer.getvalue()) + + loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor) + assert len(loaded_prompt_embeds) == 1 + loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] + assert loaded_tensor.device.type == "cpu" + assert loaded_tensor.layout == torch.strided + torch.testing.assert_close(loaded_tensor, + tensor.to("cpu").to_dense(), + equal_nan=True) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d6f92a63301e8..0f4a7c0186b65 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1006,8 +1006,8 @@ class OpenAIServing: # OPTIMIZATION priority = orig_priority - 1 + @staticmethod def _load_prompt_embeds( - self, prompt_embeds: Optional[Union[bytes, list[bytes]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None ) -> list[EmbedsPrompt]: @@ -1015,12 +1015,14 @@ class OpenAIServing: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: tensor = torch.load(io.BytesIO( pybase64.b64decode(embed, validate=True)), - weights_only=True) + weights_only=True, + map_location=torch.device("cpu")) assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( torch.float32, torch.bfloat16, torch.float16, ) + tensor = tensor.to_dense() if tensor.dim() > 2: tensor = tensor.squeeze(0) assert tensor.dim() == 2