mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:35:43 +08:00
[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
3ff57ebfca
commit
c18e1a3418
@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
@ -21,10 +22,12 @@ class AWQConfig(QuantizationConfig):
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
@ -35,7 +38,8 @@ class AWQConfig(QuantizationConfig):
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point})")
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "awq"
|
||||
@ -61,11 +65,15 @@ class AWQConfig(QuantizationConfig):
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AWQLinearMethod"]:
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQLinearMethod(self)
|
||||
return None
|
||||
|
||||
@ -73,6 +81,10 @@ class AWQConfig(QuantizationConfig):
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ.
|
||||
|
||||
|
||||
@ -122,7 +122,7 @@ def input_processor_for_blip(
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
|
||||
class BlipVisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: BlipVisionConfig):
|
||||
def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BlipVisionConfig,
|
||||
config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -189,11 +190,13 @@ class BlipParallelAttention(nn.Module):
|
||||
self.num_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
)
|
||||
self.projection = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.projection",
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -235,9 +238,12 @@ class BlipParallelAttention(nn.Module):
|
||||
|
||||
class BlipMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@ -246,11 +252,13 @@ class BlipMLP(nn.Module):
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1")
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
@ -262,24 +270,32 @@ class BlipMLP(nn.Module):
|
||||
|
||||
class BlipEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = BlipParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
self.self_attn = BlipParallelAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
else:
|
||||
# Blip doesn't have SDPA attention implemented in transformers
|
||||
# use eager attention instead for cpu backend
|
||||
self.self_attn = BlipAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = BlipMLP(config, quant_config=quant_config)
|
||||
self.mlp = BlipMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
|
||||
config: BlipConfig
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@ -321,8 +340,10 @@ class BlipEncoder(nn.Module):
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
BlipEncoderLayer(config=config, quant_config=quant_config)
|
||||
for _ in range(num_hidden_layers)
|
||||
BlipEncoderLayer(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
|
||||
config_class = BlipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -354,19 +380,24 @@ class BlipVisionModel(nn.Module):
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"The original encoder only has {config.num_hidden_layers} "
|
||||
f"The original encoder only has {num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.encoder.layers)} layers."
|
||||
)
|
||||
elif len(self.encoder.layers) == config.num_hidden_layers:
|
||||
|
||||
# If possible, skip post_layernorm to conserve memory
|
||||
if require_post_norm is None:
|
||||
require_post_norm = len(self.encoder.layers) == num_hidden_layers
|
||||
|
||||
if require_post_norm:
|
||||
self.post_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
else:
|
||||
# post_layernorm is unused when we extract intermediate features
|
||||
# In this case, we can skip it to conserve memory
|
||||
self.post_layernorm = None
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -490,7 +490,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_model = BlipVisionModel(config.vision_config)
|
||||
self.vision_model = BlipVisionModel(config.vision_config, quant_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(
|
||||
torch.zeros(1, config.num_query_tokens,
|
||||
|
||||
@ -192,6 +192,7 @@ class CLIPParallelAttention(nn.Module):
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -211,12 +212,14 @@ class CLIPParallelAttention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -259,20 +262,25 @@ class CLIPParallelAttention(nn.Module):
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1")
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
@ -284,21 +292,29 @@ class CLIPMLP(nn.Module):
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = CLIPParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
self.self_attn = CLIPParallelAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
else:
|
||||
self.self_attn = CLIPSdpaAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
||||
self.mlp = CLIPMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -327,11 +343,15 @@ class CLIPEncoder(nn.Module):
|
||||
config: CLIPConfig
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
@ -339,8 +359,10 @@ class CLIPEncoder(nn.Module):
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
self.layers = nn.ModuleList([
|
||||
CLIPEncoderLayer(config=config, quant_config=quant_config)
|
||||
for _ in range(num_hidden_layers)
|
||||
CLIPEncoderLayer(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
@ -354,11 +376,17 @@ class CLIPEncoder(nn.Module):
|
||||
|
||||
class CLIPVisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
@ -370,19 +398,25 @@ class CLIPVisionTransformer(nn.Module):
|
||||
self.encoder = CLIPEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override)
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"The original encoder only has {config.num_hidden_layers} "
|
||||
f"The original encoder only has {num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.encoder.layers)} layers."
|
||||
)
|
||||
elif len(self.encoder.layers) == config.num_hidden_layers:
|
||||
|
||||
# If possible, skip post_layernorm to conserve memory
|
||||
if require_post_norm is None:
|
||||
require_post_norm = len(self.encoder.layers) == num_hidden_layers
|
||||
|
||||
if require_post_norm:
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
else:
|
||||
# post_layernorm is unused when we extract intermediate features
|
||||
# In this case, we can skip it to conserve memory
|
||||
self.post_layernorm = None
|
||||
|
||||
def forward(
|
||||
@ -405,10 +439,15 @@ class CLIPVisionModel(nn.Module):
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -418,7 +457,10 @@ class CLIPVisionModel(nn.Module):
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override)
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
return self.vision_model(pixel_values)
|
||||
|
||||
@ -113,7 +113,8 @@ class Idefics2VisionAttention(nn.Module):
|
||||
self,
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -130,12 +131,14 @@ class Idefics2VisionAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
@ -178,7 +181,8 @@ class Idefics2VisionMLP(nn.Module):
|
||||
self,
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
@ -187,12 +191,14 @@ class Idefics2VisionMLP(nn.Module):
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -204,13 +210,22 @@ class Idefics2VisionMLP(nn.Module):
|
||||
|
||||
class Idefics2EncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: Idefics2Config):
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Idefics2VisionAttention(config)
|
||||
self.self_attn = Idefics2VisionAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = Idefics2VisionMLP(config)
|
||||
self.mlp = Idefics2VisionMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -245,12 +260,20 @@ class Idefics2Encoder(nn.Module):
|
||||
config: Idefics2Config
|
||||
"""
|
||||
|
||||
def __init__(self, config: Idefics2Config):
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
Idefics2EncoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
Idefics2EncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
@ -275,12 +298,20 @@ class Idefics2Encoder(nn.Module):
|
||||
|
||||
class Idefics2VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self, config: Idefics2VisionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
self.config = config
|
||||
self.embeddings = Idefics2VisionEmbeddings(config)
|
||||
self.encoder = Idefics2Encoder(config)
|
||||
self.encoder = Idefics2Encoder(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
|
||||
@ -137,6 +137,7 @@ class InternParallelAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -165,6 +166,7 @@ class InternParallelAttention(nn.Module):
|
||||
num_dummy_heads + self.num_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
@ -181,6 +183,7 @@ class InternParallelAttention(nn.Module):
|
||||
self.dummy_dim,
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||
@ -284,20 +287,26 @@ class InternSdpaAttention(nn.Module):
|
||||
|
||||
class InternMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1")
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
@ -315,6 +324,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -324,9 +334,12 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
|
||||
self.attn = self._init_attn(config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads)
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
self.mlp = InternMLP(config, quant_config=quant_config)
|
||||
self.mlp = InternMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
@ -343,6 +356,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
num_dummy_heads: int,
|
||||
prefix: str = "",
|
||||
):
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -351,7 +365,8 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
|
||||
return InternParallelAttention(config,
|
||||
quant_config=quant_config,
|
||||
num_dummy_heads=num_dummy_heads)
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=prefix)
|
||||
|
||||
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
|
||||
|
||||
@ -377,6 +392,7 @@ class InternVisionEncoder(nn.Module):
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -390,8 +406,9 @@ class InternVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([
|
||||
InternVisionEncoderLayer(config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads)
|
||||
for _ in range(num_hidden_layers)
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
@ -412,7 +429,8 @@ class InternVisionModel(nn.Module):
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@ -423,6 +441,7 @@ class InternVisionModel(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
||||
@ -19,7 +19,8 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization import (AWQConfig,
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
||||
InternVisionPatchModel)
|
||||
@ -418,11 +419,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self._patch_quant_config(config, quant_config)
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
self.select_layer = config.select_layer
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
@ -430,7 +431,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
self.llm_arch_name = config.text_config.architectures[0]
|
||||
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
|
||||
self.vision_model = self._init_vision_model(config, self.is_mono)
|
||||
self.vision_model = self._init_vision_model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
is_mono=self.is_mono,
|
||||
prefix="vision_model",
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
@ -441,6 +447,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _patch_quant_config(self, config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig):
|
||||
# the awq models from OpenGVLab missing `modules_to_not_convert`
|
||||
# patch the quant_config to add `modules_to_not_convert` back
|
||||
if isinstance(quant_config, AWQConfig):
|
||||
text_config = config.text_config
|
||||
llm_quant_config = getattr(text_config, "quantization_config",
|
||||
None)
|
||||
if (not quant_config.modules_to_not_convert) and \
|
||||
(llm_quant_config is not None):
|
||||
quant_config.modules_to_not_convert.append("vision_model")
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
@ -448,17 +466,28 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _init_vision_model(self, config: PretrainedConfig, is_mono: bool):
|
||||
def _init_vision_model(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
is_mono: bool,
|
||||
prefix: str,
|
||||
):
|
||||
if not is_mono:
|
||||
vision_feature_layer = self.select_layer
|
||||
vision_feature_layer = config.select_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
|
||||
return InternVisionModel(
|
||||
config.vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers)
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
prefix=prefix,
|
||||
)
|
||||
else:
|
||||
return InternVisionPatchModel(config.vision_config)
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
|
||||
SiglipVisionConfig)
|
||||
PretrainedConfig, SiglipVisionConfig)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
@ -200,7 +200,17 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _init_vision_tower(hf_config: LlavaConfig):
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: PretrainedConfig
|
||||
vision_feature_layer: int
|
||||
|
||||
|
||||
def init_vision_tower_for_llava(
|
||||
hf_config: LlavaLikeConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
):
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
@ -214,16 +224,24 @@ def _init_vision_tower(hf_config: LlavaConfig):
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
# TODO: allow layer override?
|
||||
return PixtralHFVisionModel(vision_config)
|
||||
return PixtralHFVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
@ -255,7 +273,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
config.projector_hidden_act = "gelu"
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = _init_vision_tower(config)
|
||||
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
|
||||
@ -26,7 +26,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import LlavaMultiModalProjector
|
||||
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
@ -259,32 +259,6 @@ def input_processor_for_llava_next(ctx: InputContext,
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _init_vision_tower(hf_config: LlavaNextConfig):
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
vision_feature_layer = hf_config.vision_feature_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
@ -303,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = _init_vision_tower(config)
|
||||
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
|
||||
@ -26,6 +26,7 @@ from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip)
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
@ -179,32 +180,6 @@ def input_processor_for_llava_next_video(ctx: InputContext,
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _init_vision_tower(hf_config: LlavaNextVideoConfig):
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
vision_feature_layer = hf_config.vision_feature_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
# adopted from transformers modeling_llava_next_video.py
|
||||
class LlavaNextVideoPooler(nn.Module):
|
||||
|
||||
@ -281,7 +256,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
self.vision_tower = _init_vision_tower(config)
|
||||
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
|
||||
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
|
||||
@ -31,6 +31,7 @@ from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
|
||||
dummy_video_for_clip, get_clip_image_feature_size,
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
|
||||
dummy_video_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
@ -357,32 +358,6 @@ def input_processor_for_llava_onevision(ctx: InputContext,
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _init_vision_tower(hf_config: LlavaOnevisionConfig):
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
vision_feature_layer = hf_config.vision_feature_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, config: LlavaOnevisionConfig):
|
||||
@ -425,7 +400,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
self.vision_tower = _init_vision_tower(config)
|
||||
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
|
||||
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
@ -395,7 +395,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(config, cache_config, quant_config)
|
||||
self.vpm = self.init_vision_module()
|
||||
self.vpm = self.init_vision_module(config, quant_config)
|
||||
param_dtype = torch.get_default_dtype()
|
||||
self.vpm.to(dtype=param_dtype)
|
||||
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
|
||||
@ -647,7 +647,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
def init_vision_module(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||
@ -693,7 +697,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
def init_vision_module(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
) -> nn.Module:
|
||||
# TODO :refactor this vision model
|
||||
try:
|
||||
import timm
|
||||
@ -817,8 +825,13 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
model = Idefics2VisionTransformer(self.config.vision_config)
|
||||
def init_vision_module(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
) -> nn.Module:
|
||||
model = Idefics2VisionTransformer(config.vision_config,
|
||||
quant_config=quant_config)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
@ -929,9 +942,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
|
||||
model = Idefics2VisionTransformer(self.config.vision_config)
|
||||
def init_vision_module(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
) -> nn.Module:
|
||||
model = Idefics2VisionTransformer(config.vision_config,
|
||||
quant_config=quant_config)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
|
||||
@ -379,9 +379,13 @@ class MllamaVisionSdpaAttention(nn.Module):
|
||||
|
||||
class MllamaVisionEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
is_gated: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
is_gated: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -390,7 +394,9 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.self_attn = MllamaVisionSdpaAttention(config)
|
||||
self.mlp = CLIPMLP(config)
|
||||
self.mlp = CLIPMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(self.hidden_size,
|
||||
eps=config.norm_eps)
|
||||
@ -427,16 +433,23 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
|
||||
class MllamaVisionEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
num_layers=32,
|
||||
is_gated=False,
|
||||
output_hidden_states=None):
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
num_layers: int = 32,
|
||||
is_gated: bool = False,
|
||||
output_hidden_states=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
MllamaVisionEncoderLayer(config, is_gated)
|
||||
for _ in range(num_layers)
|
||||
MllamaVisionEncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
is_gated=is_gated,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(num_layers)
|
||||
])
|
||||
self.output_hidden_states = output_hidden_states or []
|
||||
|
||||
@ -463,8 +476,14 @@ class MllamaVisionEncoder(nn.Module):
|
||||
|
||||
class MllamaVisionModel(nn.Module):
|
||||
|
||||
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
self.max_num_tiles = config.max_num_tiles
|
||||
@ -500,12 +519,19 @@ class MllamaVisionModel(nn.Module):
|
||||
# encoders
|
||||
self.transformer = MllamaVisionEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
config.num_hidden_layers,
|
||||
is_gated=False,
|
||||
output_hidden_states=config.intermediate_layers_indices)
|
||||
self.global_transformer = MllamaVisionEncoder(config,
|
||||
config.num_global_layers,
|
||||
is_gated=True)
|
||||
output_hidden_states=config.intermediate_layers_indices,
|
||||
prefix=f"{prefix}.transformer",
|
||||
)
|
||||
self.global_transformer = MllamaVisionEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
config.num_global_layers,
|
||||
is_gated=True,
|
||||
prefix=f"{prefix}.global_transformer",
|
||||
)
|
||||
|
||||
def apply_class_embedding(self,
|
||||
hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
@ -648,6 +674,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
config: Optional[config_mllama.MllamaTextConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -673,6 +700,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.num_key_value_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
@ -680,6 +708,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
|
||||
# use huggingface's instead
|
||||
@ -692,6 +721,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
self.num_local_key_value_heads,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -791,15 +821,21 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
"""Cross-attention transformer block with tanh-gated attention
|
||||
and feedforward."""
|
||||
|
||||
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig]) \
|
||||
-> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.cross_attn = MllamaTextCrossAttention(
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.cross_attn",
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -811,6 +847,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -854,10 +891,15 @@ class MllamaTextModel(nn.Module):
|
||||
config_class = config_mllama.MllamaTextConfig
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: config_mllama.MllamaTextConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig]):
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
|
||||
@ -869,13 +911,20 @@ class MllamaTextModel(nn.Module):
|
||||
if layer_idx in self.cross_attention_layers:
|
||||
layers.append(
|
||||
MllamaCrossAttentionDecoderLayer(
|
||||
config, layer_idx, quant_config=quant_config))
|
||||
config,
|
||||
layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
))
|
||||
else:
|
||||
# TODO: force LlamaDecoderLayer to config.attention_bias=False
|
||||
layers.append(
|
||||
LlamaDecoderLayer(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
LlamaDecoderLayer(
|
||||
config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
))
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -932,12 +981,19 @@ class MllamaForCausalLM(nn.Module):
|
||||
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
|
||||
]
|
||||
|
||||
def __init__(self, config: config_mllama.MllamaTextConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig]):
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model = MllamaTextModel(config, cache_config, quant_config)
|
||||
self.model = MllamaTextModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.model")
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@ -994,11 +1050,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
self.image_size = config.vision_config.image_size
|
||||
|
||||
self.vision_model = MllamaVisionModel(config.vision_config)
|
||||
self.vision_model = MllamaVisionModel(config.vision_config,
|
||||
quant_config)
|
||||
self.language_model = MllamaForCausalLM(
|
||||
config.text_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix="language_model",
|
||||
)
|
||||
self.multi_modal_projector = nn.Linear(
|
||||
config.vision_config.vision_output_dim,
|
||||
|
||||
@ -4,10 +4,13 @@
|
||||
# Copyright (c) 2024 NVIDIA
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
@ -56,9 +59,11 @@ class NVLM_D_Model(InternVLChatModel):
|
||||
)
|
||||
|
||||
def _init_vision_model(self, config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
num_hidden_layers: int):
|
||||
# We added additional dummy heads to the original num of heads to make
|
||||
# the number of heads divisible by 8.
|
||||
return InternVisionModel(config.vision_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
num_dummy_heads=7)
|
||||
|
||||
@ -142,7 +142,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.vision_tower = SiglipVisionModel(config.vision_config)
|
||||
self.vision_tower = SiglipVisionModel(config.vision_config,
|
||||
quant_config)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
projection_dim=config.vision_config.projection_dim)
|
||||
|
||||
@ -70,7 +70,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
projection_dim=768)
|
||||
|
||||
|
||||
def _init_img_processor(hf_config: PretrainedConfig):
|
||||
def _init_img_processor(hf_config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig]):
|
||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||
layer_idx = hf_config.img_processor.get('layer_idx', -2)
|
||||
|
||||
@ -82,7 +83,10 @@ def _init_img_processor(hf_config: PretrainedConfig):
|
||||
num_hidden_layers = layer_idx + 1
|
||||
|
||||
img_processor = CLIPVisionModel(
|
||||
clip_config, num_hidden_layers_override=num_hidden_layers)
|
||||
clip_config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
|
||||
return img_processor
|
||||
|
||||
@ -148,14 +152,15 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
"""Phi3 Image embedding with HD transform."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
def __init__(self, config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig]) -> None:
|
||||
super().__init__()
|
||||
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
|
||||
self.img_processor = _init_img_processor(config)
|
||||
self.img_processor = _init_img_processor(config, quant_config)
|
||||
|
||||
image_dim_out = config.img_processor['image_dim_out']
|
||||
self.num_img_tokens = config.img_processor['num_img_tokens']
|
||||
@ -535,7 +540,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
# TODO: Optionally initializes this for supporting input embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config)
|
||||
|
||||
self.language_model = LlamaForCausalLM(config, cache_config,
|
||||
quant_config)
|
||||
|
||||
@ -767,9 +767,17 @@ def input_processor_for_pixtral_hf(
|
||||
|
||||
class PixtralHFMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: PixtralVisionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert config.intermediate_size is not None
|
||||
# TODO: Use quant_config and prefix after optimizing this
|
||||
self.gate_proj = nn.Linear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=False)
|
||||
@ -787,8 +795,15 @@ class PixtralHFMLP(nn.Module):
|
||||
|
||||
class PixtralHFAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: PixtralVisionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
assert not config.hidden_size % config.num_attention_heads
|
||||
self.n_heads = config.num_attention_heads
|
||||
@ -796,6 +811,7 @@ class PixtralHFAttention(nn.Module):
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
# TODO: Use quant_config and prefix after optimizing this
|
||||
self.q_proj = nn.Linear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
@ -840,11 +856,22 @@ class PixtralHFAttention(nn.Module):
|
||||
|
||||
class PixtralHFTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: PixtralVisionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.attention = PixtralHFAttention(config)
|
||||
self.feed_forward = PixtralHFMLP(config)
|
||||
self.attention = PixtralHFAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.feed_forward = PixtralHFMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward")
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
def forward(
|
||||
@ -864,11 +891,27 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
|
||||
class PixtralHFTransformer(nn.Module):
|
||||
|
||||
def __init__(self, config: PixtralVisionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(config.num_hidden_layers):
|
||||
self.layers.append(PixtralHFTransformerBlock(config))
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
PixtralHFTransformerBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -883,7 +926,15 @@ class PixtralHFTransformer(nn.Module):
|
||||
|
||||
class PixtralHFVisionModel(nn.Module):
|
||||
|
||||
def __init__(self, config: PixtralVisionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@ -895,7 +946,24 @@ class PixtralHFVisionModel(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.transformer = PixtralHFTransformer(config)
|
||||
self.transformer = PixtralHFTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.transformer",
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if len(self.transformer.layers) > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"The original encoder only has {num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.transformer.layers)} "
|
||||
"layers.")
|
||||
|
||||
if require_post_norm is True:
|
||||
msg = "PixtralHFVisionModel does not have post-layernorm"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.dtype = next(self.parameters()).dtype
|
||||
self.device = next(self.parameters()).device
|
||||
self.patch_positional_embedding = PixtralRotaryEmbedding(
|
||||
|
||||
@ -248,8 +248,10 @@ class SiglipParallelAttention(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -266,12 +268,14 @@ class SiglipParallelAttention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -314,8 +318,10 @@ class SiglipMLP(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
@ -326,11 +332,13 @@ class SiglipMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -346,15 +354,20 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = SiglipParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
self.self_attn = SiglipParallelAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
else:
|
||||
self.self_attn = SiglipSdpaAttention(config)
|
||||
|
||||
@ -363,6 +376,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
@ -392,8 +406,10 @@ class SiglipEncoder(nn.Module):
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
@ -402,8 +418,10 @@ class SiglipEncoder(nn.Module):
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
SiglipEncoderLayer(config, quant_config=quant_config)
|
||||
for _ in range(num_hidden_layers)
|
||||
SiglipEncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
@ -424,7 +442,8 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||
@ -433,7 +452,9 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
config.hidden_size, config.num_attention_heads, batch_first=True)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config=config, quant_config=quant_config)
|
||||
self.mlp = SiglipMLP(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = hidden_state.shape[0]
|
||||
@ -454,9 +475,13 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
):
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
@ -465,26 +490,34 @@ class SiglipVisionTransformer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"The original encoder only has {config.num_hidden_layers} "
|
||||
f"The original encoder only has {num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.encoder.layers)} layers."
|
||||
)
|
||||
elif len(self.encoder.layers) == config.num_hidden_layers:
|
||||
|
||||
# If possible, skip post_layernorm to conserve memory
|
||||
if require_post_norm is None:
|
||||
require_post_norm = len(self.encoder.layers) == num_hidden_layers
|
||||
|
||||
if require_post_norm:
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
else:
|
||||
# post_layernorm is unused when we extract intermediate features
|
||||
# In this case, we can skip it to conserve memory
|
||||
self.post_layernorm = None
|
||||
|
||||
self.use_head = (True if not hasattr(config, "vision_use_head") else
|
||||
config.vision_use_head)
|
||||
if self.use_head:
|
||||
self.head = SiglipMultiheadAttentionPoolingHead(
|
||||
config=config, quant_config=quant_config)
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.head",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -517,8 +550,11 @@ class SiglipVisionModel(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
):
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
@ -529,6 +565,8 @@ class SiglipVisionModel(nn.Module):
|
||||
config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user