[Bugfix] Fix isinstance check for tensor types in _load_prompt_embeds to use dtype comparison (#21612)

Signed-off-by: Alexandre Juan <a.juan@netheos.net>
This commit is contained in:
Alexandre JUAN 2025-07-26 05:11:10 +02:00 committed by GitHub
parent a55c95096b
commit 2f6e6b33fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -957,9 +957,11 @@ class OpenAIServing:
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
weights_only=True)
assert isinstance(
tensor,
(torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2