mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 23:11:27 +08:00
[Model] Support quantization of Qwen2VisionTransformer (#9817)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
890ca36072
commit
d087bf863e
@ -126,15 +126,18 @@ class Qwen2VisionMLP(nn.Module):
|
|||||||
hidden_features: int = None,
|
hidden_features: int = None,
|
||||||
act_layer: Type[nn.Module] = QuickGELU,
|
act_layer: Type[nn.Module] = QuickGELU,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc1 = ColumnParallelLinear(in_features,
|
self.fc1 = ColumnParallelLinear(in_features,
|
||||||
hidden_features,
|
hidden_features,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc1")
|
||||||
self.act = act_layer()
|
self.act = act_layer()
|
||||||
self.fc2 = RowParallelLinear(hidden_features,
|
self.fc2 = RowParallelLinear(hidden_features,
|
||||||
in_features,
|
in_features,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fc2")
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_parallel, _ = self.fc1(x)
|
x_parallel, _ = self.fc1(x)
|
||||||
@ -196,6 +199,7 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
num_heads: Optional[int] = None,
|
num_heads: Optional[int] = None,
|
||||||
projection_size: Optional[int] = None,
|
projection_size: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
@ -207,10 +211,12 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
|
|
||||||
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
||||||
output_size=3 * projection_size,
|
output_size=3 * projection_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv")
|
||||||
self.proj = RowParallelLinear(input_size=projection_size,
|
self.proj = RowParallelLinear(input_size=projection_size,
|
||||||
output_size=embed_dim,
|
output_size=embed_dim,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.proj")
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend()
|
self.attn_backend: _Backend = get_vit_attn_backend()
|
||||||
@ -310,6 +316,7 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
act_layer: Type[nn.Module] = QuickGELU,
|
act_layer: Type[nn.Module] = QuickGELU,
|
||||||
norm_layer: Type[nn.Module] = None,
|
norm_layer: Type[nn.Module] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -321,11 +328,13 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
self.attn = Qwen2VisionAttention(embed_dim=dim,
|
self.attn = Qwen2VisionAttention(embed_dim=dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
self.mlp = Qwen2VisionMLP(dim,
|
self.mlp = Qwen2VisionMLP(dim,
|
||||||
mlp_hidden_dim,
|
mlp_hidden_dim,
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
|
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
|
||||||
@ -374,6 +383,7 @@ class Qwen2VisionPatchMerger(nn.Module):
|
|||||||
norm_layer: Type[nn.Module] = None,
|
norm_layer: Type[nn.Module] = None,
|
||||||
spatial_merge_size: int = 2,
|
spatial_merge_size: int = 2,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||||
@ -384,12 +394,14 @@ class Qwen2VisionPatchMerger(nn.Module):
|
|||||||
ColumnParallelLinear(self.hidden_size,
|
ColumnParallelLinear(self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config),
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp.0"),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
RowParallelLinear(self.hidden_size,
|
RowParallelLinear(self.hidden_size,
|
||||||
d_model,
|
d_model,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config),
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp.2"),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -440,6 +452,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
vision_config: Qwen2VLVisionConfig,
|
vision_config: Qwen2VLVisionConfig,
|
||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -467,28 +480,29 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
Qwen2VisionBlock(
|
Qwen2VisionBlock(dim=embed_dim,
|
||||||
dim=embed_dim,
|
num_heads=num_heads,
|
||||||
num_heads=num_heads,
|
mlp_ratio=mlp_ratio,
|
||||||
mlp_ratio=mlp_ratio,
|
norm_layer=norm_layer,
|
||||||
norm_layer=norm_layer,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||||
) for _ in range(depth)
|
for layer_idx in range(depth)
|
||||||
])
|
])
|
||||||
self.merger = Qwen2VisionPatchMerger(
|
self.merger = Qwen2VisionPatchMerger(
|
||||||
d_model=hidden_size,
|
d_model=hidden_size,
|
||||||
context_dim=embed_dim,
|
context_dim=embed_dim,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.merger",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
return self.blocks[0].mlp.fc2.weight.dtype
|
return self.patch_embed.proj.weight.dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self.blocks[0].mlp.fc2.weight.device
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||||
pos_ids = []
|
pos_ids = []
|
||||||
@ -932,10 +946,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.visual = Qwen2VisionTransformer(
|
self.visual = Qwen2VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
|
quant_config=quant_config,
|
||||||
# NOTE: Qwen2-VL vision encoder does not support any
|
prefix="visual",
|
||||||
# quantization method now.
|
|
||||||
quant_config=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = Qwen2Model(config,
|
self.model = Qwen2Model(config,
|
||||||
@ -1175,7 +1187,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if "visual" in name and "qkv.weight" in name:
|
if "visual" in name and name.endswith("qkv.weight"):
|
||||||
visual_num_heads = self.config.vision_config.num_heads
|
visual_num_heads = self.config.vision_config.num_heads
|
||||||
visual_embed_dim = self.config.vision_config.embed_dim
|
visual_embed_dim = self.config.vision_config.embed_dim
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
head_size = visual_embed_dim // visual_num_heads
|
||||||
@ -1184,7 +1196,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
visual_embed_dim)
|
visual_embed_dim)
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
||||||
elif "visual" in name and "qkv.bias" in name:
|
elif "visual" in name and name.endswith("qkv.bias"):
|
||||||
visual_num_heads = self.config.vision_config.num_heads
|
visual_num_heads = self.config.vision_config.num_heads
|
||||||
visual_embed_dim = self.config.vision_config.embed_dim
|
visual_embed_dim = self.config.vision_config.embed_dim
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
head_size = visual_embed_dim // visual_num_heads
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user