support bitsandbytes quantization with more models (#9148)

This commit is contained in:
chenqianfzh 2024-10-08 18:52:19 -07:00 committed by GitHub
parent 9ba0bd6aa6
commit 2f4117c38e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 165 additions and 28 deletions

View File

@ -9,22 +9,22 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from ..utils import fork_new_process_for_each_test
from tests.utils import fork_new_process_for_each_test
models_4bit_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
("facebook/opt-125m", "quantize opt model inflight"),
]
models_pre_qaunt_4bit_to_test = [
('lllyasviel/omost-llama-3-8b-4bits',
'read pre-quantized 4-bit NF4 model'),
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'),
]
models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
('meta-llama/Llama-Guard-3-8B-INT8',
'read pre-quantized llama 8-bit model'),
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
]
@ -133,6 +133,7 @@ def validate_generated_texts(hf_runner,
hf_str = hf_log["generated_text"]
vllm_str = vllm_log["generated_text"]
prompt = hf_log["prompt"]
assert hf_str == vllm_str, (f"Model: {model_name}"
f"Mismatch between HF and vLLM outputs:\n"
f"Prompt: {prompt}\n"

View File

@ -336,8 +336,12 @@ class ColumnParallelLinear(LinearBase):
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
param_data = param.data
if output_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@ -821,6 +825,9 @@ class QKVParallelLinear(ColumnParallelLinear):
("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
@ -834,6 +841,23 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
"v":
((self.total_num_heads + self.total_num_kv_heads) *
self.head_size,
self.total_num_kv_heads * self.head_size),
"total":
((self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_size, 0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)

View File

@ -108,7 +108,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
return []
class BitsAndBytesLinearMethod(LinearMethodBase):
@ -236,7 +236,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
if generation == 0 or generation == 1:
matmul_states[i] = MatmulLtState()
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
matmul_states[i].SCB = quant_states[i]
matmul_states[i].SCB = quant_states[i].to(x.device)
matmul_states[i].threshold = (
self.quant_config.llm_int8_threshold)
matmul_states[i].has_fp16_weights = (

View File

@ -736,15 +736,26 @@ class ShardedStateLoader(BaseModelLoader):
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
# TODO: these module names are for Llama only,
# change so that it works with other models as well
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
possible_config_file_names = ["adapter_config.json"]
default_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
'.fc1.',
'.fc2.',
'.dense.',
'.query_key_value.',
'.qkv_proj.',
'.dense_h_to_4h.',
'.dense_4h_to_h.',
'.out_proj.',
]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
@ -754,7 +765,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
self.target_modules = self.default_target_modules
self.target_modules = []
return
qlora_adapter = load_config.model_loader_extra_config[
@ -901,10 +912,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if not weight_name.endswith(".weight"):
if not weight_name.endswith((".weight", ".bias")):
continue
qweight_name = weight_name.replace(".weight", ".qweight")
if qweight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor
@ -920,7 +932,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
if weight_name.endswith((".weight", ".bias")):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
@ -943,9 +955,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
if not weight_name.endswith((".weight", ".bias")):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
in temp_state_dict) or \
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
@ -965,15 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
weight_name = weight_name.replace(".weight", ".qweight")
# weight partitions of different modules occur at
# different dimensions
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if 'down_proj' in weight_name or 'o_proj' in weight_name:
if any(module in weight_name
for module in self.column_parallel_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
@ -1022,6 +1034,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.")
if len(self.target_modules) == 0:
if hasattr(model, 'default_bitsandbytes_target_modules'):
self.target_modules = model.default_bitsandbytes_target_modules
else:
self.target_modules = self.default_target_modules
if hasattr(model, 'column_parallel_weights_modules'):
self.column_parallel_weights_modules = \
model.column_parallel_weights_modules
else:
self.column_parallel_weights_modules = []
self.model_type = type(model).__name__
logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...")

View File

@ -391,6 +391,17 @@ class FalconModel(nn.Module):
class FalconForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
default_bitsandbytes_target_modules = [
".query_key_value.",
".dense.",
".dense_h_to_4h.",
".dense_4h_to_h.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]
def __init__(
self,
config: FalconConfig,

View File

@ -332,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"gate_up_proj",
"down_proj",
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []

View File

@ -375,6 +375,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),

View File

@ -449,6 +449,19 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings"
}
embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),

View File

@ -315,6 +315,19 @@ class OPTModel(nn.Module):
class OPTForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".out_proj.", ".fc2."]
def __init__(
self,
config: OPTConfig,

View File

@ -260,6 +260,20 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"fc1",
"fc2",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".fc2.", ".dense."]
embedding_modules = {}
embedding_padding_modules = []