mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 05:29:12 +08:00
[Bugfix] Revert QKVCrossParallelLinear usage in Mllama to keep BNB quantization work (#14498)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
73ae0b44e9
commit
fb16eea48b
@ -43,7 +43,6 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVCrossParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -814,11 +813,20 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.q_local_size = self.num_local_heads * self.head_dim
|
||||
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
||||
|
||||
self.qkv_proj = QKVCrossParallelLinear(
|
||||
# TODO(Isotr0py): Use QKVCrossParallelLinear when it supports
|
||||
# quantization
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.num_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
)
|
||||
self.kv_proj = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
total_num_heads=0,
|
||||
total_num_kv_heads=self.num_key_value_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
@ -854,11 +862,15 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
cross_attention_states: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
if cross_attention_states is not None:
|
||||
kv, _ = self.kv_proj(cross_attention_states)
|
||||
k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)
|
||||
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
||||
k = self.k_norm(k)
|
||||
else:
|
||||
k = v = None
|
||||
|
||||
q = q.view(-1, self.num_local_heads, self.head_dim)
|
||||
q = self.q_norm(q)
|
||||
@ -1149,8 +1161,13 @@ class MllamaForCausalLM(nn.Module):
|
||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
],
|
||||
"cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -1420,9 +1437,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
(".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),
|
||||
(".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user