mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 13:51:24 +08:00
Feature/vit attention unification# 23880 (#23978)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
72d30108a0
commit
6cbd41909e
@ -23,6 +23,9 @@ def clear_cache():
|
|||||||
"""Clear lru cache to ensure each test case runs without caching.
|
"""Clear lru cache to ensure each test case runs without caching.
|
||||||
"""
|
"""
|
||||||
_cached_get_attn_backend.cache_clear()
|
_cached_get_attn_backend.cache_clear()
|
||||||
|
# Clear xformers availability cache
|
||||||
|
import vllm.attention.layer as layer_module
|
||||||
|
layer_module.USE_XFORMERS_OPS = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||||
@ -33,19 +36,28 @@ def test_mha_attn_platform(device: str):
|
|||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
with patch("vllm.attention.selector.current_platform",
|
||||||
|
CpuPlatform()), \
|
||||||
|
patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
with patch("vllm.attention.selector.current_platform",
|
||||||
|
RocmPlatform()), \
|
||||||
|
patch("vllm.platforms.current_platform", RocmPlatform()), \
|
||||||
|
patch("vllm.attention.layer.current_platform", RocmPlatform()):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||||
else:
|
else:
|
||||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
with patch("vllm.attention.selector.current_platform",
|
||||||
|
CudaPlatform()), \
|
||||||
|
patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == _Backend.XFORMERS
|
assert attn.attn_backend == _Backend.XFORMERS
|
||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
with patch("vllm.attention.selector.current_platform",
|
||||||
|
CudaPlatform()), \
|
||||||
|
patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||||
attn = MultiHeadAttention(16, 72, scale=1)
|
attn = MultiHeadAttention(16, 72, scale=1)
|
||||||
assert attn.attn_backend == _Backend.XFORMERS
|
assert attn.attn_backend == _Backend.XFORMERS
|
||||||
|
|
||||||
|
|||||||
@ -360,13 +360,13 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# currently, only torch_sdpa is supported on rocm
|
# currently, only torch_sdpa is supported on rocm
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
else:
|
else:
|
||||||
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
|
|
||||||
_Backend.FLEX_ATTENTION):
|
|
||||||
backend = _Backend.XFORMERS
|
|
||||||
|
|
||||||
self.attn_backend = backend if backend in {
|
self.attn_backend = backend if backend in {
|
||||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
_Backend.TORCH_SDPA,
|
||||||
} else _Backend.TORCH_SDPA
|
_Backend.TORCH_SDPA_VLLM_V1,
|
||||||
|
_Backend.XFORMERS,
|
||||||
|
_Backend.PALLAS_VLLM_V1,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
|
} else current_platform.get_vit_attn_backend()
|
||||||
|
|
||||||
if (self.attn_backend == _Backend.XFORMERS
|
if (self.attn_backend == _Backend.XFORMERS
|
||||||
and not check_xformers_availability()):
|
and not check_xformers_availability()):
|
||||||
@ -413,6 +413,19 @@ class MultiHeadAttention(nn.Module):
|
|||||||
from torch_xla.experimental.custom_kernel import flash_attention
|
from torch_xla.experimental.custom_kernel import flash_attention
|
||||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||||
out = out.transpose(1, 2)
|
out = out.transpose(1, 2)
|
||||||
|
elif self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
|
from aiter import flash_attn_varlen_func
|
||||||
|
|
||||||
|
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
|
||||||
|
out = flash_attn_varlen_func(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
softmax_scale=self.scale)
|
||||||
|
else:
|
||||||
|
# ViT attention hasn't supported this backend yet
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"ViT attention hasn't supported {self.attn_backend} "
|
||||||
|
f"backend yet.")
|
||||||
|
|
||||||
return out.reshape(bsz, q_len, -1)
|
return out.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
|||||||
@ -170,6 +170,7 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
)
|
)
|
||||||
|
# Use unified MultiHeadAttention with Flash Attention support
|
||||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||||
self.head_dim, self.scale)
|
self.head_dim, self.scale)
|
||||||
|
|
||||||
@ -181,6 +182,8 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
hidden_states
|
hidden_states
|
||||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
# Use unified MultiHeadAttention implementation
|
||||||
out = self.attn(query_states, key_states, value_states)
|
out = self.attn(query_states, key_states, value_states)
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|||||||
@ -255,6 +255,10 @@ class InternSdpaAttention(nn.Module):
|
|||||||
|
|
||||||
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
|
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||||
|
|
||||||
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
|
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
|
||||||
|
self.scale)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qkv = self.qkv(x)
|
qkv = self.qkv(x)
|
||||||
@ -268,12 +272,9 @@ class InternSdpaAttention(nn.Module):
|
|||||||
B_, N_, H_, D_ = q.shape
|
B_, N_, H_, D_ = q.shape
|
||||||
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||||
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||||
q = q.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
x = x.transpose(1, 2).reshape(B, N, -1)
|
x = self.attn(q, k, v)
|
||||||
|
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -12,10 +12,10 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from transformers.utils import torch_int
|
from transformers.utils import torch_int
|
||||||
|
|
||||||
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -206,6 +206,10 @@ class InternSdpaAttention(nn.Module):
|
|||||||
|
|
||||||
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
|
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||||
|
|
||||||
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
|
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
|
||||||
|
self.scale)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
|
|
||||||
@ -213,20 +217,13 @@ class InternSdpaAttention(nn.Module):
|
|||||||
k = self.k_proj(x)
|
k = self.k_proj(x)
|
||||||
v = self.v_proj(x)
|
v = self.v_proj(x)
|
||||||
|
|
||||||
q = q.view(B, N, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(B, N, self.num_heads, self.head_dim)
|
|
||||||
v = v.view(B, N, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
if self.qk_normalization:
|
if self.qk_normalization:
|
||||||
B_, N_, H_, D_ = q.shape
|
B_, N_, H_, D_ = q.shape
|
||||||
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||||
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||||
q = q.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
x = x.transpose(1, 2).reshape(B, N, -1)
|
x = self.attn(q, k, v)
|
||||||
|
|
||||||
x = self.projection_layer(x)
|
x = self.projection_layer(x)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from transformers.models.mllama.processing_mllama import (
|
|||||||
|
|
||||||
import vllm.distributed.parallel_state as ps
|
import vllm.distributed.parallel_state as ps
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.attention.selector import _Backend
|
from vllm.attention.selector import _Backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -517,6 +518,10 @@ class MllamaVisionSdpaAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
|
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
|
||||||
|
1.0 / math.sqrt(self.head_dim))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_state: torch.Tensor,
|
hidden_state: torch.Tensor,
|
||||||
@ -524,21 +529,10 @@ class MllamaVisionSdpaAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_state)
|
qkv, _ = self.qkv_proj(hidden_state)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# TODO: remove padding in image encoder
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
attn_output = F.scaled_dot_product_attention(q,
|
attn_output = self.attn(q, k, v)
|
||||||
k,
|
|
||||||
v,
|
|
||||||
attn_mask=attention_mask,
|
|
||||||
dropout_p=0.0)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(attn_output.shape[0],
|
attn_output = attn_output.reshape(attn_output.shape[0],
|
||||||
attn_output.shape[1], -1)
|
attn_output.shape[1], -1)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from torchvision import transforms
|
|||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||||
|
|
||||||
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -682,9 +683,9 @@ class Step3VisionAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
disable_tp=use_data_parallel)
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
return tensor.view(bsz, seq_len, self.num_heads,
|
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
|
||||||
self.head_dim).transpose(1, 2).contiguous()
|
self.scale)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -696,19 +697,9 @@ class Step3VisionAttention(nn.Module):
|
|||||||
# get query proj
|
# get query proj
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
v = v.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
attn_output = self.attn(q, k, v)
|
||||||
q = q.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
attn_output = F.scaled_dot_product_attention(q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
scale=self.scale,
|
|
||||||
is_causal=False)
|
|
||||||
attn_output = attn_output.transpose(1, 2).reshape(
|
|
||||||
bsz, tgt_len, self.num_heads * self.head_dim)
|
|
||||||
|
|
||||||
attn_output, _ = self.out_proj(attn_output)
|
attn_output, _ = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|||||||
@ -122,4 +122,4 @@ def resolve_visual_encoder_outputs(
|
|||||||
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
|
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
|
||||||
if post_layer_norm is not None and uses_last_layer:
|
if post_layer_norm is not None and uses_last_layer:
|
||||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||||
return torch.cat(hs_pool, dim=-1)
|
return torch.cat(hs_pool, dim=-1)
|
||||||
@ -48,6 +48,7 @@ class _Backend(enum.Enum):
|
|||||||
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
|
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
|
||||||
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
||||||
TORCH_SDPA = enum.auto()
|
TORCH_SDPA = enum.auto()
|
||||||
|
TORCH_SDPA_VLLM_V1 = enum.auto()
|
||||||
FLASHINFER = enum.auto()
|
FLASHINFER = enum.auto()
|
||||||
FLASHINFER_VLLM_V1 = enum.auto()
|
FLASHINFER_VLLM_V1 = enum.auto()
|
||||||
TRITON_MLA = enum.auto() # Supported by V1
|
TRITON_MLA = enum.auto() # Supported by V1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user