[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061)

This commit is contained in:
Isotr0py 2024-09-03 21:37:52 +08:00 committed by GitHub
parent 0fbc6696c2
commit ec266536b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 157 additions and 44 deletions

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig from transformers import Blip2VisionConfig, BlipVisionConfig
from xformers import ops as xops from transformers.models.blip.modeling_blip import BlipAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -21,6 +21,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
@ -156,7 +162,7 @@ class BlipVisionEmbeddings(nn.Module):
return embeddings return embeddings
class BlipAttention(nn.Module): class BlipParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
@ -224,7 +230,7 @@ class BlipAttention(nn.Module):
out = out.view(bsz, tgt_len, -1) out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out) attn_output, _ = self.projection(out)
return attn_output return attn_output, None
class BlipMLP(nn.Module): class BlipMLP(nn.Module):
@ -261,7 +267,16 @@ class BlipEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.self_attn = BlipAttention(config, quant_config=quant_config) # fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config,
quant_config=quant_config)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config) self.mlp = BlipMLP(config, quant_config=quant_config)
@ -272,7 +287,7 @@ class BlipEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from xformers import ops as xops from transformers.models.clip.modeling_clip import CLIPSdpaAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -22,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
@ -162,7 +168,7 @@ class CLIPVisionEmbeddings(nn.Module):
return embeddings return embeddings
class CLIPAttention(nn.Module): class CLIPParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
@ -231,7 +237,7 @@ class CLIPAttention(nn.Module):
out = out.view(bsz, tgt_len, -1) out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output return attn_output, None
class CLIPMLP(nn.Module): class CLIPMLP(nn.Module):
@ -266,7 +272,13 @@ class CLIPEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.self_attn = CLIPAttention(config, quant_config=quant_config) num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config) self.mlp = CLIPMLP(config, quant_config=quant_config)
@ -278,7 +290,7 @@ class CLIPEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
@ -365,6 +377,10 @@ class CLIPVisionModel(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None): num_hidden_layers_override: Optional[int] = None):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
@ -386,7 +402,7 @@ class CLIPVisionModel(nn.Module):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ] if self.shard_weight else []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)

View File

@ -10,7 +10,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from xformers import ops as xops
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -21,6 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
NORM2FN = { NORM2FN = {
'rms_norm': RMSNorm, 'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm, 'layer_norm': nn.LayerNorm,
@ -81,7 +86,7 @@ class InternVisionEmbeddings(nn.Module):
return embeddings return embeddings
class InternAttention(nn.Module): class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
@ -140,18 +145,67 @@ class InternAttention(nn.Module):
k = self.k_norm.forward_native(k.flatten(-2, k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_) -1)).view(B_, N_, H_, D_)
x = xops.memory_efficient_attention_forward( x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
q,
k,
v,
scale=self.scale,
)
x = x.view(B, N, -1) x = x.view(B, N, -1)
x, _ = self.proj(x) x, _ = self.proj(x)
return x return x
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1)
x = self.proj(x)
return x
class InternMLP(nn.Module): class InternMLP(nn.Module):
def __init__(self, def __init__(self,
@ -187,7 +241,14 @@ class InternVisionEncoderLayer(nn.Module):
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type self.norm_type = config.norm_type
self.attn = InternAttention(config, quant_config=quant_config) # fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.mlp = InternMLP(config, quant_config=quant_config) self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)

View File

@ -307,26 +307,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if key_to_modify in name: if key_to_modify in name:
name = name.replace(key_to_modify, new_key) name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False use_default_weight_loading = False
for (param_name, shard_name, shard_id) in stacked_params_mapping: if "vision" not in name or self.vision_tower.shard_weight:
if shard_name not in name: for (param_name, shard_name,
continue shard_id) in stacked_params_mapping:
name = name.replace(shard_name, param_name) if shard_name not in name:
# Skip loading extra bias for GPTQ models. continue
if name.endswith(".bias") and name not in params_dict: name = name.replace(shard_name, param_name)
continue # Skip loading extra bias for GPTQ models.
param = params_dict[name] if name.endswith(".bias") and name not in params_dict:
weight_loader = param.weight_loader continue
weight_loader(param, loaded_weight, shard_id) param = params_dict[name]
break weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
else: else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True use_default_weight_loading = True
if use_default_weight_loading: if use_default_weight_loading:

View File

@ -9,7 +9,7 @@ import torch
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import SiglipVisionConfig from transformers import SiglipVisionConfig
from xformers import ops as xops from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -26,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible # Since interpolation is applied, the image size need not be divisible
@ -219,7 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings return embeddings
class SiglipAttention(nn.Module): class SiglipParallelAttention(nn.Module):
def __init__( def __init__(
self, self,
@ -282,7 +288,7 @@ class SiglipAttention(nn.Module):
out = out.view(batch_size, q_len, -1) out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output return attn_output, None
class SiglipMLP(nn.Module): class SiglipMLP(nn.Module):
@ -327,7 +333,14 @@ class SiglipEncoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config, quant_config=quant_config) num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = SiglipSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
@ -344,7 +357,7 @@ class SiglipEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
@ -476,6 +489,10 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
): ):
super().__init__() super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
config, config,
quant_config, quant_config,