diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 6cefe7441668..06a9f7cd8226 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -14,6 +14,7 @@ To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0) import einops import torch +import torch.nn.functional as F from vllm.utils.torch_utils import direct_register_custom_op @@ -123,3 +124,55 @@ def vit_flash_attn_wrapper( return torch.ops.vllm.flash_attn_maxseqlen_wrapper( q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa ) + + +# TODO: Once we have a torch 2.10, we can use tensor slices +# so we won't need to wrap this in custom ops +def torch_sdpa_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = ( + einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) + output_i = einops.rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous() + return context_layer + + +def torch_sdpa_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + b, s, h, d = q.shape + return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + + +direct_register_custom_op( + op_name="torch_sdpa_wrapper", + op_func=torch_sdpa_wrapper, + fake_impl=torch_sdpa_wrapper_fake, +) + + +def vit_torch_sdpa_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, +) -> torch.Tensor: + return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3585783e4ccc..2b04608dfd03 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -46,6 +46,7 @@ from vllm.attention.backends.registry import _Backend from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, + vit_torch_sdpa_wrapper, vit_xformers_attn_wrapper, ) from vllm.compilation.decorators import support_torch_compile @@ -442,23 +443,12 @@ class Qwen2_5_VisionAttention(nn.Module): q = q.contiguous() k = k.contiguous() v = v.contiguous() - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = einops.rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = einops.rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + context_layer = vit_torch_sdpa_wrapper( + q, + k, + v, + cu_seqlens, + ) elif self.attn_backend == _Backend.XFORMERS: context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) @@ -466,17 +456,15 @@ class Qwen2_5_VisionAttention(nn.Module): return output -# (FIXME): Enable this after dynamic slicing is fixed -# See https://github.com/vllm-project/vllm/pull/27760 -# @support_torch_compile( -# dynamic_arg_dims={ -# "x": 0, -# "cu_seqlens": 0, -# "rotary_pos_emb": 0, -# "seqlens": 0, -# }, -# mark_unbacked_dims={"seqlens": 0}, -# ) +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb": 0, + "seqlens": 0, + }, + mark_unbacked_dims={"seqlens": 0}, +) class Qwen2_5_VisionBlock(nn.Module): def __init__( self,