mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
[VLM] Minor space optimization for ClipVisionModel (#6436)
This commit is contained in:
parent
22e79ee8f3
commit
6ae1597ddf
@ -214,22 +214,24 @@ class CLIPEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
self.layers = nn.ModuleList([
|
||||
CLIPEncoderLayer(config=config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
for _ in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_feature_layer: int = -1):
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
|
||||
# Encoder forward pass only up to the required layer
|
||||
num_layer = len(self.layers) + vision_feature_layer + 1
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers[:num_layer]:
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@ -239,7 +241,8 @@ class CLIPVisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
@ -249,18 +252,19 @@ class CLIPVisionTransformer(nn.Module):
|
||||
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
|
||||
# the original transformers code and name of the model weights.
|
||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = CLIPEncoder(config=config, quant_config=quant_config)
|
||||
self.encoder = CLIPEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
vision_feature_layer: int = -1,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
hidden_states = self.encoder(inputs_embeds=hidden_states,
|
||||
vision_feature_layer=vision_feature_layer)
|
||||
hidden_states = self.encoder(inputs_embeds=hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -272,17 +276,17 @@ class CLIPVisionModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVisionTransformer(config=config,
|
||||
quant_config=quant_config)
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override)
|
||||
|
||||
def forward(self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
vision_feature_layer: int = -1):
|
||||
def forward(self, pixel_values: Optional[torch.Tensor] = None):
|
||||
|
||||
return self.vision_model(pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer)
|
||||
return self.vision_model(pixel_values=pixel_values)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
|
||||
@ -128,8 +128,17 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
vision_feature_layer = config.vision_feature_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
self.vision_tower = CLIPVisionModel(
|
||||
config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
@ -193,8 +202,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values,
|
||||
self.config.vision_feature_layer)
|
||||
image_features = vision_tower(pixel_values)
|
||||
|
||||
return self._select_image_features(
|
||||
image_features,
|
||||
@ -333,7 +341,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
break
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
if use_default_weight_loading and name in params_dict:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -222,8 +222,17 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
vision_feature_layer = config.vision_feature_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
self.vision_tower = CLIPVisionModel(
|
||||
config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
@ -312,8 +321,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values,
|
||||
self.config.vision_feature_layer)
|
||||
image_features = vision_tower(pixel_values)
|
||||
|
||||
return self._select_image_features(
|
||||
image_features,
|
||||
@ -561,7 +569,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
break
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
if use_default_weight_loading and name in params_dict:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@ -80,13 +80,11 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
|
||||
def get_img_features(self,
|
||||
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
||||
LAYER_IDX = self.layer_idx
|
||||
TYPE_FEATURE = self.type_feature
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the img_processor
|
||||
img_feature = self.img_processor(img_embeds,
|
||||
vision_feature_layer=LAYER_IDX)
|
||||
img_feature = self.img_processor(img_embeds)
|
||||
|
||||
if TYPE_FEATURE == "patch":
|
||||
patch_feature = img_feature[:, 1:]
|
||||
@ -111,7 +109,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
config, 'n_embd') else config.hidden_size
|
||||
|
||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||
self.img_processor = CLIPVisionModel(clip_config)
|
||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||
|
||||
# Initialize the CLIP only up to the required feature layer
|
||||
if self.layer_idx < 0:
|
||||
num_hidden_layers = clip_config.num_hidden_layers + \
|
||||
self.layer_idx + 1
|
||||
else:
|
||||
num_hidden_layers = self.layer_idx + 1
|
||||
|
||||
self.img_processor = CLIPVisionModel(
|
||||
clip_config, num_hidden_layers_override=num_hidden_layers)
|
||||
image_dim_out = config.img_processor['image_dim_out']
|
||||
self.num_img_tokens = config.img_processor['num_img_tokens']
|
||||
|
||||
@ -142,8 +150,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
self.img_projection = nn.Sequential(*layers)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor,
|
||||
@ -588,7 +594,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user