From a655eb30252fe266ce16fde2aa9f8f9554ccd46e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 4 Jan 2025 06:19:02 +0800 Subject: [PATCH] [Misc]Add BNB quantization for Qwen2VL (#11719) Signed-off-by: Jee Jee Li Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/qwen2_vl.py | 69 +++++++++++++++----------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 26b6d768ad4f6..5a8c6e4deb7ac 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -38,7 +38,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed import parallel_state +from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -239,6 +239,8 @@ class Qwen2VisionAttention(nn.Module): super().__init__() # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = world_size + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( @@ -261,24 +263,41 @@ class Qwen2VisionAttention(nn.Module): raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now.") + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: - # [s, b, c] --> [s, b, head * 3 * head_dim] + + # [s, b, c] --> [s, b, 3 * head * head_dim] x, _ = self.qkv(x) - # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - x = x.view(*new_x_shape) - - # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] - q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() @@ -614,24 +633,6 @@ class Qwen2VisionTransformer(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - if name.endswith("qkv.weight"): - visual_num_heads = self.num_heads - visual_embed_dim = self.embed_dim - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size, - visual_embed_dim) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) - elif name.endswith("qkv.bias"): - visual_num_heads = self.num_heads - visual_embed_dim = self.embed_dim - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -935,6 +936,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, embedding_modules = {} embedding_padding_modules = [] + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "lm_head.": "language_model.lm_head.",