mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 01:42:16 +08:00
[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061)
This commit is contained in:
parent
0fbc6696c2
commit
ec266536b7
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||||
from xformers import ops as xops
|
from transformers.models.blip.modeling_blip import BlipAttention
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
@ -21,6 +21,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
|||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens)
|
||||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers import ops as xops
|
||||||
|
USE_XFORMERS_OPS = True
|
||||||
|
except ImportError:
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
|
|
||||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||||
assert image_size % patch_size == 0
|
assert image_size % patch_size == 0
|
||||||
@ -156,7 +162,7 @@ class BlipVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class BlipAttention(nn.Module):
|
class BlipParallelAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -224,7 +230,7 @@ class BlipAttention(nn.Module):
|
|||||||
out = out.view(bsz, tgt_len, -1)
|
out = out.view(bsz, tgt_len, -1)
|
||||||
attn_output, _ = self.projection(out)
|
attn_output, _ = self.projection(out)
|
||||||
|
|
||||||
return attn_output
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
class BlipMLP(nn.Module):
|
class BlipMLP(nn.Module):
|
||||||
@ -261,7 +267,16 @@ class BlipEncoderLayer(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.self_attn = BlipAttention(config, quant_config=quant_config)
|
# fallback to sdpa attention if tp unavailable
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||||
|
self.self_attn = BlipParallelAttention(config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
else:
|
||||||
|
# Blip doesn't have SDPA attention implemented in transformers
|
||||||
|
# use eager attention instead for cpu backend
|
||||||
|
self.self_attn = BlipAttention(config)
|
||||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.mlp = BlipMLP(config, quant_config=quant_config)
|
self.mlp = BlipMLP(config, quant_config=quant_config)
|
||||||
@ -272,7 +287,7 @@ class BlipEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPVisionConfig
|
from transformers import CLIPVisionConfig
|
||||||
from xformers import ops as xops
|
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
@ -22,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
|||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens)
|
||||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers import ops as xops
|
||||||
|
USE_XFORMERS_OPS = True
|
||||||
|
except ImportError:
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
|
|
||||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||||
assert image_size % patch_size == 0
|
assert image_size % patch_size == 0
|
||||||
@ -162,7 +168,7 @@ class CLIPVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class CLIPAttention(nn.Module):
|
class CLIPParallelAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -231,7 +237,7 @@ class CLIPAttention(nn.Module):
|
|||||||
out = out.view(bsz, tgt_len, -1)
|
out = out.view(bsz, tgt_len, -1)
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
|
|
||||||
return attn_output
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
class CLIPMLP(nn.Module):
|
class CLIPMLP(nn.Module):
|
||||||
@ -266,7 +272,13 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.self_attn = CLIPAttention(config, quant_config=quant_config)
|
num_heads = config.num_attention_heads
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||||
|
self.self_attn = CLIPParallelAttention(config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
else:
|
||||||
|
self.self_attn = CLIPSdpaAttention(config)
|
||||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
||||||
@ -278,7 +290,7 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -365,6 +377,10 @@ class CLIPVisionModel(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
num_hidden_layers_override: Optional[int] = None):
|
num_hidden_layers_override: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||||
|
|
||||||
self.vision_model = CLIPVisionTransformer(
|
self.vision_model = CLIPVisionTransformer(
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@ -386,7 +402,7 @@ class CLIPVisionModel(nn.Module):
|
|||||||
("qkv_proj", "q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
("qkv_proj", "k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
] if self.shard_weight else []
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
layer_count = len(self.vision_model.encoder.layers)
|
layer_count = len(self.vision_model.encoder.layers)
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from xformers import ops as xops
|
|
||||||
|
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -21,6 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers import ops as xops
|
||||||
|
USE_XFORMERS_OPS = True
|
||||||
|
except ImportError:
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
NORM2FN = {
|
NORM2FN = {
|
||||||
'rms_norm': RMSNorm,
|
'rms_norm': RMSNorm,
|
||||||
'layer_norm': nn.LayerNorm,
|
'layer_norm': nn.LayerNorm,
|
||||||
@ -81,7 +86,7 @@ class InternVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class InternAttention(nn.Module):
|
class InternParallelAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -140,18 +145,67 @@ class InternAttention(nn.Module):
|
|||||||
k = self.k_norm.forward_native(k.flatten(-2,
|
k = self.k_norm.forward_native(k.flatten(-2,
|
||||||
-1)).view(B_, N_, H_, D_)
|
-1)).view(B_, N_, H_, D_)
|
||||||
|
|
||||||
x = xops.memory_efficient_attention_forward(
|
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
scale=self.scale,
|
|
||||||
)
|
|
||||||
x = x.view(B, N, -1)
|
x = x.view(B, N, -1)
|
||||||
|
|
||||||
x, _ = self.proj(x)
|
x, _ = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InternSdpaAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f'embed_dim must be divisible by num_heads '
|
||||||
|
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
||||||
|
f' {self.num_heads}).')
|
||||||
|
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.qkv = nn.Linear(self.embed_dim,
|
||||||
|
3 * self.embed_dim,
|
||||||
|
bias=config.qkv_bias)
|
||||||
|
|
||||||
|
self.qk_normalization = config.qk_normalization
|
||||||
|
|
||||||
|
if self.qk_normalization:
|
||||||
|
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
q = q.view(B, N, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(B, N, self.num_heads, self.head_dim)
|
||||||
|
v = v.view(B, N, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if self.qk_normalization:
|
||||||
|
B_, N_, H_, D_ = q.shape
|
||||||
|
q = self.q_norm.forward_native(q.flatten(-2,
|
||||||
|
-1)).view(B_, N_, H_, D_)
|
||||||
|
k = self.k_norm.forward_native(k.flatten(-2,
|
||||||
|
-1)).view(B_, N_, H_, D_)
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||||
|
x = x.transpose(1, 2).view(B, N, -1)
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class InternMLP(nn.Module):
|
class InternMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -187,7 +241,14 @@ class InternVisionEncoderLayer(nn.Module):
|
|||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.norm_type = config.norm_type
|
self.norm_type = config.norm_type
|
||||||
|
|
||||||
self.attn = InternAttention(config, quant_config=quant_config)
|
# fallback to sdpa attention if tp unavailable
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||||
|
self.attn = InternParallelAttention(config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
else:
|
||||||
|
self.attn = InternSdpaAttention(config)
|
||||||
self.mlp = InternMLP(config, quant_config=quant_config)
|
self.mlp = InternMLP(config, quant_config=quant_config)
|
||||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
|
|||||||
@ -307,26 +307,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if key_to_modify in name:
|
if key_to_modify in name:
|
||||||
name = name.replace(key_to_modify, new_key)
|
name = name.replace(key_to_modify, new_key)
|
||||||
use_default_weight_loading = False
|
use_default_weight_loading = False
|
||||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
if "vision" not in name or self.vision_tower.shard_weight:
|
||||||
if shard_name not in name:
|
for (param_name, shard_name,
|
||||||
continue
|
shard_id) in stacked_params_mapping:
|
||||||
name = name.replace(shard_name, param_name)
|
if shard_name not in name:
|
||||||
# Skip loading extra bias for GPTQ models.
|
continue
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
name = name.replace(shard_name, param_name)
|
||||||
continue
|
# Skip loading extra bias for GPTQ models.
|
||||||
param = params_dict[name]
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
weight_loader = param.weight_loader
|
continue
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
param = params_dict[name]
|
||||||
break
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# lm_head is not used in vllm as it is tied with
|
||||||
|
# embed_token. To prevent errors, skip loading
|
||||||
|
# lm_head.weight.
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
use_default_weight_loading = True
|
||||||
else:
|
else:
|
||||||
# lm_head is not used in vllm as it is tied with
|
|
||||||
# embed_token. To prevent errors, skip loading
|
|
||||||
# lm_head.weight.
|
|
||||||
if "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
use_default_weight_loading = True
|
use_default_weight_loading = True
|
||||||
|
|
||||||
if use_default_weight_loading:
|
if use_default_weight_loading:
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import SiglipVisionConfig
|
from transformers import SiglipVisionConfig
|
||||||
from xformers import ops as xops
|
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
@ -26,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
|||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens)
|
||||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers import ops as xops
|
||||||
|
USE_XFORMERS_OPS = True
|
||||||
|
except ImportError:
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
|
|
||||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||||
# Since interpolation is applied, the image size need not be divisible
|
# Since interpolation is applied, the image size need not be divisible
|
||||||
@ -219,7 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class SiglipAttention(nn.Module):
|
class SiglipParallelAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -282,7 +288,7 @@ class SiglipAttention(nn.Module):
|
|||||||
out = out.view(batch_size, q_len, -1)
|
out = out.view(batch_size, q_len, -1)
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
|
|
||||||
return attn_output
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
class SiglipMLP(nn.Module):
|
class SiglipMLP(nn.Module):
|
||||||
@ -327,7 +333,14 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = SiglipAttention(config, quant_config=quant_config)
|
num_heads = config.num_attention_heads
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||||
|
self.self_attn = SiglipParallelAttention(config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
else:
|
||||||
|
self.self_attn = SiglipSdpaAttention(config)
|
||||||
|
|
||||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.mlp = SiglipMLP(
|
self.mlp = SiglipMLP(
|
||||||
@ -344,7 +357,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -476,6 +489,10 @@ class SiglipVisionModel(nn.Module):
|
|||||||
num_hidden_layers_override: Optional[int] = None,
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||||
|
|
||||||
self.vision_model = SiglipVisionTransformer(
|
self.vision_model = SiglipVisionTransformer(
|
||||||
config,
|
config,
|
||||||
quant_config,
|
quant_config,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user