mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 14:09:09 +08:00
[Bugfix] Revert QKVCrossParallelLinear usage in Mllama to keep BNB quantization work (#14498)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
73ae0b44e9
commit
fb16eea48b
@ -43,7 +43,6 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVCrossParallelLinear,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@ -814,11 +813,20 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
self.q_local_size = self.num_local_heads * self.head_dim
|
self.q_local_size = self.num_local_heads * self.head_dim
|
||||||
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
||||||
|
|
||||||
self.qkv_proj = QKVCrossParallelLinear(
|
# TODO(Isotr0py): Use QKVCrossParallelLinear when it supports
|
||||||
|
# quantization
|
||||||
|
self.q_proj = ColumnParallelLinear(
|
||||||
|
input_size=self.hidden_size,
|
||||||
|
output_size=self.num_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
)
|
||||||
|
self.kv_proj = QKVParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.num_heads,
|
total_num_heads=0,
|
||||||
self.num_key_value_heads,
|
total_num_kv_heads=self.num_key_value_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
@ -854,11 +862,15 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||||
cross_attention_states: Optional[torch.Tensor],
|
cross_attention_states: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
|
q, _ = self.q_proj(hidden_states)
|
||||||
if cross_attention_states is not None:
|
if cross_attention_states is not None:
|
||||||
|
kv, _ = self.kv_proj(cross_attention_states)
|
||||||
|
k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)
|
||||||
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
||||||
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
else:
|
||||||
|
k = v = None
|
||||||
|
|
||||||
q = q.view(-1, self.num_local_heads, self.head_dim)
|
q = q.view(-1, self.num_local_heads, self.head_dim)
|
||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
@ -1149,8 +1161,13 @@ class MllamaForCausalLM(nn.Module):
|
|||||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsV0Only):
|
SupportsV0Only):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"self_attn.qkv_proj": [
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
"self_attn.q_proj",
|
||||||
|
"self_attn.k_proj",
|
||||||
|
"self_attn.v_proj",
|
||||||
|
],
|
||||||
|
"cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -1420,9 +1437,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||||
(".qkv_proj", ".k_proj", "k"),
|
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||||
(".qkv_proj", ".v_proj", "v"),
|
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||||
|
(".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),
|
||||||
|
(".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),
|
||||||
(".gate_up_proj", ".gate_proj", 0),
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
(".gate_up_proj", ".up_proj", 1),
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user