mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 06:07:02 +08:00
[Core] Refactor QKVCrossParallelLinear implementation to support BNB 4-bit quantization (#14545)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
77a318bd01
commit
e392d85831
@ -17,6 +17,7 @@ from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
from ....quantization.utils import is_quant_method_supported
|
||||
from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
@ -397,6 +398,50 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["float16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
def test_bnb_regression(
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
):
|
||||
stop_sign = image_assets[0].pil_image
|
||||
prompts = [
|
||||
{
|
||||
"prompt": "<|begin_of_text|>The content of the image <|image|> is",
|
||||
"multi_modal_data": {
|
||||
"image": stop_sign
|
||||
},
|
||||
},
|
||||
{
|
||||
"prompt":
|
||||
"The color of the sky is blue but sometimes it can also be",
|
||||
},
|
||||
]
|
||||
# Test regression about QKVCrossParallelLinear
|
||||
llm = LLM(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
enforce_eager=True,
|
||||
quantization="bitsandbytes",
|
||||
load_format="bitsandbytes",
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
assert outputs
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
|
||||
# TODO(Isotr0py): We might need a more flexible structure to handle
|
||||
# bitsandbytes shard offsets.
|
||||
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
|
||||
"""
|
||||
Separate the BitsAndBytes 4-bit shard.
|
||||
|
||||
For example, given bnb weight attributes as below:
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4, 8, 16]),
|
||||
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
|
||||
}
|
||||
|
||||
The function will return:
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4]),
|
||||
'bnb_quant_state': {0: ...},
|
||||
}
|
||||
and
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4, 12]),
|
||||
'bnb_quant_state': {0: ..., 1: ...},
|
||||
}
|
||||
"""
|
||||
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
|
||||
offset_l = shard_offsets[:2]
|
||||
offset_r = shard_offsets[1:] - shard_offsets[1]
|
||||
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
|
||||
quant_state_r = {
|
||||
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
|
||||
for i in range(1,
|
||||
len(shard_offsets) - 1)
|
||||
}
|
||||
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
|
||||
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
|
||||
return left, right
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@ -1229,7 +1267,24 @@ class RowParallelLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
class QKVCrossParallelLinear(torch.nn.Module):
|
||||
class QKVCrossParallelLinear(LinearBase):
|
||||
"""Linear layers for efficient cross-attention's QKV transformation.
|
||||
|
||||
Args:
|
||||
hidden_size: input hidden state size of the transformer.
|
||||
head_size: size of each attention head.
|
||||
total_num_heads: total number of attention query heads.
|
||||
total_num_kv_heads: total number of attention key/value heads. If
|
||||
None, assume total_num_kv_heads = total_num_heads.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
@ -1241,12 +1296,28 @@ class QKVCrossParallelLinear(torch.nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
# input_size and output_size are not used, just for alignment
|
||||
input_size = hidden_size
|
||||
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
|
||||
super().__init__(input_size=input_size,
|
||||
output_size=output_size,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Empty placeholders for loading as a single module.
|
||||
self.weight = torch.nn.Parameter()
|
||||
set_weight_attrs(self.weight, {
|
||||
"weight_loader": self.weight_loader_weight,
|
||||
})
|
||||
placeholder_size = 0
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
placeholder_size, [placeholder_size],
|
||||
placeholder_size,
|
||||
placeholder_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
# Use a dictionary to avoid submodules parameters auto-registration:
|
||||
# drop-in replacement for a `QKVParallelLinear` module.
|
||||
self.proj = dict()
|
||||
@ -1276,18 +1347,94 @@ class QKVCrossParallelLinear(torch.nn.Module):
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter()
|
||||
set_weight_attrs(self.bias, {
|
||||
"weight_loader": self.weight_loader_bias,
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@property
|
||||
def q_proj_decoder(self):
|
||||
return self.proj["q_proj_decoder"]
|
||||
def q_proj_decoder(self) -> ColumnParallelLinear:
|
||||
layer = self.proj["q_proj_decoder"]
|
||||
for name, param in self.named_parameters():
|
||||
target_param = getattr(layer, name)
|
||||
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
|
||||
return layer
|
||||
|
||||
@property
|
||||
def kv_proj_encoder(self):
|
||||
return self.proj["kv_proj_encoder"]
|
||||
def kv_proj_encoder(self) -> QKVParallelLinear:
|
||||
layer = self.proj["kv_proj_encoder"]
|
||||
for name, param in self.named_parameters():
|
||||
target_param = getattr(layer, name)
|
||||
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
|
||||
return layer
|
||||
|
||||
def forward(self, decoder_hidden_states, encoder_hidden_states):
|
||||
def sync_weight_attrs(
|
||||
self,
|
||||
src_param: nn.Parameter,
|
||||
tgt_param: nn.Parameter,
|
||||
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
|
||||
):
|
||||
missing_attrs_dict = {
|
||||
k: getattr(src_param, k)
|
||||
for k in (set(src_param.__dict__.keys()) -
|
||||
set(tgt_param.__dict__.keys()))
|
||||
}
|
||||
# TODO(Isotr0py): handle bitsandbytes 8bit
|
||||
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
if (missing_attrs_dict and use_bitsandbytes_4bit):
|
||||
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
|
||||
missing_attrs_dict)
|
||||
if mode == "q_proj_decoder":
|
||||
set_weight_attrs(tgt_param, q_proj_attrs)
|
||||
elif mode == "kv_proj_encoder":
|
||||
set_weight_attrs(tgt_param, kv_proj_attrs)
|
||||
else:
|
||||
set_weight_attrs(tgt_param, missing_attrs_dict)
|
||||
|
||||
def _is_same_param(
|
||||
self,
|
||||
src_param: torch.nn.Parameter,
|
||||
map_param: torch.nn.Parameter,
|
||||
) -> bool:
|
||||
"""Check if two parameters are exactly pointing to same things."""
|
||||
# ignore weight_loader because it's always different
|
||||
key_to_ignore = ["weight_loader", "_weight_loader"]
|
||||
has_same_type_name = type(src_param) is type(map_param)
|
||||
src_param_attrs = {
|
||||
k: v
|
||||
for k, v in src_param.__dict__.items() if k not in key_to_ignore
|
||||
}
|
||||
map_param_attrs = {
|
||||
k: v
|
||||
for k, v in map_param.__dict__.items() if k not in key_to_ignore
|
||||
}
|
||||
has_same_attrs = src_param_attrs == map_param_attrs
|
||||
return has_same_type_name and has_same_attrs
|
||||
|
||||
def select_proj_params(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
param: nn.Parameter,
|
||||
) -> nn.Parameter:
|
||||
"""
|
||||
Given the placeholder param,
|
||||
return the corresponding param in the proj layers.
|
||||
"""
|
||||
target_param_list = [
|
||||
v for _, v in layer.named_parameters()
|
||||
if self._is_same_param(param, v)
|
||||
]
|
||||
assert len(target_param_list) == 1
|
||||
target_param = target_param_list[0]
|
||||
return target_param
|
||||
|
||||
def forward( # type: ignore[override]
|
||||
self,
|
||||
decoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
q, _ = self.q_proj_decoder(decoder_hidden_states)
|
||||
if encoder_hidden_states is None:
|
||||
# Encoder KV already cached.
|
||||
@ -1300,25 +1447,21 @@ class QKVCrossParallelLinear(torch.nn.Module):
|
||||
k, v = kv_enc.split(self.kv_size, dim=-1)
|
||||
return q, k, v
|
||||
|
||||
def weight_loader_weight(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
|
||||
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
|
||||
else self.kv_proj_encoder.weight
|
||||
param.weight_loader(
|
||||
param,
|
||||
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
|
||||
param, loaded_weight, loaded_shard_id)
|
||||
def weight_loader(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
layer = (self.q_proj_decoder
|
||||
if loaded_shard_id == "q" else self.kv_proj_encoder)
|
||||
target_param = self.select_proj_params(layer, param)
|
||||
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
|
||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||
|
||||
def weight_loader_bias(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
|
||||
else self.kv_proj_encoder.bias
|
||||
param.weight_loader(
|
||||
param,
|
||||
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
|
||||
param, loaded_weight, loaded_shard_id)
|
||||
def extra_repr(self) -> str:
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
|
||||
s += f", kv_size={self.kv_size}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
||||
s += ", gather_output=False"
|
||||
return s
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVCrossParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -813,20 +814,11 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.q_local_size = self.num_local_heads * self.head_dim
|
||||
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
||||
|
||||
# TODO(Isotr0py): Use QKVCrossParallelLinear when it supports
|
||||
# quantization
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.num_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
)
|
||||
self.kv_proj = QKVParallelLinear(
|
||||
self.qkv_proj = QKVCrossParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
total_num_heads=0,
|
||||
total_num_kv_heads=self.num_key_value_heads,
|
||||
self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
@ -862,15 +854,11 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
cross_attention_states: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
|
||||
if cross_attention_states is not None:
|
||||
kv, _ = self.kv_proj(cross_attention_states)
|
||||
k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)
|
||||
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
||||
k = self.k_norm(k)
|
||||
else:
|
||||
k = v = None
|
||||
|
||||
q = q.view(-1, self.num_local_heads, self.head_dim)
|
||||
q = self.q_norm(q)
|
||||
@ -1161,13 +1149,8 @@ class MllamaForCausalLM(nn.Module):
|
||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
],
|
||||
"cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -1437,11 +1420,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
(".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),
|
||||
(".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
@ -1570,4 +1551,4 @@ def convert_dense_cross_attention_mask_to_tensor(
|
||||
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
|
||||
mask *= full_text_mask
|
||||
# (num_prompt_tokens, num_encoder_tokens)
|
||||
return mask
|
||||
return mask
|
||||
Loading…
x
Reference in New Issue
Block a user