diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9d848679d5d98..71976fea1ee77 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -957,9 +957,11 @@ class OpenAIServing: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: tensor = torch.load(io.BytesIO(base64.b64decode(embed)), weights_only=True) - assert isinstance( - tensor, - (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor)) + assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( + torch.float32, + torch.bfloat16, + torch.float16, + ) if tensor.dim() > 2: tensor = tensor.squeeze(0) assert tensor.dim() == 2