mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 09:04:27 +08:00
support bitsandbytes quantization with more models (#9148)
This commit is contained in:
parent
9ba0bd6aa6
commit
2f4117c38e
@ -9,22 +9,22 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
from tests.utils import fork_new_process_for_each_test
|
||||
|
||||
models_4bit_to_test = [
|
||||
('huggyllama/llama-7b', 'quantize model inflight'),
|
||||
("facebook/opt-125m", "quantize opt model inflight"),
|
||||
]
|
||||
|
||||
models_pre_qaunt_4bit_to_test = [
|
||||
('lllyasviel/omost-llama-3-8b-4bits',
|
||||
'read pre-quantized 4-bit NF4 model'),
|
||||
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
|
||||
'read pre-quantized 4-bit FP4 model'),
|
||||
('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'),
|
||||
]
|
||||
|
||||
models_pre_quant_8bit_to_test = [
|
||||
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
|
||||
('meta-llama/Llama-Guard-3-8B-INT8',
|
||||
'read pre-quantized llama 8-bit model'),
|
||||
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
|
||||
]
|
||||
|
||||
|
||||
@ -133,6 +133,7 @@ def validate_generated_texts(hf_runner,
|
||||
hf_str = hf_log["generated_text"]
|
||||
vllm_str = vllm_log["generated_text"]
|
||||
prompt = hf_log["prompt"]
|
||||
|
||||
assert hf_str == vllm_str, (f"Model: {model_name}"
|
||||
f"Mismatch between HF and vLLM outputs:\n"
|
||||
f"Prompt: {prompt}\n"
|
||||
|
||||
@ -336,8 +336,12 @@ class ColumnParallelLinear(LinearBase):
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None:
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
@ -821,6 +825,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
("v", (self.total_num_heads + self.total_num_kv_heads) *
|
||||
self.head_size, self.total_num_kv_heads * self.head_size),
|
||||
]
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
@ -834,6 +841,23 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.total_num_heads * self.head_size),
|
||||
"k": (self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
"v":
|
||||
((self.total_num_heads + self.total_num_kv_heads) *
|
||||
self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
"total":
|
||||
((self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_size, 0)
|
||||
}
|
||||
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, shard_id)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
|
||||
@ -108,7 +108,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
return []
|
||||
|
||||
|
||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
@ -236,7 +236,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
if generation == 0 or generation == 1:
|
||||
matmul_states[i] = MatmulLtState()
|
||||
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
|
||||
matmul_states[i].SCB = quant_states[i]
|
||||
matmul_states[i].SCB = quant_states[i].to(x.device)
|
||||
matmul_states[i].threshold = (
|
||||
self.quant_config.llm_int8_threshold)
|
||||
matmul_states[i].has_fp16_weights = (
|
||||
|
||||
@ -736,15 +736,26 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||
|
||||
# TODO: these module names are for Llama only,
|
||||
# change so that it works with other models as well
|
||||
default_target_modules = [
|
||||
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
|
||||
"o_proj"
|
||||
]
|
||||
|
||||
possible_config_file_names = ["adapter_config.json"]
|
||||
|
||||
default_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
'.fc1.',
|
||||
'.fc2.',
|
||||
'.dense.',
|
||||
'.query_key_value.',
|
||||
'.qkv_proj.',
|
||||
'.dense_h_to_4h.',
|
||||
'.dense_4h_to_h.',
|
||||
'.out_proj.',
|
||||
]
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
@ -754,7 +765,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if (not load_config.model_loader_extra_config
|
||||
or "qlora_adapter_name_or_path"
|
||||
not in load_config.model_loader_extra_config):
|
||||
self.target_modules = self.default_target_modules
|
||||
self.target_modules = []
|
||||
return
|
||||
|
||||
qlora_adapter = load_config.model_loader_extra_config[
|
||||
@ -901,10 +912,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
|
||||
if not weight_name.endswith(".weight"):
|
||||
if not weight_name.endswith((".weight", ".bias")):
|
||||
continue
|
||||
|
||||
qweight_name = weight_name.replace(".weight", ".qweight")
|
||||
|
||||
if qweight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield qweight_name, weight_tensor
|
||||
@ -920,7 +932,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for weight_name, weight_tensor in weight_iterator:
|
||||
if weight_name.endswith(".weight"):
|
||||
if weight_name.endswith((".weight", ".bias")):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
@ -943,9 +955,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# pre quantized weights would have a quant_state
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
# Filter out all weights whose suffix is not ".weight"
|
||||
if not weight_name.endswith(".weight"):
|
||||
|
||||
if not weight_name.endswith((".weight", ".bias")):
|
||||
continue
|
||||
|
||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
|
||||
in temp_state_dict) or \
|
||||
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
||||
@ -965,15 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if any(target_module in weight_name
|
||||
for target_module in self.target_modules):
|
||||
|
||||
if any(target_module in weight_name for target_module in
|
||||
self.target_modules) and weight_name.endswith(".weight"):
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
|
||||
# weight partitions of different modules occur at
|
||||
# different dimensions
|
||||
# TODO: these module names are for Llama only,
|
||||
# change so that it works with other models as well
|
||||
if 'down_proj' in weight_name or 'o_proj' in weight_name:
|
||||
if any(module in weight_name
|
||||
for module in self.column_parallel_weights_modules):
|
||||
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
@ -1022,6 +1034,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet.")
|
||||
|
||||
if len(self.target_modules) == 0:
|
||||
if hasattr(model, 'default_bitsandbytes_target_modules'):
|
||||
self.target_modules = model.default_bitsandbytes_target_modules
|
||||
else:
|
||||
self.target_modules = self.default_target_modules
|
||||
|
||||
if hasattr(model, 'column_parallel_weights_modules'):
|
||||
self.column_parallel_weights_modules = \
|
||||
model.column_parallel_weights_modules
|
||||
else:
|
||||
self.column_parallel_weights_modules = []
|
||||
|
||||
self.model_type = type(model).__name__
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
" May take a while ...")
|
||||
|
||||
|
||||
@ -391,6 +391,17 @@ class FalconModel(nn.Module):
|
||||
|
||||
class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {}
|
||||
default_bitsandbytes_target_modules = [
|
||||
".query_key_value.",
|
||||
".dense.",
|
||||
".dense_h_to_4h.",
|
||||
".dense_4h_to_h.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
|
||||
@ -332,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
# Gemma does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
@ -375,6 +375,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# Gemma does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
@ -449,6 +449,19 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"lm_head": "output_embeddings"
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
@ -315,6 +315,19 @@ class OPTModel(nn.Module):
|
||||
|
||||
class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
}
|
||||
default_bitsandbytes_target_modules = [
|
||||
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".out_proj.", ".fc2."]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
|
||||
@ -260,6 +260,20 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"fc1",
|
||||
"fc2",
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
}
|
||||
default_bitsandbytes_target_modules = [
|
||||
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".fc2.", ".dense."]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user