[Bugfix] Revert QKVCrossParallelLinear usage in Mllama to keep BNB quantization work (#14498)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-03-09 12:47:45 +08:00 committed by GitHub
parent 73ae0b44e9
commit fb16eea48b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -43,7 +43,6 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -814,11 +813,20 @@ class MllamaTextCrossAttention(nn.Module):
self.q_local_size = self.num_local_heads * self.head_dim self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim
self.qkv_proj = QKVCrossParallelLinear( # TODO(Isotr0py): Use QKVCrossParallelLinear when it supports
# quantization
self.q_proj = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_proj = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.head_dim, self.head_dim,
self.num_heads, total_num_heads=0,
self.num_key_value_heads, total_num_kv_heads=self.num_key_value_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
@ -854,11 +862,15 @@ class MllamaTextCrossAttention(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
q, k, v = self.qkv_proj(hidden_states, cross_attention_states) q, _ = self.q_proj(hidden_states)
if cross_attention_states is not None: if cross_attention_states is not None:
kv, _ = self.kv_proj(cross_attention_states)
k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim) k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k) k = self.k_norm(k)
else:
k = v = None
q = q.view(-1, self.num_local_heads, self.head_dim) q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q) q = self.q_norm(q)
@ -1149,8 +1161,13 @@ class MllamaForCausalLM(nn.Module):
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only): SupportsV0Only):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "self_attn.qkv_proj": [
"gate_up_proj": ["gate_proj", "up_proj"] "self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -1420,9 +1437,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".qkv_proj", ".k_proj", "k"), (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".qkv_proj", ".v_proj", "v"), (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),
(".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]