mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-13 13:37:03 +08:00
[Model] Remove transformers attention porting in VITs (#10414)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
5be4e52b65
commit
e7ebb662d7
@ -4,10 +4,11 @@ from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
from transformers.models.blip.modeling_blip import BlipAttention
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
@ -21,11 +22,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
|
||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
@ -168,7 +165,7 @@ class BlipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class BlipParallelAttention(nn.Module):
|
||||
class BlipAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
@ -208,6 +205,12 @@ class BlipParallelAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"BLIP does not support {self.attn_backend} backend now.")
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2).contiguous()
|
||||
@ -231,11 +234,26 @@ class BlipParallelAttention(nn.Module):
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
||||
for x in (query_states,
|
||||
key_states,
|
||||
value_states))
|
||||
out = F.scaled_dot_product_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
attn_output, _ = self.projection(out)
|
||||
|
||||
@ -285,18 +303,11 @@ class BlipEncoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = BlipParallelAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
else:
|
||||
# Blip doesn't have SDPA attention implemented in transformers
|
||||
# use eager attention instead for cpu backend
|
||||
self.self_attn = BlipAttention(config)
|
||||
self.self_attn = BlipAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = BlipMLP(config,
|
||||
@ -374,11 +385,6 @@ class BlipVisionModel(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embeddings = BlipVisionEmbeddings(config)
|
||||
@ -422,7 +428,7 @@ class BlipVisionModel(nn.Module):
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
] if self.shard_weight else []
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
@ -5,10 +5,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
@ -23,11 +24,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
@ -197,7 +194,7 @@ class CLIPVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPParallelAttention(nn.Module):
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
@ -237,6 +234,12 @@ class CLIPParallelAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"CLIP does not support {self.attn_backend} backend now.")
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2).contiguous()
|
||||
@ -261,11 +264,26 @@ class CLIPParallelAttention(nn.Module):
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
||||
for x in (query_states,
|
||||
key_states,
|
||||
value_states))
|
||||
out = F.scaled_dot_product_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
@ -311,17 +329,11 @@ class CLIPEncoderLayer(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = CLIPParallelAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
else:
|
||||
self.self_attn = CLIPSdpaAttention(config)
|
||||
self.self_attn = CLIPAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = CLIPMLP(config,
|
||||
@ -461,11 +473,6 @@ class CLIPVisionModel(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
@ -490,7 +497,7 @@ class CLIPVisionModel(nn.Module):
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
] if self.shard_weight else []
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
@ -12,6 +12,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@ -24,11 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
NORM2FN = {
|
||||
'rms_norm': RMSNorm,
|
||||
@ -186,6 +183,11 @@ class InternParallelAttention(nn.Module):
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"InternViT does not support {self.attn_backend} backend now.")
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||
@ -211,11 +213,21 @@ class InternParallelAttention(nn.Module):
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
|
||||
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
|
||||
x = x.view(B, N, -1)
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
x, _ = self.proj(x)
|
||||
return x
|
||||
out = xops.memory_efficient_attention_forward(q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(B, N, -1)
|
||||
out, _ = self.proj(out)
|
||||
return out
|
||||
|
||||
|
||||
class InternSdpaAttention(nn.Module):
|
||||
@ -362,7 +374,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
|
||||
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
|
||||
if (num_heads + num_dummy_heads) % tp_size == 0:
|
||||
return InternParallelAttention(config,
|
||||
quant_config=quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
|
||||
@ -187,7 +187,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend()
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||
}:
|
||||
|
||||
@ -260,7 +260,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.proj")
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend()
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||
}:
|
||||
|
||||
@ -6,11 +6,12 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionConfig
|
||||
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
@ -27,11 +28,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
|
||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
@ -254,7 +251,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class SiglipParallelAttention(nn.Module):
|
||||
class SiglipAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -293,6 +290,11 @@ class SiglipParallelAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"SIGLIP does not support {self.attn_backend} backend now.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -313,11 +315,26 @@ class SiglipParallelAttention(nn.Module):
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
||||
for x in (query_states,
|
||||
key_states,
|
||||
value_states))
|
||||
out = F.scaled_dot_product_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(batch_size, q_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
@ -372,17 +389,11 @@ class SiglipEncoderLayer(nn.Module):
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = SiglipParallelAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
else:
|
||||
self.self_attn = SiglipSdpaAttention(config)
|
||||
|
||||
self.self_attn = SiglipAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -569,10 +580,6 @@ class SiglipVisionModel(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
@ -601,7 +608,7 @@ class SiglipVisionModel(nn.Module):
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
] if self.shard_weight else []
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
@ -587,7 +587,11 @@ class LLMWrapper(nn.Module):
|
||||
return llm(*args, **kwargs)
|
||||
|
||||
|
||||
def get_vit_attn_backend() -> _Backend:
|
||||
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
||||
"""
|
||||
Get the available attention backend for Vision Transformer.
|
||||
"""
|
||||
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
|
||||
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
|
||||
if selected_backend is None:
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
@ -596,7 +600,7 @@ def get_vit_attn_backend() -> _Backend:
|
||||
if selected_backend is None:
|
||||
# For Volta and Turing GPUs, use xformers instead.
|
||||
device_available = current_platform.has_device_capability(80)
|
||||
if device_available:
|
||||
if device_available and support_fa:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
if is_flash_attn_2_available():
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
@ -606,7 +610,8 @@ def get_vit_attn_backend() -> _Backend:
|
||||
"so we use xformers backend instead. You can run "
|
||||
"`pip install flash-attn` to use flash-attention backend.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif current_platform.is_cpu():
|
||||
elif current_platform.is_cpu() or current_platform.is_rocm():
|
||||
# ROCM doesn't support xformers
|
||||
selected_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user