From d1ddf340c88110ab9b961d5464ed7599bd4dc9a8 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 9 Oct 2025 18:52:27 +0800 Subject: [PATCH] [V0 deprecation] Remove `QKVCrossParallelLinear` implementation (#26475) Signed-off-by: Isotr0py --- vllm/lora/layers/qkv_x_parallel_linear.py | 8 - vllm/model_executor/layers/linear.py | 237 +----------------- .../compressed_tensors/transform/linear.py | 6 +- vllm/model_executor/model_loader/utils.py | 6 - 4 files changed, 2 insertions(+), 255 deletions(-) delete mode 100644 vllm/lora/layers/qkv_x_parallel_linear.py diff --git a/vllm/lora/layers/qkv_x_parallel_linear.py b/vllm/lora/layers/qkv_x_parallel_linear.py deleted file mode 100644 index 785cdf38e3603..0000000000000 --- a/vllm/lora/layers/qkv_x_parallel_linear.py +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .base import BaseLayerWithLoRA - - -# TODO: Implement this -class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): - pass diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3881ba12faa06..49b683a1a9f9a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -3,10 +3,9 @@ import itertools from abc import abstractmethod -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, Union import torch -import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import ( @@ -1440,237 +1439,3 @@ class RowParallelLinear(LinearBase): s += f", tp_size={self.tp_size}" s += f", reduce_results={self.reduce_results}" return s - - -@CustomOp.register("qkv_cross_parallel_linear") -class QKVCrossParallelLinear(LinearBase): - """Linear layers for efficient cross-attention's QKV transformation. - - Args: - hidden_size: input hidden state size of the transformer. - head_size: size of each attention head. - total_num_heads: total number of attention query heads. - total_num_kv_heads: total number of attention key/value heads. If - None, assume total_num_kv_heads = total_num_heads. - bias: If true, add bias. - skip_bias_add: This was added to enable performance optimizations where - bias can be fused with other element-wise operations. we - skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) - """ - - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - # input_size and output_size are not used, just for alignment - input_size = hidden_size - output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size - super().__init__( - input_size=input_size, - output_size=output_size, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix, - ) - - self.quant_config = quant_config - - # Empty placeholders for loading as a single module. - placeholder_size = 0 - assert self.quant_method is not None - self.quant_method.create_weights( - self, - placeholder_size, - [placeholder_size], - placeholder_size, - placeholder_size, - self.params_dtype, - weight_loader=self.weight_loader, - ) - - # Use a dictionary to avoid submodules parameters auto-registration: - # drop-in replacement for a `QKVParallelLinear` module. - self.proj = dict() - self.proj["q_proj_decoder"] = ColumnParallelLinear( - input_size=hidden_size, - output_size=total_num_heads * head_size, - bias=bias, - quant_config=quant_config, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - prefix=f"{prefix}.q_proj_decoder", - ) - - self.proj["kv_proj_encoder"] = QKVParallelLinear( - hidden_size=hidden_size, - head_size=head_size, - total_num_heads=0, - total_num_kv_heads=total_num_kv_heads, - bias=bias, - quant_config=quant_config, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - prefix=f"{prefix}.kv_proj_encoder", - ) - - # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. - self.q_size = self.q_proj_decoder.output_size_per_partition - self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size - - if bias: - self.bias = torch.nn.Parameter() - set_weight_attrs( - self.bias, - { - "output_dim": 0, - "weight_loader": self.weight_loader_v1, - }, - ) - else: - self.bias = None - - def process_weights_after_loading(self): - for layer in self.proj.values(): - if self.quant_method is not None: - self.quant_method.process_weights_after_loading(layer) - - @property - def q_proj_decoder(self) -> ColumnParallelLinear: - layer = self.proj["q_proj_decoder"] - for name, param in self.named_parameters(): - target_param = getattr(layer, name, None) - if target_param is not None: - self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") - return layer - - @property - def kv_proj_encoder(self) -> QKVParallelLinear: - layer = self.proj["kv_proj_encoder"] - for name, param in self.named_parameters(): - target_param = getattr(layer, name, None) - if target_param is not None: - self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") - return layer - - def sync_weight_attrs( - self, - src_param: nn.Parameter, - tgt_param: nn.Parameter, - mode: Literal["q_proj_decoder", "kv_proj_encoder"], - ): - missing_attrs_dict = { - k: getattr(src_param, k) - for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys())) - } - # TODO(Isotr0py): handle bitsandbytes 8bit - use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False) - if missing_attrs_dict and use_bitsandbytes_4bit: - q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( - missing_attrs_dict - ) - if mode == "q_proj_decoder": - set_weight_attrs(tgt_param, q_proj_attrs) - elif mode == "kv_proj_encoder": - set_weight_attrs(tgt_param, kv_proj_attrs) - else: - set_weight_attrs(tgt_param, missing_attrs_dict) - - def _is_same_param( - self, - src_param: torch.nn.Parameter, - map_param: torch.nn.Parameter, - ) -> bool: - """Check if two parameters are exactly pointing to same things.""" - # ignore weight_loader because it's always different - key_to_ignore = ["weight_loader", "_weight_loader"] - has_same_type_name = type(src_param) is type(map_param) - src_param_attrs = { - k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore - } - map_param_attrs = { - k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore - } - has_same_attrs = src_param_attrs == map_param_attrs - return has_same_type_name and has_same_attrs - - def select_proj_params( - self, - layer: nn.Module, - param: nn.Parameter, - ) -> nn.Parameter: - """ - Given the placeholder param, - return the corresponding param in the proj layers. - """ - target_param_list = [ - v for _, v in layer.named_parameters() if self._is_same_param(param, v) - ] - assert len(target_param_list) == 1 - target_param = target_param_list[0] - return target_param - - def forward( # type: ignore[override] - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - q, _ = self.q_proj_decoder(decoder_hidden_states) - if encoder_hidden_states is None: - # Encoder KV already cached. - k = None - v = None - else: - # Prefill phase, encoder KV cached here. - kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) - # Split kv in half - k, v = kv_enc.split(self.kv_size, dim=-1) - return q, k, v - - def weight_loader_v1( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, - ): - # just like all other parameters, does not yet - # support loading bias with weight_loader_v2 - layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder - target_param = self.select_proj_params(layer, param) - shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () - layer.weight_loader(target_param, loaded_weight, *shard_id_args) - - def weight_loader( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, - ): - layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder - target_param = self.select_proj_params(layer, param) - shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () - if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED: - layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args) - else: - layer.weight_loader(target_param, loaded_weight, *shard_id_args) - - def extra_repr(self) -> str: - s = f"in_features={self.input_size}" - s += f", q_size={self.q_size}" - s += f", kv_size={self.kv_size}" - s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" - s += ", gather_output=False" - return s diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py index a51fe28b975e2..edd2706b470fd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py @@ -16,7 +16,6 @@ from compressed_tensors.utils import is_match from vllm.model_executor.layers.linear import ( WEIGHT_LOADER_V2_SUPPORTED, LinearMethodBase, - QKVCrossParallelLinear, ) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsScheme, @@ -89,10 +88,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase): # hack around this by getting weight loader v1 so ULM can load correctly quant_method_name = self.quant_method.__class__.__name__ if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED: - if isinstance(layer, QKVCrossParallelLinear): - weight_loader_v1 = layer.weight_loader_v1 - else: - weight_loader_v1 = layer.weight_loader + weight_loader_v1 = layer.weight_loader extra_weight_attrs["weight_loader"] = weight_loader_v1 self.quant_method.create_weights( diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 364b73d6b68d8..5ae32f1d120c0 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -17,7 +17,6 @@ from vllm.attention import Attention from vllm.attention.layer import MLAAttention from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger -from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -108,11 +107,6 @@ def process_weights_after_loading( maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config) for _, module in model.named_modules(): - if isinstance(module, QKVCrossParallelLinear): - # NOTE(Isotr0py): special case for cross QKV layer because - # q and kv proj aren't registered as submodules intentionally - module.process_weights_after_loading() - continue quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): # When quant methods need to process weights after loading