mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 12:58:02 +08:00
support bitsandbytes quantization with more models (#9148)
This commit is contained in:
parent
9ba0bd6aa6
commit
2f4117c38e
@ -9,22 +9,22 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
from tests.utils import fork_new_process_for_each_test
|
||||||
from ..utils import fork_new_process_for_each_test
|
|
||||||
|
|
||||||
models_4bit_to_test = [
|
models_4bit_to_test = [
|
||||||
('huggyllama/llama-7b', 'quantize model inflight'),
|
("facebook/opt-125m", "quantize opt model inflight"),
|
||||||
]
|
]
|
||||||
|
|
||||||
models_pre_qaunt_4bit_to_test = [
|
models_pre_qaunt_4bit_to_test = [
|
||||||
('lllyasviel/omost-llama-3-8b-4bits',
|
|
||||||
'read pre-quantized 4-bit NF4 model'),
|
|
||||||
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
|
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
|
||||||
'read pre-quantized 4-bit FP4 model'),
|
'read pre-quantized 4-bit FP4 model'),
|
||||||
|
('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'),
|
||||||
]
|
]
|
||||||
|
|
||||||
models_pre_quant_8bit_to_test = [
|
models_pre_quant_8bit_to_test = [
|
||||||
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
|
('meta-llama/Llama-Guard-3-8B-INT8',
|
||||||
|
'read pre-quantized llama 8-bit model'),
|
||||||
|
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -133,6 +133,7 @@ def validate_generated_texts(hf_runner,
|
|||||||
hf_str = hf_log["generated_text"]
|
hf_str = hf_log["generated_text"]
|
||||||
vllm_str = vllm_log["generated_text"]
|
vllm_str = vllm_log["generated_text"]
|
||||||
prompt = hf_log["prompt"]
|
prompt = hf_log["prompt"]
|
||||||
|
|
||||||
assert hf_str == vllm_str, (f"Model: {model_name}"
|
assert hf_str == vllm_str, (f"Model: {model_name}"
|
||||||
f"Mismatch between HF and vLLM outputs:\n"
|
f"Mismatch between HF and vLLM outputs:\n"
|
||||||
f"Prompt: {prompt}\n"
|
f"Prompt: {prompt}\n"
|
||||||
|
|||||||
@ -336,8 +336,12 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||||
|
|
||||||
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
|
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
if output_dim is not None:
|
# bitsandbytes loads the weights of the specific portion
|
||||||
|
# no need to narrow here
|
||||||
|
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||||
shard_size = param_data.shape[output_dim]
|
shard_size = param_data.shape[output_dim]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
@ -821,6 +825,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
("v", (self.total_num_heads + self.total_num_kv_heads) *
|
("v", (self.total_num_heads + self.total_num_kv_heads) *
|
||||||
self.head_size, self.total_num_kv_heads * self.head_size),
|
self.head_size, self.total_num_kv_heads * self.head_size),
|
||||||
]
|
]
|
||||||
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||||
|
False)
|
||||||
|
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||||
# Special case for Quantized Weights.
|
# Special case for Quantized Weights.
|
||||||
@ -834,6 +841,23 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
shard_size, shard_offset = adjust_marlin_shard(
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
param, shard_size, shard_offset)
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
|
if use_bitsandbytes_4bit:
|
||||||
|
orig_qkv_offsets = {
|
||||||
|
"q": (0, self.total_num_heads * self.head_size),
|
||||||
|
"k": (self.total_num_heads * self.head_size,
|
||||||
|
self.total_num_kv_heads * self.head_size),
|
||||||
|
"v":
|
||||||
|
((self.total_num_heads + self.total_num_kv_heads) *
|
||||||
|
self.head_size,
|
||||||
|
self.total_num_kv_heads * self.head_size),
|
||||||
|
"total":
|
||||||
|
((self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_size, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||||
|
param, orig_qkv_offsets, shard_id)
|
||||||
|
|
||||||
loaded_weight_shard = loaded_weight.narrow(
|
loaded_weight_shard = loaded_weight.narrow(
|
||||||
output_dim, shard_offset, shard_size)
|
output_dim, shard_offset, shard_size)
|
||||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||||
|
|||||||
@ -108,7 +108,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
return []
|
||||||
|
|
||||||
|
|
||||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||||
@ -236,7 +236,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
|||||||
if generation == 0 or generation == 1:
|
if generation == 0 or generation == 1:
|
||||||
matmul_states[i] = MatmulLtState()
|
matmul_states[i] = MatmulLtState()
|
||||||
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
|
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
|
||||||
matmul_states[i].SCB = quant_states[i]
|
matmul_states[i].SCB = quant_states[i].to(x.device)
|
||||||
matmul_states[i].threshold = (
|
matmul_states[i].threshold = (
|
||||||
self.quant_config.llm_int8_threshold)
|
self.quant_config.llm_int8_threshold)
|
||||||
matmul_states[i].has_fp16_weights = (
|
matmul_states[i].has_fp16_weights = (
|
||||||
|
|||||||
@ -736,15 +736,26 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||||
|
|
||||||
# TODO: these module names are for Llama only,
|
|
||||||
# change so that it works with other models as well
|
|
||||||
default_target_modules = [
|
|
||||||
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
|
|
||||||
"o_proj"
|
|
||||||
]
|
|
||||||
|
|
||||||
possible_config_file_names = ["adapter_config.json"]
|
possible_config_file_names = ["adapter_config.json"]
|
||||||
|
|
||||||
|
default_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
'.fc1.',
|
||||||
|
'.fc2.',
|
||||||
|
'.dense.',
|
||||||
|
'.query_key_value.',
|
||||||
|
'.qkv_proj.',
|
||||||
|
'.dense_h_to_4h.',
|
||||||
|
'.dense_4h_to_h.',
|
||||||
|
'.out_proj.',
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
|
|
||||||
@ -754,7 +765,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
if (not load_config.model_loader_extra_config
|
if (not load_config.model_loader_extra_config
|
||||||
or "qlora_adapter_name_or_path"
|
or "qlora_adapter_name_or_path"
|
||||||
not in load_config.model_loader_extra_config):
|
not in load_config.model_loader_extra_config):
|
||||||
self.target_modules = self.default_target_modules
|
self.target_modules = []
|
||||||
return
|
return
|
||||||
|
|
||||||
qlora_adapter = load_config.model_loader_extra_config[
|
qlora_adapter = load_config.model_loader_extra_config[
|
||||||
@ -901,10 +912,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
hf_weights_files, use_safetensors):
|
hf_weights_files, use_safetensors):
|
||||||
|
|
||||||
if not weight_name.endswith(".weight"):
|
if not weight_name.endswith((".weight", ".bias")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
qweight_name = weight_name.replace(".weight", ".qweight")
|
qweight_name = weight_name.replace(".weight", ".qweight")
|
||||||
|
|
||||||
if qweight_name in quant_state_dict:
|
if qweight_name in quant_state_dict:
|
||||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||||
yield qweight_name, weight_tensor
|
yield qweight_name, weight_tensor
|
||||||
@ -920,7 +932,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
use_safetensors)
|
use_safetensors)
|
||||||
temp_state_dict = {}
|
temp_state_dict = {}
|
||||||
for weight_name, weight_tensor in weight_iterator:
|
for weight_name, weight_tensor in weight_iterator:
|
||||||
if weight_name.endswith(".weight"):
|
if weight_name.endswith((".weight", ".bias")):
|
||||||
continue
|
continue
|
||||||
# bitsandbytes library requires
|
# bitsandbytes library requires
|
||||||
# weight.quant_state.bitsandbytes__* in CPU
|
# weight.quant_state.bitsandbytes__* in CPU
|
||||||
@ -943,9 +955,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# pre quantized weights would have a quant_state
|
# pre quantized weights would have a quant_state
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
hf_weights_files, use_safetensors):
|
hf_weights_files, use_safetensors):
|
||||||
# Filter out all weights whose suffix is not ".weight"
|
|
||||||
if not weight_name.endswith(".weight"):
|
if not weight_name.endswith((".weight", ".bias")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
|
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
|
||||||
in temp_state_dict) or \
|
in temp_state_dict) or \
|
||||||
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
||||||
@ -965,15 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
hf_weights_files, use_safetensors):
|
hf_weights_files, use_safetensors):
|
||||||
if any(target_module in weight_name
|
|
||||||
for target_module in self.target_modules):
|
if any(target_module in weight_name for target_module in
|
||||||
|
self.target_modules) and weight_name.endswith(".weight"):
|
||||||
weight_name = weight_name.replace(".weight", ".qweight")
|
weight_name = weight_name.replace(".weight", ".qweight")
|
||||||
|
|
||||||
# weight partitions of different modules occur at
|
if any(module in weight_name
|
||||||
# different dimensions
|
for module in self.column_parallel_weights_modules):
|
||||||
# TODO: these module names are for Llama only,
|
|
||||||
# change so that it works with other models as well
|
|
||||||
if 'down_proj' in weight_name or 'o_proj' in weight_name:
|
|
||||||
total_size = weight_tensor.size(-1)
|
total_size = weight_tensor.size(-1)
|
||||||
start_index = total_size // tp_size * tp_rank
|
start_index = total_size // tp_size * tp_rank
|
||||||
end_index = total_size // tp_size * (tp_rank + 1)
|
end_index = total_size // tp_size * (tp_rank + 1)
|
||||||
@ -1022,6 +1034,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||||
"quantization yet.")
|
"quantization yet.")
|
||||||
|
|
||||||
|
if len(self.target_modules) == 0:
|
||||||
|
if hasattr(model, 'default_bitsandbytes_target_modules'):
|
||||||
|
self.target_modules = model.default_bitsandbytes_target_modules
|
||||||
|
else:
|
||||||
|
self.target_modules = self.default_target_modules
|
||||||
|
|
||||||
|
if hasattr(model, 'column_parallel_weights_modules'):
|
||||||
|
self.column_parallel_weights_modules = \
|
||||||
|
model.column_parallel_weights_modules
|
||||||
|
else:
|
||||||
|
self.column_parallel_weights_modules = []
|
||||||
|
|
||||||
|
self.model_type = type(model).__name__
|
||||||
|
|
||||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||||
" May take a while ...")
|
" May take a while ...")
|
||||||
|
|
||||||
|
|||||||
@ -391,6 +391,17 @@ class FalconModel(nn.Module):
|
|||||||
|
|
||||||
class FalconForCausalLM(nn.Module, SupportsPP):
|
class FalconForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
bitsandbytes_stacked_params_mapping = {}
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".query_key_value.",
|
||||||
|
".dense.",
|
||||||
|
".dense_h_to_4h.",
|
||||||
|
".dense_4h_to_h.",
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: FalconConfig,
|
config: FalconConfig,
|
||||||
|
|||||||
@ -332,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
"gate_up_proj",
|
"gate_up_proj",
|
||||||
"down_proj",
|
"down_proj",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
# Gemma does not apply LoRA to the embedding layer.
|
# Gemma does not apply LoRA to the embedding layer.
|
||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|||||||
@ -375,6 +375,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# Gemma does not apply LoRA to the embedding layer.
|
# Gemma does not apply LoRA to the embedding layer.
|
||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -449,6 +449,19 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
"lm_head": "output_embeddings"
|
"lm_head": "output_embeddings"
|
||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -315,6 +315,19 @@ class OPTModel(nn.Module):
|
|||||||
|
|
||||||
class OPTForCausalLM(nn.Module, SupportsPP):
|
class OPTForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
}
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".out_proj.", ".fc2."]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OPTConfig,
|
config: OPTConfig,
|
||||||
|
|||||||
@ -260,6 +260,20 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
"fc1",
|
"fc1",
|
||||||
"fc2",
|
"fc2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
}
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".fc2.", ".dense."]
|
||||||
|
|
||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user