[Model] Support quantization of Qwen2VisionTransformer (#9817)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin 2024-10-31 01:41:20 -04:00 committed by GitHub
parent 890ca36072
commit d087bf863e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -126,15 +126,18 @@ class Qwen2VisionMLP(nn.Module):
hidden_features: int = None, hidden_features: int = None,
act_layer: Type[nn.Module] = QuickGELU, act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.fc1 = ColumnParallelLinear(in_features, self.fc1 = ColumnParallelLinear(in_features,
hidden_features, hidden_features,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.act = act_layer() self.act = act_layer()
self.fc2 = RowParallelLinear(hidden_features, self.fc2 = RowParallelLinear(hidden_features,
in_features, in_features,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.fc2")
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x) x_parallel, _ = self.fc1(x)
@ -196,6 +199,7 @@ class Qwen2VisionAttention(nn.Module):
num_heads: Optional[int] = None, num_heads: Optional[int] = None,
projection_size: Optional[int] = None, projection_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
@ -207,10 +211,12 @@ class Qwen2VisionAttention(nn.Module):
self.qkv = ColumnParallelLinear(input_size=embed_dim, self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size, output_size=3 * projection_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.proj = RowParallelLinear(input_size=projection_size, self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim, output_size=embed_dim,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.proj")
# Detect attention implementation. # Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend() self.attn_backend: _Backend = get_vit_attn_backend()
@ -310,6 +316,7 @@ class Qwen2VisionBlock(nn.Module):
act_layer: Type[nn.Module] = QuickGELU, act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None, norm_layer: Type[nn.Module] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -321,11 +328,13 @@ class Qwen2VisionBlock(nn.Module):
self.attn = Qwen2VisionAttention(embed_dim=dim, self.attn = Qwen2VisionAttention(embed_dim=dim,
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
self.mlp = Qwen2VisionMLP(dim, self.mlp = Qwen2VisionMLP(dim,
mlp_hidden_dim, mlp_hidden_dim,
act_layer=act_layer, act_layer=act_layer,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor) -> torch.Tensor: rotary_pos_emb: torch.Tensor) -> torch.Tensor:
@ -374,6 +383,7 @@ class Qwen2VisionPatchMerger(nn.Module):
norm_layer: Type[nn.Module] = None, norm_layer: Type[nn.Module] = None,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
@ -384,12 +394,14 @@ class Qwen2VisionPatchMerger(nn.Module):
ColumnParallelLinear(self.hidden_size, ColumnParallelLinear(self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config), quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
nn.GELU(), nn.GELU(),
RowParallelLinear(self.hidden_size, RowParallelLinear(self.hidden_size,
d_model, d_model,
bias=True, bias=True,
quant_config=quant_config), quant_config=quant_config,
prefix=f"{prefix}.mlp.2"),
]) ])
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -440,6 +452,7 @@ class Qwen2VisionTransformer(nn.Module):
vision_config: Qwen2VLVisionConfig, vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
@ -467,28 +480,29 @@ class Qwen2VisionTransformer(nn.Module):
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
Qwen2VisionBlock( Qwen2VisionBlock(dim=embed_dim,
dim=embed_dim, num_heads=num_heads,
num_heads=num_heads, mlp_ratio=mlp_ratio,
mlp_ratio=mlp_ratio, norm_layer=norm_layer,
norm_layer=norm_layer, quant_config=quant_config,
quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}")
) for _ in range(depth) for layer_idx in range(depth)
]) ])
self.merger = Qwen2VisionPatchMerger( self.merger = Qwen2VisionPatchMerger(
d_model=hidden_size, d_model=hidden_size,
context_dim=embed_dim, context_dim=embed_dim,
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.merger",
) )
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype return self.patch_embed.proj.weight.dtype
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
return self.blocks[0].mlp.fc2.weight.device return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = [] pos_ids = []
@ -932,10 +946,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.visual = Qwen2VisionTransformer( self.visual = Qwen2VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
# NOTE: Qwen2-VL vision encoder does not support any prefix="visual",
# quantization method now.
quant_config=None,
) )
self.model = Qwen2Model(config, self.model = Qwen2Model(config,
@ -1175,7 +1187,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
if "visual" in name and "qkv.weight" in name: if "visual" in name and name.endswith("qkv.weight"):
visual_num_heads = self.config.vision_config.num_heads visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads head_size = visual_embed_dim // visual_num_heads
@ -1184,7 +1196,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
visual_embed_dim) visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name: elif "visual" in name and name.endswith("qkv.bias"):
visual_num_heads = self.config.vision_config.num_heads visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads head_size = visual_embed_dim // visual_num_heads