[Model] Consolidate ViTs attention implementation without mask (#10893)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-12-05 02:11:08 +08:00 committed by GitHub
parent 01d079fd8e
commit 10398b4706
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 107 additions and 224 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
@ -168,6 +169,68 @@ class Attention(nn.Module):
return s return s
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS
self.attn_backend = attn_backend if attn_backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
} else _Backend.TORCH_SDPA
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA2
bsz, q_len, _ = query.size()
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query,
key,
value,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)
out = out.transpose(1, 2)
return out.view(bsz, q_len, -1)
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,

View File

@ -4,11 +4,10 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.selector import _Backend from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.inputs import DecoderOnlyInputs, token_inputs
@ -22,8 +21,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
@ -205,11 +202,8 @@ class BlipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
# Detect attention implementation. self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.attn_backend = get_vit_attn_backend(support_fa=False) self.head_dim, self.scale)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"BLIP does not support {self.attn_backend} backend now.")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, return tensor.view(bsz, seq_len, self.num_heads,
@ -220,41 +214,10 @@ class BlipAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
): ):
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv(hidden_states) qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len, out = self.attn(query_states, key_states, value_states)
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out) attn_output, _ = self.projection(out)
return attn_output, None return attn_output, None

View File

@ -5,11 +5,10 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from vllm.attention.selector import _Backend from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.inputs import DecoderOnlyInputs, token_inputs
@ -25,8 +24,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs) resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
@ -235,11 +232,8 @@ class CLIPAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
# Detect attention implementation. self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.attn_backend = get_vit_attn_backend(support_fa=False) self.head_dim, self.scale)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"CLIP does not support {self.attn_backend} backend now.")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, return tensor.view(bsz, seq_len, self.num_heads,
@ -250,42 +244,10 @@ class CLIPAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
): ):
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states) qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
out = self.attn(query_states, key_states, value_states)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output, None return attn_output, None

View File

@ -8,6 +8,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -77,27 +78,16 @@ class Attention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
) )
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
self.scale)
self.output_dropout = torch.nn.Dropout(config.dropout_prob) self.output_dropout = torch.nn.Dropout(config.dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, _ = x.shape
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1) q, k, v = qkv.chunk(3, dim=-1)
q = q.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
k = k.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
v = v.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
out = torch.nn.functional.scaled_dot_product_attention(q, out = self.attn(q, k, v)
k, output, _ = self.dense(out)
v,
attn_mask=None,
dropout_p=0.,
is_causal=False)
output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
output = self.output_dropout(output) output = self.output_dropout(output)
return output return output

View File

@ -21,8 +21,8 @@ import torch
from torch import nn from torch import nn
from transformers.models.idefics2.configuration_idefics2 import ( from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig) Idefics2Config, Idefics2VisionConfig)
from xformers import ops as xops
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
) )
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.is_causal = False self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv, _ = self.qkv_proj( qkv, _ = self.qkv_proj(
hidden_states hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1) query_states, key_states, value_states = qkv.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len, out = self.attn(query_states, key_states, value_states)
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out = xops.memory_efficient_attention_forward(
query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale,
)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output return attn_output

View File

@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.selector import _Backend from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
@ -25,8 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .utils import get_vit_attn_backend
NORM2FN = { NORM2FN = {
'rms_norm': RMSNorm, 'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm, 'layer_norm': nn.LayerNorm,
@ -183,10 +181,8 @@ class InternParallelAttention(nn.Module):
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
) )
self.attn_backend = get_vit_attn_backend(support_fa=False) self.attn = MultiHeadAttention(self.num_heads_per_partition,
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: self.head_dim, self.scale)
raise RuntimeError(
f"InternViT does not support {self.attn_backend} backend now.")
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1: if self.tp_size > 1:
@ -209,23 +205,7 @@ class InternParallelAttention(nn.Module):
if self.qk_normalization: if self.qk_normalization:
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim) out = self.attn(q, k, v)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(q,
k,
v,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(1, 2)
out = out.view(B, N, -1)
out, _ = self.proj(out) out, _ = self.proj(out)
return out return out

View File

@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.mlp1 = self._init_mlp1(config) self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None self.img_context_token_id = None
self.visual_token_mask = None
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return image_embeds return image_embeds
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.is_mono: if self.is_mono:
visual_token_mask = ( self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1) input_ids == self.img_context_token_id).reshape(-1, 1)
else: else:
visual_token_mask = None self.visual_token_mask = None
return visual_token_mask
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
assert self.img_context_token_id is not None assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.img_context_token_id) self.img_context_token_id)
@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
**kwargs: object, **kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]: ) -> Union[SamplerOutput, IntermediateTensors]:
visual_token_mask = None
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None input_ids = None
inputs_embeds = None inputs_embeds = None
@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"intermediate_tensors": intermediate_tensors, "intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds, "inputs_embeds": inputs_embeds,
} }
if self.img_context_token_id is not None:
visual_token_mask = self._get_visual_token_mask(input_ids)
# We always overwrite it back to None after computing visual token if self.visual_token_mask is not None:
# mask so that this doesn't need to depend on encoder output # overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
self.img_context_token_id = None self.img_context_token_id = None
if self.is_mono:
forward_kwargs.update({"visual_token_mask": visual_token_mask})
hidden_states = self.language_model.model(**forward_kwargs) hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states return hidden_states

View File

@ -13,6 +13,7 @@ from torch.nn import functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -38,14 +39,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
@ -188,13 +187,11 @@ class MultiHeadDotProductAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
) )
# Detect attention implementation. self.scale = self.head_dim**-0.5
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) self.attn = MultiHeadAttention(self.num_heads,
if self.attn_backend not in { self.head_dim,
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS self.scale,
}: num_kv_heads=self.num_kv_heads)
raise RuntimeError(
f"Molmo does not support {self.attn_backend} backend now.")
def forward(self, def forward(self,
inputs_q: torch.Tensor, inputs_q: torch.Tensor,
@ -210,25 +207,8 @@ class MultiHeadDotProductAttention(nn.Module):
xq, _ = self.wq(inputs_q) xq, _ = self.wq(inputs_q)
xk, _ = self.wk(inputs_k) xk, _ = self.wk(inputs_k)
xv, _ = self.wv(inputs_v) xv, _ = self.wv(inputs_v)
q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim)
kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim)
xq = xq.view(*q_shape)
xk = xk.view(*kv_shape)
xv = xv.view(*kv_shape)
if self.attn_backend == _Backend.FLASH_ATTN: output = self.attn(xq, xk, xv)
from flash_attn import flash_attn_func
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
elif self.attn_backend == _Backend.TORCH_SDPA:
xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
for x in (xq, xk, xv))
output = F.scaled_dot_product_attention(xq, xk, xv)
output = rearrange(output, "b h s d -> b s h d ")
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
output = rearrange(output, "b s h d -> b s (h d)").contiguous()
output, _ = self.wo(output) output, _ = self.wo(output)
return output return output

View File

@ -6,12 +6,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import SiglipVisionConfig from transformers import SiglipVisionConfig
from vllm.attention.selector import _Backend from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.inputs import DecoderOnlyInputs, token_inputs
@ -29,8 +28,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs) resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible # Since interpolation is applied, the image size need not be divisible
@ -291,52 +288,18 @@ class SiglipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn_backend = get_vit_attn_backend(support_fa=False) self.attn = MultiHeadAttention(self.num_heads_per_partition,
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: self.head_dim, self.scale)
raise RuntimeError(
f"SIGLIP does not support {self.attn_backend} backend now.")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states) qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len, out = self.attn(query_states, key_states, value_states)
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output, None return attn_output, None