mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 17:01:52 +08:00
[Misc] Use apply_rotary_emb from vllm_flash_attn for Qwen2-VL vision RoPE (#17726)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
822de7fb94
commit
c3e9d5060e
@ -297,13 +297,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
|
||||
q = apply_rotary_pos_emb_vision(q,
|
||||
rotary_pos_emb,
|
||||
use_flash_attn=use_flash_attn)
|
||||
k = apply_rotary_pos_emb_vision(k,
|
||||
rotary_pos_emb,
|
||||
use_flash_attn=use_flash_attn)
|
||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
|
||||
@ -64,7 +64,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.processor import (
|
||||
@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor,
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
use_flash_attn=False) -> torch.Tensor:
|
||||
freqs: torch.Tensor) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
apply_rotary_emb = apply_rotary_emb_torch
|
||||
if use_flash_attn:
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user