mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 02:15:39 +08:00
[Misc]Add BNB quantization for Qwen2VL (#11719)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
1543914c04
commit
a655eb3025
@ -38,7 +38,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
|||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
@ -239,6 +239,8 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_size = world_size
|
||||||
|
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
projection_size, num_heads)
|
projection_size, num_heads)
|
||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
@ -261,24 +263,41 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
||||||
|
|
||||||
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
|
# [s, b, 3 * head * head_dim]
|
||||||
|
seq_len, bs, _ = qkv.shape
|
||||||
|
if self.tp_size > 1:
|
||||||
|
qkv = tensor_model_parallel_all_gather(qkv)
|
||||||
|
|
||||||
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||||
|
q, k, v = qkv.chunk(3, dim=2)
|
||||||
|
|
||||||
|
# 3 * [s, b, head * head_dim]
|
||||||
|
if self.tp_size > 1:
|
||||||
|
splitter = partial(dist_utils.split_tensor_along_last_dim,
|
||||||
|
num_partitions=self.tp_size)
|
||||||
|
q = splitter(q)[self.tp_rank]
|
||||||
|
k = splitter(k)[self.tp_rank]
|
||||||
|
v = splitter(v)[self.tp_rank]
|
||||||
|
|
||||||
|
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||||
|
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head)
|
||||||
|
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
|
||||||
|
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
|
|
||||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||||
new_x_shape = x.size()[:-1] + (
|
q, k, v = self.split_qkv(x)
|
||||||
self.num_attention_heads_per_partition,
|
|
||||||
3 * self.hidden_size_per_attention_head,
|
|
||||||
)
|
|
||||||
x = x.view(*new_x_shape)
|
|
||||||
|
|
||||||
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
|
||||||
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
|
|
||||||
batch_size = q.shape[1]
|
batch_size = q.shape[1]
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||||
@ -614,24 +633,6 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if name.endswith("qkv.weight"):
|
|
||||||
visual_num_heads = self.num_heads
|
|
||||||
visual_embed_dim = self.embed_dim
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
|
||||||
head_size,
|
|
||||||
visual_embed_dim)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
|
||||||
elif name.endswith("qkv.bias"):
|
|
||||||
visual_num_heads = self.num_heads
|
|
||||||
visual_embed_dim = self.embed_dim
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
|
||||||
head_size)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1)
|
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
@ -935,6 +936,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
# To ensure correct weight loading and mapping.
|
# To ensure correct weight loading and mapping.
|
||||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||||
"lm_head.": "language_model.lm_head.",
|
"lm_head.": "language_model.lm_head.",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user