[V0 deprecation] Remove QKVCrossParallelLinear implementation (#26475)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-09 18:52:27 +08:00 committed by GitHub
parent ec10fd0abc
commit d1ddf340c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 2 additions and 255 deletions

View File

@ -1,8 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .base import BaseLayerWithLoRA
# TODO: Implement this
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
pass

View File

@ -3,10 +3,9 @@
import itertools
from abc import abstractmethod
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (
@ -1440,237 +1439,3 @@ class RowParallelLinear(LinearBase):
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
@CustomOp.register("qkv_cross_parallel_linear")
class QKVCrossParallelLinear(LinearBase):
"""Linear layers for efficient cross-attention's QKV transformation.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
# input_size and output_size are not used, just for alignment
input_size = hidden_size
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
super().__init__(
input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
)
self.quant_config = quant_config
# Empty placeholders for loading as a single module.
placeholder_size = 0
assert self.quant_method is not None
self.quant_method.create_weights(
self,
placeholder_size,
[placeholder_size],
placeholder_size,
placeholder_size,
self.params_dtype,
weight_loader=self.weight_loader,
)
# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
self.proj["q_proj_decoder"] = ColumnParallelLinear(
input_size=hidden_size,
output_size=total_num_heads * head_size,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.q_proj_decoder",
)
self.proj["kv_proj_encoder"] = QKVParallelLinear(
hidden_size=hidden_size,
head_size=head_size,
total_num_heads=0,
total_num_kv_heads=total_num_kv_heads,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.kv_proj_encoder",
)
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.q_size = self.q_proj_decoder.output_size_per_partition
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader_v1,
},
)
else:
self.bias = None
def process_weights_after_loading(self):
for layer in self.proj.values():
if self.quant_method is not None:
self.quant_method.process_weights_after_loading(layer)
@property
def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name, None)
if target_param is not None:
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
return layer
@property
def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name, None)
if target_param is not None:
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
return layer
def sync_weight_attrs(
self,
src_param: nn.Parameter,
tgt_param: nn.Parameter,
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
):
missing_attrs_dict = {
k: getattr(src_param, k)
for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys()))
}
# TODO(Isotr0py): handle bitsandbytes 8bit
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False)
if missing_attrs_dict and use_bitsandbytes_4bit:
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
missing_attrs_dict
)
if mode == "q_proj_decoder":
set_weight_attrs(tgt_param, q_proj_attrs)
elif mode == "kv_proj_encoder":
set_weight_attrs(tgt_param, kv_proj_attrs)
else:
set_weight_attrs(tgt_param, missing_attrs_dict)
def _is_same_param(
self,
src_param: torch.nn.Parameter,
map_param: torch.nn.Parameter,
) -> bool:
"""Check if two parameters are exactly pointing to same things."""
# ignore weight_loader because it's always different
key_to_ignore = ["weight_loader", "_weight_loader"]
has_same_type_name = type(src_param) is type(map_param)
src_param_attrs = {
k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore
}
map_param_attrs = {
k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore
}
has_same_attrs = src_param_attrs == map_param_attrs
return has_same_type_name and has_same_attrs
def select_proj_params(
self,
layer: nn.Module,
param: nn.Parameter,
) -> nn.Parameter:
"""
Given the placeholder param,
return the corresponding param in the proj layers.
"""
target_param_list = [
v for _, v in layer.named_parameters() if self._is_same_param(param, v)
]
assert len(target_param_list) == 1
target_param = target_param_list[0]
return target_param
def forward( # type: ignore[override]
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
k = None
v = None
else:
# Prefill phase, encoder KV cached here.
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
# Split kv in half
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v
def weight_loader_v1(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
# just like all other parameters, does not yet
# support loading bias with weight_loader_v2
layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
else:
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", q_size={self.q_size}"
s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += ", gather_output=False"
return s

View File

@ -16,7 +16,6 @@ from compressed_tensors.utils import is_match
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase,
QKVCrossParallelLinear,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme,
@ -89,10 +88,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
# hack around this by getting weight loader v1 so ULM can load correctly
quant_method_name = self.quant_method.__class__.__name__
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
if isinstance(layer, QKVCrossParallelLinear):
weight_loader_v1 = layer.weight_loader_v1
else:
weight_loader_v1 = layer.weight_loader
weight_loader_v1 = layer.weight_loader
extra_weight_attrs["weight_loader"] = weight_loader_v1
self.quant_method.create_weights(

View File

@ -17,7 +17,6 @@ from vllm.attention import Attention
from vllm.attention.layer import MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@ -108,11 +107,6 @@ def process_weights_after_loading(
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
for _, module in model.named_modules():
if isinstance(module, QKVCrossParallelLinear):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module.process_weights_after_loading()
continue
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading