[Llama4] [multimodal] Fix misplaced dtype cast of cos_sin_cache in Llama4VisionRotaryEmbedding (#25889)

Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
cjackal 2025-10-01 05:35:15 +09:00 committed by yewentao256
parent bb2e04e41e
commit 8ecccdd15f

View File

@ -59,7 +59,9 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
self._match_cos_sin_cache_dtype(query)
# self.cos_sin_cache here is complex tensor so we cannot cast into
# query's dtype directly with self._match_cos_sin_cache_dtype
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(
*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(