mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 02:27:12 +08:00
[Model] Consolidate ViTs attention implementation without mask (#10893)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
01d079fd8e
commit
10398b4706
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata, AttentionType
|
from vllm.attention import AttentionMetadata, AttentionType
|
||||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
@ -168,6 +169,68 @@ class Attention(nn.Module):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
"""Multi-headed attention without any cache, used for ViT."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = scale
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
attn_backend = get_attn_backend(head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype=None,
|
||||||
|
block_size=16,
|
||||||
|
is_attention_free=False)
|
||||||
|
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||||
|
attn_backend = _Backend.XFORMERS
|
||||||
|
|
||||||
|
self.attn_backend = attn_backend if attn_backend in {
|
||||||
|
_Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||||
|
} else _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||||
|
# TODO(Isotr0py): Use existing backend implementations and support FA2
|
||||||
|
bsz, q_len, _ = query.size()
|
||||||
|
kv_len = key.size(1)
|
||||||
|
|
||||||
|
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||||
|
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if self.attn_backend == _Backend.XFORMERS:
|
||||||
|
from xformers import ops as xops
|
||||||
|
|
||||||
|
out = xops.memory_efficient_attention_forward(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
scale=self.scale)
|
||||||
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
|
query, key, value = (x.transpose(1, 2)
|
||||||
|
for x in (query, key, value))
|
||||||
|
out = F.scaled_dot_product_attention(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
scale=self.scale)
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
return out.view(bsz, q_len, -1)
|
||||||
|
|
||||||
|
|
||||||
def unified_attention(
|
def unified_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
|||||||
@ -4,11 +4,10 @@ from typing import Iterable, Optional, Set, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||||
|
|
||||||
from vllm.attention.selector import _Backend
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||||
@ -22,8 +21,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
|||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens)
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
from .utils import get_vit_attn_backend
|
|
||||||
|
|
||||||
|
|
||||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||||
assert image_size % patch_size == 0
|
assert image_size % patch_size == 0
|
||||||
@ -205,11 +202,8 @@ class BlipAttention(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
# Detect attention implementation.
|
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
self.head_dim, self.scale)
|
||||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"BLIP does not support {self.attn_backend} backend now.")
|
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads,
|
return tensor.view(bsz, seq_len, self.num_heads,
|
||||||
@ -220,41 +214,10 @@ class BlipAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
qkv_states, _ = self.qkv(hidden_states)
|
qkv_states, _ = self.qkv(hidden_states)
|
||||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||||
query_states = query_states.view(bsz, tgt_len,
|
out = self.attn(query_states, key_states, value_states)
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
key_states = key_states.view(bsz, tgt_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
value_states = value_states.view(bsz, tgt_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
|
|
||||||
if self.attn_backend == _Backend.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
|
|
||||||
out = xops.memory_efficient_attention_forward(query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
p=self.dropout,
|
|
||||||
scale=self.scale)
|
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
|
||||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
|
||||||
for x in (query_states,
|
|
||||||
key_states,
|
|
||||||
value_states))
|
|
||||||
out = F.scaled_dot_product_attention(query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
dropout_p=self.dropout,
|
|
||||||
scale=self.scale)
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
|
|
||||||
out = out.view(bsz, tgt_len, -1)
|
|
||||||
attn_output, _ = self.projection(out)
|
attn_output, _ = self.projection(out)
|
||||||
|
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|||||||
@ -5,11 +5,10 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPVisionConfig
|
from transformers import CLIPVisionConfig
|
||||||
|
|
||||||
from vllm.attention.selector import _Backend
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||||
@ -25,8 +24,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
|||||||
resolve_visual_encoder_outputs)
|
resolve_visual_encoder_outputs)
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
from .utils import get_vit_attn_backend
|
|
||||||
|
|
||||||
|
|
||||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||||
assert image_size % patch_size == 0
|
assert image_size % patch_size == 0
|
||||||
@ -235,11 +232,8 @@ class CLIPAttention(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
# Detect attention implementation.
|
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
self.head_dim, self.scale)
|
||||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"CLIP does not support {self.attn_backend} backend now.")
|
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads,
|
return tensor.view(bsz, seq_len, self.num_heads,
|
||||||
@ -250,42 +244,10 @@ class CLIPAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
):
|
):
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||||
|
out = self.attn(query_states, key_states, value_states)
|
||||||
query_states = query_states.view(bsz, tgt_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
key_states = key_states.view(bsz, tgt_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
value_states = value_states.view(bsz, tgt_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
|
|
||||||
if self.attn_backend == _Backend.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
|
|
||||||
out = xops.memory_efficient_attention_forward(query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
p=self.dropout,
|
|
||||||
scale=self.scale)
|
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
|
||||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
|
||||||
for x in (query_states,
|
|
||||||
key_states,
|
|
||||||
value_states))
|
|
||||||
out = F.scaled_dot_product_attention(query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
dropout_p=self.dropout,
|
|
||||||
scale=self.scale)
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
|
|
||||||
out = out.view(bsz, tgt_len, -1)
|
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
|
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -77,27 +78,16 @@ class Attention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
|
||||||
|
self.scale)
|
||||||
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, L, _ = x.shape
|
|
||||||
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
|
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
|
||||||
q, k, v = qkv.chunk(3, dim=-1)
|
q, k, v = qkv.chunk(3, dim=-1)
|
||||||
q = q.reshape(B, L, self.num_heads_per_rank,
|
|
||||||
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
|
|
||||||
k = k.reshape(B, L, self.num_heads_per_rank,
|
|
||||||
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
|
|
||||||
v = v.reshape(B, L, self.num_heads_per_rank,
|
|
||||||
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
|
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q,
|
out = self.attn(q, k, v)
|
||||||
k,
|
output, _ = self.dense(out)
|
||||||
v,
|
|
||||||
attn_mask=None,
|
|
||||||
dropout_p=0.,
|
|
||||||
is_causal=False)
|
|
||||||
|
|
||||||
output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
|
|
||||||
output = self.output_dropout(output)
|
output = self.output_dropout(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -21,8 +21,8 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.models.idefics2.configuration_idefics2 import (
|
from transformers.models.idefics2.configuration_idefics2 import (
|
||||||
Idefics2Config, Idefics2VisionConfig)
|
Idefics2Config, Idefics2VisionConfig)
|
||||||
from xformers import ops as xops
|
|
||||||
|
|
||||||
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
self.is_causal = False
|
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||||
|
self.head_dim, self.scale)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch_size, q_len, _ = hidden_states.size()
|
|
||||||
qkv, _ = self.qkv_proj(
|
qkv, _ = self.qkv_proj(
|
||||||
hidden_states
|
hidden_states
|
||||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||||
query_states = query_states.view(batch_size, q_len,
|
out = self.attn(query_states, key_states, value_states)
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
key_states = key_states.view(batch_size, q_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
value_states = value_states.view(batch_size, q_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
# see: https://facebookresearch.github.io/xformers/components/ops.html
|
|
||||||
out = xops.memory_efficient_attention_forward(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
p=self.dropout,
|
|
||||||
scale=self.scale,
|
|
||||||
)
|
|
||||||
out = out.view(batch_size, q_len, -1)
|
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.selector import _Backend
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
@ -25,8 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from .utils import get_vit_attn_backend
|
|
||||||
|
|
||||||
NORM2FN = {
|
NORM2FN = {
|
||||||
'rms_norm': RMSNorm,
|
'rms_norm': RMSNorm,
|
||||||
'layer_norm': nn.LayerNorm,
|
'layer_norm': nn.LayerNorm,
|
||||||
@ -183,10 +181,8 @@ class InternParallelAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
self.head_dim, self.scale)
|
||||||
raise RuntimeError(
|
|
||||||
f"InternViT does not support {self.attn_backend} backend now.")
|
|
||||||
|
|
||||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@ -209,23 +205,7 @@ class InternParallelAttention(nn.Module):
|
|||||||
if self.qk_normalization:
|
if self.qk_normalization:
|
||||||
q, k = self._apply_qk_norm(q, k)
|
q, k = self._apply_qk_norm(q, k)
|
||||||
|
|
||||||
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
out = self.attn(q, k, v)
|
||||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
|
||||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
|
||||||
|
|
||||||
if self.attn_backend == _Backend.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
|
|
||||||
out = xops.memory_efficient_attention_forward(q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
scale=self.scale)
|
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
|
||||||
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
|
|
||||||
out = out.view(B, N, -1)
|
|
||||||
out, _ = self.proj(out)
|
out, _ = self.proj(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.mlp1 = self._init_mlp1(config)
|
self.mlp1 = self._init_mlp1(config)
|
||||||
|
|
||||||
self.img_context_token_id = None
|
self.img_context_token_id = None
|
||||||
|
self.visual_token_mask = None
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
if self.is_mono:
|
if self.is_mono:
|
||||||
visual_token_mask = (
|
self.visual_token_mask = (
|
||||||
input_ids == self.img_context_token_id).reshape(-1, 1)
|
input_ids == self.img_context_token_id).reshape(-1, 1)
|
||||||
else:
|
else:
|
||||||
visual_token_mask = None
|
self.visual_token_mask = None
|
||||||
return visual_token_mask
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
assert self.img_context_token_id is not None
|
assert self.img_context_token_id is not None
|
||||||
|
self._set_visual_token_mask(input_ids)
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
input_ids, inputs_embeds, multimodal_embeddings,
|
||||||
self.img_context_token_id)
|
self.img_context_token_id)
|
||||||
@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> Union[SamplerOutput, IntermediateTensors]:
|
) -> Union[SamplerOutput, IntermediateTensors]:
|
||||||
|
|
||||||
visual_token_mask = None
|
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
"intermediate_tensors": intermediate_tensors,
|
"intermediate_tensors": intermediate_tensors,
|
||||||
"inputs_embeds": inputs_embeds,
|
"inputs_embeds": inputs_embeds,
|
||||||
}
|
}
|
||||||
if self.img_context_token_id is not None:
|
|
||||||
visual_token_mask = self._get_visual_token_mask(input_ids)
|
|
||||||
|
|
||||||
# We always overwrite it back to None after computing visual token
|
if self.visual_token_mask is not None:
|
||||||
# mask so that this doesn't need to depend on encoder output
|
# overwrite visual_token_mask and img_context_token_id back to None,
|
||||||
|
# so that this doesn't need to depend on encoder output
|
||||||
|
forward_kwargs.update(
|
||||||
|
{"visual_token_mask": self.visual_token_mask})
|
||||||
|
self.visual_token_mask = None
|
||||||
self.img_context_token_id = None
|
self.img_context_token_id = None
|
||||||
|
|
||||||
if self.is_mono:
|
|
||||||
forward_kwargs.update({"visual_token_mask": visual_token_mask})
|
|
||||||
|
|
||||||
hidden_states = self.language_model.model(**forward_kwargs)
|
hidden_states = self.language_model.model(**forward_kwargs)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from torch.nn import functional as F
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -38,14 +39,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import NestedTensors
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
from vllm.platforms import _Backend
|
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||||
SequenceData)
|
SequenceData)
|
||||||
from vllm.transformers_utils.processor import get_processor
|
from vllm.transformers_utils.processor import get_processor
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
|
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||||
is_pp_missing_parameter,
|
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -188,13 +187,11 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Detect attention implementation.
|
self.scale = self.head_dim**-0.5
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn = MultiHeadAttention(self.num_heads,
|
||||||
if self.attn_backend not in {
|
self.head_dim,
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
|
self.scale,
|
||||||
}:
|
num_kv_heads=self.num_kv_heads)
|
||||||
raise RuntimeError(
|
|
||||||
f"Molmo does not support {self.attn_backend} backend now.")
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
inputs_q: torch.Tensor,
|
inputs_q: torch.Tensor,
|
||||||
@ -210,25 +207,8 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
xq, _ = self.wq(inputs_q)
|
xq, _ = self.wq(inputs_q)
|
||||||
xk, _ = self.wk(inputs_k)
|
xk, _ = self.wk(inputs_k)
|
||||||
xv, _ = self.wv(inputs_v)
|
xv, _ = self.wv(inputs_v)
|
||||||
q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim)
|
|
||||||
kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim)
|
|
||||||
xq = xq.view(*q_shape)
|
|
||||||
xk = xk.view(*kv_shape)
|
|
||||||
xv = xv.view(*kv_shape)
|
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
output = self.attn(xq, xk, xv)
|
||||||
from flash_attn import flash_attn_func
|
|
||||||
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
|
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
|
||||||
xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
|
|
||||||
for x in (xq, xk, xv))
|
|
||||||
output = F.scaled_dot_product_attention(xq, xk, xv)
|
|
||||||
output = rearrange(output, "b h s d -> b s h d ")
|
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
|
|
||||||
|
|
||||||
output = rearrange(output, "b s h d -> b s (h d)").contiguous()
|
|
||||||
output, _ = self.wo(output)
|
output, _ = self.wo(output)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -6,12 +6,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import SiglipVisionConfig
|
from transformers import SiglipVisionConfig
|
||||||
|
|
||||||
from vllm.attention.selector import _Backend
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||||
@ -29,8 +28,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
|||||||
resolve_visual_encoder_outputs)
|
resolve_visual_encoder_outputs)
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
from .utils import get_vit_attn_backend
|
|
||||||
|
|
||||||
|
|
||||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||||
# Since interpolation is applied, the image size need not be divisible
|
# Since interpolation is applied, the image size need not be divisible
|
||||||
@ -291,52 +288,18 @@ class SiglipAttention(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
self.head_dim, self.scale)
|
||||||
raise RuntimeError(
|
|
||||||
f"SIGLIP does not support {self.attn_backend} backend now.")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
batch_size, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||||
|
|
||||||
query_states = query_states.view(batch_size, q_len,
|
out = self.attn(query_states, key_states, value_states)
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
key_states = key_states.view(batch_size, q_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
value_states = value_states.view(batch_size, q_len,
|
|
||||||
self.num_heads_per_partition,
|
|
||||||
self.head_dim)
|
|
||||||
|
|
||||||
if self.attn_backend == _Backend.XFORMERS:
|
|
||||||
from xformers import ops as xops
|
|
||||||
|
|
||||||
out = xops.memory_efficient_attention_forward(query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
p=self.dropout,
|
|
||||||
scale=self.scale)
|
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
|
||||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
|
||||||
for x in (query_states,
|
|
||||||
key_states,
|
|
||||||
value_states))
|
|
||||||
out = F.scaled_dot_product_attention(query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
dropout_p=self.dropout,
|
|
||||||
scale=self.scale)
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
|
|
||||||
out = out.view(batch_size, q_len, -1)
|
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
|
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user