mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:55:00 +08:00
[Bugfix][Qwen][Multimodal] Move Qwen2_5_vl sdpa to custom op and reenable compile (#27764)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
parent
a4398fbb5e
commit
55011aef24
@ -14,6 +14,7 @@ To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -123,3 +124,55 @@ def vit_flash_attn_wrapper(
|
|||||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
|
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Once we have a torch 2.10, we can use tensor slices
|
||||||
|
# so we won't need to wrap this in custom ops
|
||||||
|
def torch_sdpa_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
outputs = []
|
||||||
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
start_idx = cu_seqlens[i - 1]
|
||||||
|
end_idx = cu_seqlens[i]
|
||||||
|
q_i = q[:, start_idx:end_idx]
|
||||||
|
k_i = k[:, start_idx:end_idx]
|
||||||
|
v_i = v[:, start_idx:end_idx]
|
||||||
|
q_i, k_i, v_i = (
|
||||||
|
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||||
|
)
|
||||||
|
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||||
|
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
||||||
|
outputs.append(output_i)
|
||||||
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
|
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
def torch_sdpa_wrapper_fake(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
b, s, h, d = q.shape
|
||||||
|
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="torch_sdpa_wrapper",
|
||||||
|
op_func=torch_sdpa_wrapper,
|
||||||
|
fake_impl=torch_sdpa_wrapper_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_torch_sdpa_wrapper(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
|
||||||
|
|||||||
@ -46,6 +46,7 @@ from vllm.attention.backends.registry import _Backend
|
|||||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||||
from vllm.attention.ops.vit_attn_wrappers import (
|
from vllm.attention.ops.vit_attn_wrappers import (
|
||||||
vit_flash_attn_wrapper,
|
vit_flash_attn_wrapper,
|
||||||
|
vit_torch_sdpa_wrapper,
|
||||||
vit_xformers_attn_wrapper,
|
vit_xformers_attn_wrapper,
|
||||||
)
|
)
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
@ -442,23 +443,12 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
outputs = []
|
context_layer = vit_torch_sdpa_wrapper(
|
||||||
for i in range(1, len(cu_seqlens)):
|
q,
|
||||||
start_idx = cu_seqlens[i - 1]
|
k,
|
||||||
end_idx = cu_seqlens[i]
|
v,
|
||||||
q_i = q[:, start_idx:end_idx]
|
cu_seqlens,
|
||||||
k_i = k[:, start_idx:end_idx]
|
)
|
||||||
v_i = v[:, start_idx:end_idx]
|
|
||||||
q_i, k_i, v_i = (
|
|
||||||
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
|
||||||
)
|
|
||||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
|
||||||
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
|
||||||
outputs.append(output_i)
|
|
||||||
context_layer = torch.cat(outputs, dim=1)
|
|
||||||
context_layer = einops.rearrange(
|
|
||||||
context_layer, "b s h d -> s b (h d)"
|
|
||||||
).contiguous()
|
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||||
|
|
||||||
@ -466,17 +456,15 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# (FIXME): Enable this after dynamic slicing is fixed
|
@support_torch_compile(
|
||||||
# See https://github.com/vllm-project/vllm/pull/27760
|
dynamic_arg_dims={
|
||||||
# @support_torch_compile(
|
"x": 0,
|
||||||
# dynamic_arg_dims={
|
"cu_seqlens": 0,
|
||||||
# "x": 0,
|
"rotary_pos_emb": 0,
|
||||||
# "cu_seqlens": 0,
|
"seqlens": 0,
|
||||||
# "rotary_pos_emb": 0,
|
},
|
||||||
# "seqlens": 0,
|
mark_unbacked_dims={"seqlens": 0},
|
||||||
# },
|
)
|
||||||
# mark_unbacked_dims={"seqlens": 0},
|
|
||||||
# )
|
|
||||||
class Qwen2_5_VisionBlock(nn.Module):
|
class Qwen2_5_VisionBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user