mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 11:50:56 +08:00
[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061)
This commit is contained in:
parent
0fbc6696c2
commit
ec266536b7
@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
from xformers import ops as xops
|
||||
from transformers.models.blip.modeling_blip import BlipAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -21,6 +21,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
|
||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
@ -156,7 +162,7 @@ class BlipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class BlipAttention(nn.Module):
|
||||
class BlipParallelAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
@ -224,7 +230,7 @@ class BlipAttention(nn.Module):
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
attn_output, _ = self.projection(out)
|
||||
|
||||
return attn_output
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class BlipMLP(nn.Module):
|
||||
@ -261,7 +267,16 @@ class BlipEncoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = BlipAttention(config, quant_config=quant_config)
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = BlipParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
# Blip doesn't have SDPA attention implemented in transformers
|
||||
# use eager attention instead for cpu backend
|
||||
self.self_attn = BlipAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = BlipMLP(config, quant_config=quant_config)
|
||||
@ -272,7 +287,7 @@ class BlipEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from xformers import ops as xops
|
||||
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -22,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
@ -162,7 +168,7 @@ class CLIPVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
class CLIPParallelAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
@ -231,7 +237,7 @@ class CLIPAttention(nn.Module):
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
@ -266,7 +272,13 @@ class CLIPEncoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = CLIPAttention(config, quant_config=quant_config)
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = CLIPParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.self_attn = CLIPSdpaAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
||||
@ -278,7 +290,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@ -365,6 +377,10 @@ class CLIPVisionModel(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
@ -386,7 +402,7 @@ class CLIPVisionModel(nn.Module):
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
] if self.shard_weight else []
|
||||
params_dict = dict(self.named_parameters())
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -21,6 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
NORM2FN = {
|
||||
'rms_norm': RMSNorm,
|
||||
'layer_norm': nn.LayerNorm,
|
||||
@ -81,7 +86,7 @@ class InternVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class InternAttention(nn.Module):
|
||||
class InternParallelAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
@ -140,18 +145,67 @@ class InternAttention(nn.Module):
|
||||
k = self.k_norm.forward_native(k.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
|
||||
x = xops.memory_efficient_attention_forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale,
|
||||
)
|
||||
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
|
||||
x = x.view(B, N, -1)
|
||||
|
||||
x, _ = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class InternSdpaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f'embed_dim must be divisible by num_heads '
|
||||
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
||||
f' {self.num_heads}).')
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim,
|
||||
3 * self.embed_dim,
|
||||
bias=config.qkv_bias)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
q = q.view(B, N, self.num_heads, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads, self.head_dim)
|
||||
|
||||
if self.qk_normalization:
|
||||
B_, N_, H_, D_ = q.shape
|
||||
q = self.q_norm.forward_native(q.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
k = self.k_norm.forward_native(k.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
x = x.transpose(1, 2).view(B, N, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class InternMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -187,7 +241,14 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
|
||||
self.attn = InternAttention(config, quant_config=quant_config)
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.attn = InternParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.attn = InternSdpaAttention(config)
|
||||
self.mlp = InternMLP(config, quant_config=quant_config)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -307,26 +307,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
use_default_weight_loading = False
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if "vision" not in name or self.vision_tower.shard_weight:
|
||||
for (param_name, shard_name,
|
||||
shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# lm_head is not used in vllm as it is tied with
|
||||
# embed_token. To prevent errors, skip loading
|
||||
# lm_head.weight.
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
# lm_head is not used in vllm as it is tied with
|
||||
# embed_token. To prevent errors, skip loading
|
||||
# lm_head.weight.
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
use_default_weight_loading = True
|
||||
|
||||
if use_default_weight_loading:
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionConfig
|
||||
from xformers import ops as xops
|
||||
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -26,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
|
||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
# Since interpolation is applied, the image size need not be divisible
|
||||
@ -219,7 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class SiglipAttention(nn.Module):
|
||||
class SiglipParallelAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -282,7 +288,7 @@ class SiglipAttention(nn.Module):
|
||||
out = out.view(batch_size, q_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class SiglipMLP(nn.Module):
|
||||
@ -327,7 +333,14 @@ class SiglipEncoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.self_attn = SiglipAttention(config, quant_config=quant_config)
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = SiglipParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.self_attn = SiglipSdpaAttention(config)
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -344,7 +357,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@ -476,6 +489,10 @@ class SiglipVisionModel(nn.Module):
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user