mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 15:15:14 +08:00
[V0 deprecation] Remove QKVCrossParallelLinear implementation (#26475)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
ec10fd0abc
commit
d1ddf340c8
@ -1,8 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .base import BaseLayerWithLoRA
|
||||
|
||||
|
||||
# TODO: Implement this
|
||||
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
pass
|
||||
@ -3,10 +3,9 @@
|
||||
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (
|
||||
@ -1440,237 +1439,3 @@ class RowParallelLinear(LinearBase):
|
||||
s += f", tp_size={self.tp_size}"
|
||||
s += f", reduce_results={self.reduce_results}"
|
||||
return s
|
||||
|
||||
|
||||
@CustomOp.register("qkv_cross_parallel_linear")
|
||||
class QKVCrossParallelLinear(LinearBase):
|
||||
"""Linear layers for efficient cross-attention's QKV transformation.
|
||||
|
||||
Args:
|
||||
hidden_size: input hidden state size of the transformer.
|
||||
head_size: size of each attention head.
|
||||
total_num_heads: total number of attention query heads.
|
||||
total_num_kv_heads: total number of attention key/value heads. If
|
||||
None, assume total_num_kv_heads = total_num_heads.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
# input_size and output_size are not used, just for alignment
|
||||
input_size = hidden_size
|
||||
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Empty placeholders for loading as a single module.
|
||||
placeholder_size = 0
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
placeholder_size,
|
||||
[placeholder_size],
|
||||
placeholder_size,
|
||||
placeholder_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
|
||||
# Use a dictionary to avoid submodules parameters auto-registration:
|
||||
# drop-in replacement for a `QKVParallelLinear` module.
|
||||
self.proj = dict()
|
||||
self.proj["q_proj_decoder"] = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=total_num_heads * head_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
prefix=f"{prefix}.q_proj_decoder",
|
||||
)
|
||||
|
||||
self.proj["kv_proj_encoder"] = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=head_size,
|
||||
total_num_heads=0,
|
||||
total_num_kv_heads=total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
prefix=f"{prefix}.kv_proj_encoder",
|
||||
)
|
||||
|
||||
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
|
||||
self.q_size = self.q_proj_decoder.output_size_per_partition
|
||||
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
|
||||
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter()
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader_v1,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
for layer in self.proj.values():
|
||||
if self.quant_method is not None:
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
@property
|
||||
def q_proj_decoder(self) -> ColumnParallelLinear:
|
||||
layer = self.proj["q_proj_decoder"]
|
||||
for name, param in self.named_parameters():
|
||||
target_param = getattr(layer, name, None)
|
||||
if target_param is not None:
|
||||
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
|
||||
return layer
|
||||
|
||||
@property
|
||||
def kv_proj_encoder(self) -> QKVParallelLinear:
|
||||
layer = self.proj["kv_proj_encoder"]
|
||||
for name, param in self.named_parameters():
|
||||
target_param = getattr(layer, name, None)
|
||||
if target_param is not None:
|
||||
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
|
||||
return layer
|
||||
|
||||
def sync_weight_attrs(
|
||||
self,
|
||||
src_param: nn.Parameter,
|
||||
tgt_param: nn.Parameter,
|
||||
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
|
||||
):
|
||||
missing_attrs_dict = {
|
||||
k: getattr(src_param, k)
|
||||
for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys()))
|
||||
}
|
||||
# TODO(Isotr0py): handle bitsandbytes 8bit
|
||||
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False)
|
||||
if missing_attrs_dict and use_bitsandbytes_4bit:
|
||||
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
|
||||
missing_attrs_dict
|
||||
)
|
||||
if mode == "q_proj_decoder":
|
||||
set_weight_attrs(tgt_param, q_proj_attrs)
|
||||
elif mode == "kv_proj_encoder":
|
||||
set_weight_attrs(tgt_param, kv_proj_attrs)
|
||||
else:
|
||||
set_weight_attrs(tgt_param, missing_attrs_dict)
|
||||
|
||||
def _is_same_param(
|
||||
self,
|
||||
src_param: torch.nn.Parameter,
|
||||
map_param: torch.nn.Parameter,
|
||||
) -> bool:
|
||||
"""Check if two parameters are exactly pointing to same things."""
|
||||
# ignore weight_loader because it's always different
|
||||
key_to_ignore = ["weight_loader", "_weight_loader"]
|
||||
has_same_type_name = type(src_param) is type(map_param)
|
||||
src_param_attrs = {
|
||||
k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore
|
||||
}
|
||||
map_param_attrs = {
|
||||
k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore
|
||||
}
|
||||
has_same_attrs = src_param_attrs == map_param_attrs
|
||||
return has_same_type_name and has_same_attrs
|
||||
|
||||
def select_proj_params(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
param: nn.Parameter,
|
||||
) -> nn.Parameter:
|
||||
"""
|
||||
Given the placeholder param,
|
||||
return the corresponding param in the proj layers.
|
||||
"""
|
||||
target_param_list = [
|
||||
v for _, v in layer.named_parameters() if self._is_same_param(param, v)
|
||||
]
|
||||
assert len(target_param_list) == 1
|
||||
target_param = target_param_list[0]
|
||||
return target_param
|
||||
|
||||
def forward( # type: ignore[override]
|
||||
self,
|
||||
decoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
q, _ = self.q_proj_decoder(decoder_hidden_states)
|
||||
if encoder_hidden_states is None:
|
||||
# Encoder KV already cached.
|
||||
k = None
|
||||
v = None
|
||||
else:
|
||||
# Prefill phase, encoder KV cached here.
|
||||
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
|
||||
# Split kv in half
|
||||
k, v = kv_enc.split(self.kv_size, dim=-1)
|
||||
return q, k, v
|
||||
|
||||
def weight_loader_v1(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None,
|
||||
):
|
||||
# just like all other parameters, does not yet
|
||||
# support loading bias with weight_loader_v2
|
||||
layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
|
||||
target_param = self.select_proj_params(layer, param)
|
||||
shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
|
||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None,
|
||||
):
|
||||
layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
|
||||
target_param = self.select_proj_params(layer, param)
|
||||
shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
|
||||
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
|
||||
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
|
||||
else:
|
||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", q_size={self.q_size}"
|
||||
s += f", kv_size={self.kv_size}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
||||
s += ", gather_output=False"
|
||||
return s
|
||||
|
||||
@ -16,7 +16,6 @@ from compressed_tensors.utils import is_match
|
||||
from vllm.model_executor.layers.linear import (
|
||||
WEIGHT_LOADER_V2_SUPPORTED,
|
||||
LinearMethodBase,
|
||||
QKVCrossParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsScheme,
|
||||
@ -89,10 +88,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
# hack around this by getting weight loader v1 so ULM can load correctly
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
|
||||
if isinstance(layer, QKVCrossParallelLinear):
|
||||
weight_loader_v1 = layer.weight_loader_v1
|
||||
else:
|
||||
weight_loader_v1 = layer.weight_loader
|
||||
weight_loader_v1 = layer.weight_loader
|
||||
extra_weight_attrs["weight_loader"] = weight_loader_v1
|
||||
|
||||
self.quant_method.create_weights(
|
||||
|
||||
@ -17,7 +17,6 @@ from vllm.attention import Attention
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@ -108,11 +107,6 @@ def process_weights_after_loading(
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, QKVCrossParallelLinear):
|
||||
# NOTE(Isotr0py): special case for cross QKV layer because
|
||||
# q and kv proj aren't registered as submodules intentionally
|
||||
module.process_weights_after_loading()
|
||||
continue
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user