[Bugfix] Revert QKVCrossParallelLinear usage in Mllama to keep BNB quantization work (#14498)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-03-09 12:47:45 +08:00 committed by GitHub
parent 73ae0b44e9
commit fb16eea48b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -43,7 +43,6 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -814,11 +813,20 @@ class MllamaTextCrossAttention(nn.Module):
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
self.qkv_proj = QKVCrossParallelLinear(
# TODO(Isotr0py): Use QKVCrossParallelLinear when it supports
# quantization
self.q_proj = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
self.num_key_value_heads,
total_num_heads=0,
total_num_kv_heads=self.num_key_value_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
@ -854,11 +862,15 @@ class MllamaTextCrossAttention(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor:
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
q, _ = self.q_proj(hidden_states)
if cross_attention_states is not None:
kv, _ = self.kv_proj(cross_attention_states)
k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k)
else:
k = v = None
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)
@ -1149,8 +1161,13 @@ class MllamaForCausalLM(nn.Module):
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -1420,9 +1437,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),
(".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]