From 8ecccdd15fa530dab7bae0b9c98b8278c60d04d3 Mon Sep 17 00:00:00 2001 From: cjackal <44624812+cjackal@users.noreply.github.com> Date: Wed, 1 Oct 2025 05:35:15 +0900 Subject: [PATCH] [Llama4] [multimodal] Fix misplaced dtype cast of `cos_sin_cache` in `Llama4VisionRotaryEmbedding` (#25889) Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> Signed-off-by: yewentao256 --- .../layers/rotary_embedding/llama4_vision_rope.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 8717280353068..c98a426a2a1ef 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -59,7 +59,9 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): key: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert key is not None - self._match_cos_sin_cache_dtype(query) + # self.cos_sin_cache here is complex tensor so we cannot cast into + # query's dtype directly with self._match_cos_sin_cache_dtype + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) query_ = torch.view_as_complex(query.float().reshape( *query.shape[:-1], -1, 2)) key_ = torch.view_as_complex(key.float().reshape(