mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[Core] Optimizing cross-attention QKVParallelLinear computation (#12325)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: NickLucche <nick@nlucches-4xa100.c.openshift-330514.internal> Co-authored-by: NickLucche <nick@nlucches-4xa100.c.openshift-330514.internal>
This commit is contained in:
parent
5d802522a7
commit
69ff99fdcd
@ -1227,3 +1227,98 @@ class RowParallelLinear(LinearBase):
|
|||||||
s += f", tp_size={self.tp_size}"
|
s += f", tp_size={self.tp_size}"
|
||||||
s += f", reduce_results={self.reduce_results}"
|
s += f", reduce_results={self.reduce_results}"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class QKVCrossParallelLinear(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int,
|
||||||
|
head_size: int,
|
||||||
|
total_num_heads: int,
|
||||||
|
total_num_kv_heads: Optional[int] = None,
|
||||||
|
bias: bool = True,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
# Empty placeholders for loading as a single module.
|
||||||
|
self.weight = torch.nn.Parameter()
|
||||||
|
set_weight_attrs(self.weight, {
|
||||||
|
"weight_loader": self.weight_loader_weight,
|
||||||
|
})
|
||||||
|
# Use a dictionary to avoid submodules parameters auto-registration:
|
||||||
|
# drop-in replacement for a `QKVParallelLinear` module.
|
||||||
|
self.proj = dict()
|
||||||
|
self.proj["q_proj_decoder"] = ColumnParallelLinear(
|
||||||
|
input_size=hidden_size,
|
||||||
|
output_size=total_num_heads * head_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
skip_bias_add=skip_bias_add,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
prefix=f"{prefix}.q_proj_decoder")
|
||||||
|
|
||||||
|
self.proj["kv_proj_encoder"] = QKVParallelLinear(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
head_size=head_size,
|
||||||
|
total_num_heads=0,
|
||||||
|
total_num_kv_heads=total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
skip_bias_add=skip_bias_add,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
prefix=f"{prefix}.kv_proj_encoder")
|
||||||
|
|
||||||
|
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
|
||||||
|
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter()
|
||||||
|
set_weight_attrs(self.bias, {
|
||||||
|
"weight_loader": self.weight_loader_bias,
|
||||||
|
})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def q_proj_decoder(self):
|
||||||
|
return self.proj["q_proj_decoder"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kv_proj_encoder(self):
|
||||||
|
return self.proj["kv_proj_encoder"]
|
||||||
|
|
||||||
|
def forward(self, decoder_hidden_states, encoder_hidden_states):
|
||||||
|
q, _ = self.q_proj_decoder(decoder_hidden_states)
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
# Encoder KV already cached.
|
||||||
|
k = None
|
||||||
|
v = None
|
||||||
|
else:
|
||||||
|
# Prefill phase, encoder KV cached here.
|
||||||
|
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
|
||||||
|
# Split kv in half
|
||||||
|
k, v = kv_enc.split(self.kv_size, dim=-1)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def weight_loader_weight(self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
loaded_shard_id: Optional[str] = None):
|
||||||
|
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
|
||||||
|
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
|
||||||
|
else self.kv_proj_encoder.weight
|
||||||
|
param.weight_loader(
|
||||||
|
param,
|
||||||
|
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
|
||||||
|
param, loaded_weight, loaded_shard_id)
|
||||||
|
|
||||||
|
def weight_loader_bias(self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
loaded_shard_id: Optional[str] = None):
|
||||||
|
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
|
||||||
|
else self.kv_proj_encoder.bias
|
||||||
|
param.weight_loader(
|
||||||
|
param,
|
||||||
|
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
|
||||||
|
param, loaded_weight, loaded_shard_id)
|
||||||
@ -31,6 +31,7 @@ from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVCrossParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@ -169,7 +170,7 @@ class BartEncoderAttention(nn.Module):
|
|||||||
# Number of KV heads is less than TP size, so we replicate
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert tp_world_size % self.total_num_kv_heads == 0
|
assert tp_world_size % self.total_num_kv_heads == 0
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
self.num_kv_heads = self.num_heads
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
|
||||||
@ -248,7 +249,7 @@ class BartDecoderSelfAttention(nn.Module):
|
|||||||
# Number of KV heads is less than TP size, so we replicate
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert tp_world_size % self.total_num_kv_heads == 0
|
assert tp_world_size % self.total_num_kv_heads == 0
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
self.num_kv_heads = self.num_heads
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
|
||||||
@ -299,14 +300,14 @@ class BartCrossAttention(nn.Module):
|
|||||||
f" and `num_heads`: {num_heads}).")
|
f" and `num_heads`: {num_heads}).")
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
# TP sharding sizes is accounted for within "*Parallel" layers.
|
||||||
self.d_model,
|
self.qkv_proj = QKVCrossParallelLinear(self.d_model,
|
||||||
self.d_model // self.total_num_heads,
|
self.d_model //
|
||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_heads,
|
||||||
bias=bias,
|
self.total_num_kv_heads,
|
||||||
quant_config=quant_config,
|
bias,
|
||||||
)
|
quant_config=quant_config)
|
||||||
|
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
@ -327,10 +328,7 @@ class BartCrossAttention(nn.Module):
|
|||||||
# Number of KV heads is less than TP size, so we replicate
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert tp_world_size % self.total_num_kv_heads == 0
|
assert tp_world_size % self.total_num_kv_heads == 0
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
self.num_kv_heads = self.num_heads # No GQA in bart
|
||||||
self.q_size = self.num_heads * self.head_dim
|
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
|
||||||
|
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
@ -347,18 +345,7 @@ class BartCrossAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
# (afeldman-nm 2024/07/22) TODO:
|
q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states)
|
||||||
# Need a more efficient solution for q/k/v
|
|
||||||
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
|
|
||||||
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
|
|
||||||
dim=-1)
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
k = None
|
|
||||||
v = None
|
|
||||||
else:
|
|
||||||
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
|
|
||||||
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
|
||||||
dim=-1)
|
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVCrossParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@ -798,21 +799,22 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.pipeline_parallel_rank = get_pp_group().rank_in_group
|
self.pipeline_parallel_rank = get_pp_group().rank_in_group
|
||||||
self.tensor_parallel_size = get_tp_group().world_size
|
self.tensor_parallel_size = get_tp_group().world_size
|
||||||
self.num_heads = self.config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
|
||||||
self.num_local_heads = self.num_heads // self.tensor_parallel_size
|
self.num_local_heads = self.num_heads // self.tensor_parallel_size
|
||||||
self.num_key_value_heads = self.config.num_key_value_heads
|
|
||||||
self.num_local_key_value_heads = \
|
self.num_local_key_value_heads = \
|
||||||
self.num_key_value_heads // self.tensor_parallel_size
|
self.num_key_value_heads // self.tensor_parallel_size
|
||||||
self.dropout = config.dropout
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_dim = config.hidden_size // self.num_heads
|
self.head_dim = config.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.q_local_size = self.num_local_heads * self.head_dim
|
self.q_local_size = self.num_local_heads * self.head_dim
|
||||||
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
||||||
|
|
||||||
# TODO: change to Q/KV separate linear after #7448 is merged
|
self.qkv_proj = QKVCrossParallelLinear(
|
||||||
self.qkv_proj = QKVParallelLinear(
|
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -821,6 +823,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.num_heads * self.head_dim,
|
self.num_heads * self.head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
@ -851,21 +854,12 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||||
cross_attention_states: Optional[torch.Tensor],
|
cross_attention_states: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv_dec, _ = self.qkv_proj(hidden_states)
|
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
|
||||||
q, _, _ = qkv_dec.split(
|
if cross_attention_states is not None:
|
||||||
[self.q_local_size, self.kv_local_size, self.kv_local_size],
|
|
||||||
dim=-1)
|
|
||||||
if cross_attention_states is None:
|
|
||||||
k = None
|
|
||||||
v = None
|
|
||||||
else:
|
|
||||||
qkv_enc, _ = self.qkv_proj(cross_attention_states)
|
|
||||||
_, k, v = qkv_enc.split(
|
|
||||||
[self.q_local_size, self.kv_local_size, self.kv_local_size],
|
|
||||||
dim=-1)
|
|
||||||
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
||||||
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
q = q.view(-1, self.num_local_heads, self.head_dim)
|
q = q.view(-1, self.num_local_heads, self.head_dim)
|
||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
|
|
||||||
@ -889,6 +883,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
|
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
|
||||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
# Skip writing kv-cache for the initial profiling run.
|
# Skip writing kv-cache for the initial profiling run.
|
||||||
|
# TODO (NickLucche) replace with custom attn bias and use standard attn
|
||||||
if len(kv_cache.shape) > 1:
|
if len(kv_cache.shape) > 1:
|
||||||
i = torch.ones(1, dtype=torch.float32)
|
i = torch.ones(1, dtype=torch.float32)
|
||||||
if self.attn.backend in (_Backend.FLASH_ATTN,
|
if self.attn.backend in (_Backend.FLASH_ATTN,
|
||||||
|
|||||||
@ -650,4 +650,4 @@ def cast_overflow_tensors(
|
|||||||
if tensors.isinf().any() or tensors.isnan().any():
|
if tensors.isinf().any() or tensors.isnan().any():
|
||||||
clamp_value = torch.finfo(tensors.dtype).max - offset
|
clamp_value = torch.finfo(tensors.dtype).max - offset
|
||||||
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
|
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
|
||||||
return tensors
|
return tensors
|
||||||
Loading…
x
Reference in New Issue
Block a user