[Misc] Use apply_rotary_emb from vllm_flash_attn for Qwen2-VL vision RoPE (#17726)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-05-07 12:51:33 +08:00 committed by GitHub
parent 822de7fb94
commit c3e9d5060e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 12 deletions

View File

@ -297,13 +297,8 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
q = apply_rotary_pos_emb_vision(q,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
k = apply_rotary_pos_emb_vision(k,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (

View File

@ -64,7 +64,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor,
use_flash_attn=False) -> torch.Tensor:
freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
if use_flash_attn:
from flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output