[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)

Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung 2024-10-23 19:27:37 +08:00 committed by GitHub
parent 3ff57ebfca
commit c18e1a3418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 551 additions and 253 deletions

View File

@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@ -21,10 +22,12 @@ class AWQConfig(QuantizationConfig):
weight_bits: int,
group_size: int,
zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.modules_to_not_convert = modules_to_not_convert or []
if self.weight_bits != 4:
raise ValueError(
@ -35,7 +38,8 @@ class AWQConfig(QuantizationConfig):
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> str:
return "awq"
@ -61,11 +65,15 @@ class AWQConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQLinearMethod"]:
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearMethod(self)
return None
@ -73,6 +81,10 @@ class AWQConfig(QuantizationConfig):
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.

View File

@ -122,7 +122,7 @@ def input_processor_for_blip(
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):
def __init__(self, config: BlipVisionConfig):
def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
super().__init__()
self.config = config
@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):
def __init__(
self,
config: BlipVisionConfig,
config: Union[BlipVisionConfig, Blip2VisionConfig],
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@ -189,11 +190,13 @@ class BlipParallelAttention(nn.Module):
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.projection",
)
self.tp_size = get_tensor_model_parallel_world_size()
@ -235,9 +238,12 @@ class BlipParallelAttention(nn.Module):
class BlipMLP(nn.Module):
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@ -246,11 +252,13 @@ class BlipMLP(nn.Module):
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
@ -262,24 +270,32 @@ class BlipMLP(nn.Module):
class BlipEncoderLayer(nn.Module):
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
# fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config,
quant_config=quant_config)
self.self_attn = BlipParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
self.mlp = BlipMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
config: BlipConfig
"""
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@ -321,8 +340,10 @@ class BlipEncoder(nn.Module):
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
BlipEncoderLayer(config=config, quant_config=quant_config)
for _ in range(num_hidden_layers)
BlipEncoderLayer(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
config_class = BlipVisionConfig
main_input_name = "pixel_values"
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
@ -354,19 +380,24 @@ class BlipVisionModel(nn.Module):
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:

View File

@ -490,7 +490,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(config.vision_config)
self.vision_model = BlipVisionModel(config.vision_config, quant_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens,

View File

@ -192,6 +192,7 @@ class CLIPParallelAttention(nn.Module):
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@ -211,12 +212,14 @@ class CLIPParallelAttention(nn.Module):
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
@ -259,20 +262,25 @@ class CLIPParallelAttention(nn.Module):
class CLIPMLP(nn.Module):
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
@ -284,21 +292,29 @@ class CLIPMLP(nn.Module):
class CLIPEncoderLayer(nn.Module):
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(config,
quant_config=quant_config)
self.self_attn = CLIPParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
@ -327,11 +343,15 @@ class CLIPEncoder(nn.Module):
config: CLIPConfig
"""
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if num_hidden_layers_override is None:
@ -339,8 +359,10 @@ class CLIPEncoder(nn.Module):
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config, quant_config=quant_config)
for _ in range(num_hidden_layers)
CLIPEncoderLayer(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
@ -354,11 +376,17 @@ class CLIPEncoder(nn.Module):
class CLIPVisionTransformer(nn.Module):
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
embed_dim = config.hidden_size
@ -370,19 +398,25 @@ class CLIPVisionTransformer(nn.Module):
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
def forward(
@ -405,10 +439,15 @@ class CLIPVisionModel(nn.Module):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
@ -418,7 +457,10 @@ class CLIPVisionModel(nn.Module):
self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)

View File

@ -113,7 +113,8 @@ class Idefics2VisionAttention(nn.Module):
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@ -130,12 +131,14 @@ class Idefics2VisionAttention(nn.Module):
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
@ -178,7 +181,8 @@ class Idefics2VisionMLP(nn.Module):
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
@ -187,12 +191,14 @@ class Idefics2VisionMLP(nn.Module):
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -204,13 +210,22 @@ class Idefics2VisionMLP(nn.Module):
class Idefics2EncoderLayer(nn.Module):
def __init__(self, config: Idefics2Config):
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config)
self.self_attn = Idefics2VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config)
self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
@ -245,12 +260,20 @@ class Idefics2Encoder(nn.Module):
config: Idefics2Config
"""
def __init__(self, config: Idefics2Config):
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config)
for _ in range(config.num_hidden_layers)
Idefics2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(
@ -275,12 +298,20 @@ class Idefics2Encoder(nn.Module):
class Idefics2VisionTransformer(nn.Module):
def __init__(self, config: Idefics2VisionConfig):
def __init__(
self,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
embed_dim = config.hidden_size
self.config = config
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config)
self.encoder = Idefics2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)

View File

@ -137,6 +137,7 @@ class InternParallelAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
@ -165,6 +166,7 @@ class InternParallelAttention(nn.Module):
num_dummy_heads + self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.qk_normalization = config.qk_normalization
@ -181,6 +183,7 @@ class InternParallelAttention(nn.Module):
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
@ -284,20 +287,26 @@ class InternSdpaAttention(nn.Module):
class InternMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
@ -315,6 +324,7 @@ class InternVisionEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
@ -324,9 +334,12 @@ class InternVisionEncoderLayer(nn.Module):
self.attn = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads)
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.attn")
self.mlp = InternMLP(config, quant_config=quant_config)
self.mlp = InternMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
@ -343,6 +356,7 @@ class InternVisionEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
prefix: str = "",
):
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
@ -351,7 +365,8 @@ class InternVisionEncoderLayer(nn.Module):
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads)
num_dummy_heads=num_dummy_heads,
prefix=prefix)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
@ -377,6 +392,7 @@ class InternVisionEncoder(nn.Module):
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
):
super().__init__()
@ -390,8 +406,9 @@ class InternVisionEncoder(nn.Module):
self.layers = nn.ModuleList([
InternVisionEncoderLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads)
for _ in range(num_hidden_layers)
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
@ -412,7 +429,8 @@ class InternVisionModel(nn.Module):
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@ -423,6 +441,7 @@ class InternVisionModel(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
)
def get_input_embeddings(self):

View File

@ -19,7 +19,8 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
@ -418,11 +419,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.config = config
self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.select_layer = config.select_layer
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.downsample_ratio = config.downsample_ratio
@ -430,7 +431,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
self.vision_model = self._init_vision_model(config, self.is_mono)
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix="vision_model",
)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
@ -441,6 +447,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
@ -448,17 +466,28 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return Sampler()
def _init_vision_model(self, config: PretrainedConfig, is_mono: bool):
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono:
vision_feature_layer = self.select_layer
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
return InternVisionModel(
config.vision_config,
num_hidden_layers_override=num_hidden_layers)
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
else:
return InternVisionPatchModel(config.vision_config)

View File

@ -1,12 +1,12 @@
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
Tuple, TypedDict, Union)
import torch
import torch.nn as nn
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig)
PretrainedConfig, SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
@ -200,7 +200,17 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaConfig):
class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: int
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
@ -214,16 +224,24 @@ def _init_vision_tower(hf_config: LlavaConfig):
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
)
elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override?
return PixtralHFVisionModel(vision_config)
return PixtralHFVisionModel(
vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@ -255,7 +273,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,

View File

@ -26,7 +26,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
@ -259,32 +259,6 @@ def input_processor_for_llava_next(ctx: InputContext,
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaNextConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@ -303,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(

View File

@ -26,6 +26,7 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
@ -179,32 +180,6 @@ def input_processor_for_llava_next_video(ctx: InputContext,
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaNextVideoConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
# adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module):
@ -281,7 +256,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config)
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,

View File

@ -31,6 +31,7 @@ from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
@ -357,32 +358,6 @@ def input_processor_for_llava_onevision(ctx: InputContext,
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaOnevisionConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig):
@ -425,7 +400,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config)
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)

View File

@ -395,7 +395,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
self.vpm = self.init_vision_module(config, quant_config)
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
@ -647,7 +647,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> nn.Module:
raise NotImplementedError
def init_vision_module(self) -> nn.Module:
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
raise NotImplementedError
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
@ -693,7 +697,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
# TODO :refactor this vision model
try:
import timm
@ -817,8 +825,13 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
model = Idefics2VisionTransformer(self.config.vision_config)
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
@ -929,9 +942,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
model = Idefics2VisionTransformer(self.config.vision_config)
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model

View File

@ -379,9 +379,13 @@ class MllamaVisionSdpaAttention(nn.Module):
class MllamaVisionEncoderLayer(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
is_gated: bool = False):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
is_gated: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -390,7 +394,9 @@ class MllamaVisionEncoderLayer(nn.Module):
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(config)
self.mlp = CLIPMLP(config)
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
@ -427,16 +433,23 @@ class MllamaVisionEncoderLayer(nn.Module):
class MllamaVisionEncoder(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
num_layers=32,
is_gated=False,
output_hidden_states=None):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
num_layers: int = 32,
is_gated: bool = False,
output_hidden_states=None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList([
MllamaVisionEncoderLayer(config, is_gated)
for _ in range(num_layers)
MllamaVisionEncoderLayer(config,
quant_config=quant_config,
is_gated=is_gated,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_layers)
])
self.output_hidden_states = output_hidden_states or []
@ -463,8 +476,14 @@ class MllamaVisionEncoder(nn.Module):
class MllamaVisionModel(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
@ -500,12 +519,19 @@ class MllamaVisionModel(nn.Module):
# encoders
self.transformer = MllamaVisionEncoder(
config,
quant_config,
config.num_hidden_layers,
is_gated=False,
output_hidden_states=config.intermediate_layers_indices)
self.global_transformer = MllamaVisionEncoder(config,
config.num_global_layers,
is_gated=True)
output_hidden_states=config.intermediate_layers_indices,
prefix=f"{prefix}.transformer",
)
self.global_transformer = MllamaVisionEncoder(
config,
quant_config,
config.num_global_layers,
is_gated=True,
prefix=f"{prefix}.global_transformer",
)
def apply_class_embedding(self,
hidden_state: torch.Tensor) -> torch.Tensor:
@ -648,6 +674,7 @@ class MllamaTextCrossAttention(nn.Module):
config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@ -673,6 +700,7 @@ class MllamaTextCrossAttention(nn.Module):
self.num_key_value_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
@ -680,6 +708,7 @@ class MllamaTextCrossAttention(nn.Module):
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
@ -692,6 +721,7 @@ class MllamaTextCrossAttention(nn.Module):
self.head_dim,
self.scaling,
self.num_local_key_value_heads,
prefix=f"{prefix}.attn",
)
def forward(
@ -791,15 +821,21 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
quant_config: Optional[QuantizationConfig]) \
-> None:
def __init__(
self,
config: config_mllama.MllamaTextConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention(
config=config,
layer_idx=layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.cross_attn",
)
self.input_layernorm = RMSNorm(config.hidden_size,
@ -811,6 +847,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@ -854,10 +891,15 @@ class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model"
def __init__(self, config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig]):
def __init__(
self,
config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
@ -869,13 +911,20 @@ class MllamaTextModel(nn.Module):
if layer_idx in self.cross_attention_layers:
layers.append(
MllamaCrossAttentionDecoderLayer(
config, layer_idx, quant_config=quant_config))
config,
layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append(
LlamaDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config))
LlamaDecoderLayer(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
self.layers = nn.ModuleList(layers)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -932,12 +981,19 @@ class MllamaForCausalLM(nn.Module):
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
]
def __init__(self, config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig]):
def __init__(
self,
config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, cache_config, quant_config)
self.model = MllamaTextModel(config,
cache_config,
quant_config,
prefix=f"{prefix}.model")
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
@ -994,11 +1050,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
config.pad_token_id if config.pad_token_id is not None else -1
self.image_size = config.vision_config.image_size
self.vision_model = MllamaVisionModel(config.vision_config)
self.vision_model = MllamaVisionModel(config.vision_config,
quant_config)
self.language_model = MllamaForCausalLM(
config.text_config,
cache_config=cache_config,
quant_config=quant_config,
prefix="language_model",
)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,

View File

@ -4,10 +4,13 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from .intern_vit import InternVisionModel
@ -56,9 +59,11 @@ class NVLM_D_Model(InternVLChatModel):
)
def _init_vision_model(self, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
num_hidden_layers: int):
# We added additional dummy heads to the original num of heads to make
# the number of heads divisible by 8.
return InternVisionModel(config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
num_dummy_heads=7)

View File

@ -142,7 +142,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim)

View File

@ -70,7 +70,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768)
def _init_img_processor(hf_config: PretrainedConfig):
def _init_img_processor(hf_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig]):
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx = hf_config.img_processor.get('layer_idx', -2)
@ -82,7 +83,10 @@ def _init_img_processor(hf_config: PretrainedConfig):
num_hidden_layers = layer_idx + 1
img_processor = CLIPVisionModel(
clip_config, num_hidden_layers_override=num_hidden_layers)
clip_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
)
return img_processor
@ -148,14 +152,15 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""
def __init__(self, config: PretrainedConfig) -> None:
def __init__(self, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig]) -> None:
super().__init__()
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
self.img_processor = _init_img_processor(config)
self.img_processor = _init_img_processor(config, quant_config)
image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens']
@ -535,7 +540,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config)
self.language_model = LlamaForCausalLM(config, cache_config,
quant_config)

View File

@ -767,9 +767,17 @@ def input_processor_for_pixtral_hf(
class PixtralHFMLP(nn.Module):
def __init__(self, config: PixtralVisionConfig):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
assert config.intermediate_size is not None
# TODO: Use quant_config and prefix after optimizing this
self.gate_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
@ -787,8 +795,15 @@ class PixtralHFMLP(nn.Module):
class PixtralHFAttention(nn.Module):
def __init__(self, config: PixtralVisionConfig):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
assert not config.hidden_size % config.num_attention_heads
self.n_heads = config.num_attention_heads
@ -796,6 +811,7 @@ class PixtralHFAttention(nn.Module):
self.scale = self.head_dim**-0.5
# TODO: Use quant_config and prefix after optimizing this
self.q_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
@ -840,11 +856,22 @@ class PixtralHFAttention(nn.Module):
class PixtralHFTransformerBlock(nn.Module):
def __init__(self, config: PixtralVisionConfig):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(config)
self.feed_forward = PixtralHFMLP(config)
self.attention = PixtralHFAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.feed_forward = PixtralHFMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward")
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
@ -864,11 +891,27 @@ class PixtralHFTransformerBlock(nn.Module):
class PixtralHFTransformer(nn.Module):
def __init__(self, config: PixtralVisionConfig):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(PixtralHFTransformerBlock(config))
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
PixtralHFTransformerBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(
self,
@ -883,7 +926,15 @@ class PixtralHFTransformer(nn.Module):
class PixtralHFVisionModel(nn.Module):
def __init__(self, config: PixtralVisionConfig):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@ -895,7 +946,24 @@ class PixtralHFVisionModel(nn.Module):
bias=False,
)
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(config)
self.transformer = PixtralHFTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)
num_hidden_layers = config.num_hidden_layers
if len(self.transformer.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.transformer.layers)} "
"layers.")
if require_post_norm is True:
msg = "PixtralHFVisionModel does not have post-layernorm"
raise ValueError(msg)
self.dtype = next(self.parameters()).dtype
self.device = next(self.parameters()).device
self.patch_positional_embedding = PixtralRotaryEmbedding(

View File

@ -248,8 +248,10 @@ class SiglipParallelAttention(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@ -266,12 +268,14 @@ class SiglipParallelAttention(nn.Module):
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
@ -314,8 +318,10 @@ class SiglipMLP(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
@ -326,11 +332,13 @@ class SiglipMLP(nn.Module):
config.hidden_size,
config.intermediate_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -346,15 +354,20 @@ class SiglipEncoderLayer(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config,
quant_config=quant_config)
self.self_attn = SiglipParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = SiglipSdpaAttention(config)
@ -363,6 +376,7 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
@ -392,8 +406,10 @@ class SiglipEncoder(nn.Module):
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if num_hidden_layers_override is None:
@ -402,8 +418,10 @@ class SiglipEncoder(nn.Module):
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
SiglipEncoderLayer(config, quant_config=quant_config)
for _ in range(num_hidden_layers)
SiglipEncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(
@ -424,7 +442,8 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
@ -433,7 +452,9 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config=config, quant_config=quant_config)
self.mlp = SiglipMLP(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size = hidden_state.shape[0]
@ -454,9 +475,13 @@ class SiglipVisionTransformer(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
):
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
embed_dim = config.hidden_size
@ -465,26 +490,34 @@ class SiglipVisionTransformer(nn.Module):
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head)
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(
config=config, quant_config=quant_config)
config=config,
quant_config=quant_config,
prefix=f"{prefix}.head",
)
def forward(
self,
@ -517,8 +550,11 @@ class SiglipVisionModel(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
):
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
num_heads = config.num_attention_heads
@ -529,6 +565,8 @@ class SiglipVisionModel(nn.Module):
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
def get_input_embeddings(self) -> nn.Module: