[Core] Optimizing cross-attention QKVParallelLinear computation (#12325)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nick@nlucches-4xa100.c.openshift-330514.internal>
Co-authored-by: NickLucche <nick@nlucches-4xa100.c.openshift-330514.internal>
This commit is contained in:
Nicolò Lucchesi 2025-03-06 10:37:26 +01:00 committed by GitHub
parent 5d802522a7
commit 69ff99fdcd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 121 additions and 44 deletions

View File

@ -1227,3 +1227,98 @@ class RowParallelLinear(LinearBase):
s += f", tp_size={self.tp_size}" s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}" s += f", reduce_results={self.reduce_results}"
return s return s
class QKVCrossParallelLinear(torch.nn.Module):
def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter()
set_weight_attrs(self.weight, {
"weight_loader": self.weight_loader_weight,
})
# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
self.proj["q_proj_decoder"] = ColumnParallelLinear(
input_size=hidden_size,
output_size=total_num_heads * head_size,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.q_proj_decoder")
self.proj["kv_proj_encoder"] = QKVParallelLinear(
hidden_size=hidden_size,
head_size=head_size,
total_num_heads=0,
total_num_kv_heads=total_num_kv_heads,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.kv_proj_encoder")
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias,
})
@property
def q_proj_decoder(self):
return self.proj["q_proj_decoder"]
@property
def kv_proj_encoder(self):
return self.proj["kv_proj_encoder"]
def forward(self, decoder_hidden_states, encoder_hidden_states):
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
k = None
v = None
else:
# Prefill phase, encoder KV cached here.
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
# Split kv in half
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v
def weight_loader_weight(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
def weight_loader_bias(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
else self.kv_proj_encoder.bias
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)

View File

@ -31,6 +31,7 @@ from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -169,7 +170,7 @@ class BartEncoderAttention(nn.Module):
# Number of KV heads is less than TP size, so we replicate # Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0 assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) self.num_kv_heads = self.num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
@ -248,7 +249,7 @@ class BartDecoderSelfAttention(nn.Module):
# Number of KV heads is less than TP size, so we replicate # Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0 assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) self.num_kv_heads = self.num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
@ -299,14 +300,14 @@ class BartCrossAttention(nn.Module):
f" and `num_heads`: {num_heads}).") f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear( # TP sharding sizes is accounted for within "*Parallel" layers.
self.d_model, self.qkv_proj = QKVCrossParallelLinear(self.d_model,
self.d_model // self.total_num_heads, self.d_model //
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_heads,
bias=bias, self.total_num_kv_heads,
quant_config=quant_config, bias,
) quant_config=quant_config)
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
embed_dim, embed_dim,
@ -327,10 +328,7 @@ class BartCrossAttention(nn.Module):
# Number of KV heads is less than TP size, so we replicate # Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0 assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) self.num_kv_heads = self.num_heads # No GQA in bart
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
@ -347,18 +345,7 @@ class BartCrossAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
# (afeldman-nm 2024/07/22) TODO: q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states)
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if encoder_hidden_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)

View File

@ -43,6 +43,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -798,21 +799,22 @@ class MllamaTextCrossAttention(nn.Module):
self.config = config self.config = config
self.pipeline_parallel_rank = get_pp_group().rank_in_group self.pipeline_parallel_rank = get_pp_group().rank_in_group
self.tensor_parallel_size = get_tp_group().world_size self.tensor_parallel_size = get_tp_group().world_size
self.num_heads = self.config.num_attention_heads self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_local_heads = self.num_heads // self.tensor_parallel_size self.num_local_heads = self.num_heads // self.tensor_parallel_size
self.num_key_value_heads = self.config.num_key_value_heads
self.num_local_key_value_heads = \ self.num_local_key_value_heads = \
self.num_key_value_heads // self.tensor_parallel_size self.num_key_value_heads // self.tensor_parallel_size
self.dropout = config.dropout
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_local_size = self.num_local_heads * self.head_dim self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim
# TODO: change to Q/KV separate linear after #7448 is merged self.qkv_proj = QKVCrossParallelLinear(
self.qkv_proj = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.head_dim, self.head_dim,
self.num_heads, self.num_heads,
@ -821,6 +823,7 @@ class MllamaTextCrossAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim, self.num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
@ -851,21 +854,12 @@ class MllamaTextCrossAttention(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
qkv_dec, _ = self.qkv_proj(hidden_states) q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
q, _, _ = qkv_dec.split( if cross_attention_states is not None:
[self.q_local_size, self.kv_local_size, self.kv_local_size],
dim=-1)
if cross_attention_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(cross_attention_states)
_, k, v = qkv_enc.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size],
dim=-1)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim) k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k) k = self.k_norm(k)
q = q.view(-1, self.num_local_heads, self.head_dim) q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q) q = self.q_norm(q)
@ -889,6 +883,7 @@ class MllamaTextCrossAttention(nn.Module):
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# Skip writing kv-cache for the initial profiling run. # Skip writing kv-cache for the initial profiling run.
# TODO (NickLucche) replace with custom attn bias and use standard attn
if len(kv_cache.shape) > 1: if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32) i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN, if self.attn.backend in (_Backend.FLASH_ATTN,

View File

@ -650,4 +650,4 @@ def cast_overflow_tensors(
if tensors.isinf().any() or tensors.isnan().any(): if tensors.isinf().any() or tensors.isnan().any():
clamp_value = torch.finfo(tensors.dtype).max - offset clamp_value = torch.finfo(tensors.dtype).max - offset
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
return tensors return tensors