diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 84108200e914e..5bef4129bfa87 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -297,13 +297,8 @@ class Qwen2_5_VisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN - q = apply_rotary_pos_emb_vision(q, - rotary_pos_emb, - use_flash_attn=use_flash_attn) - k = apply_rotary_pos_emb_vision(k, - rotary_pos_emb, - use_flash_attn=use_flash_attn) + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 95f0c29d4858d..a00b756ecec07 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -64,7 +64,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import ( @@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor, - use_flash_attn=False) -> torch.Tensor: + freqs: torch.Tensor) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() apply_rotary_emb = apply_rotary_emb_torch - if use_flash_attn: - from flash_attn.layers.rotary import apply_rotary_emb + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb output = apply_rotary_emb(t_, cos, sin).type_as(t) return output