mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
support bitsandbytes 8-bit and FP4 quantized models (#7445)
This commit is contained in:
parent
257afc37c5
commit
4664ceaad6
@ -209,8 +209,14 @@ class HfRunner:
|
||||
|
||||
def wrap_device(self, input: _T) -> _T:
|
||||
if not is_cpu():
|
||||
# Check if the input is already on the GPU
|
||||
if hasattr(input, 'device') and input.device.type == "cuda":
|
||||
return input # Already on GPU, no need to move
|
||||
return input.to("cuda")
|
||||
else:
|
||||
# Check if the input is already on the CPU
|
||||
if hasattr(input, 'device') and input.device.type == "cpu":
|
||||
return input # Already on CPU, no need to move
|
||||
return input.to("cpu")
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -2,85 +2,115 @@
|
||||
|
||||
Run `pytest tests/quantization/test_bitsandbytes.py`.
|
||||
'''
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import SamplingParams
|
||||
|
||||
models_to_test = [
|
||||
models_4bit_to_test = [
|
||||
('huggyllama/llama-7b', 'quantize model inflight'),
|
||||
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
|
||||
]
|
||||
|
||||
models_pre_qaunt_4bit_to_test = [
|
||||
('lllyasviel/omost-llama-3-8b-4bits',
|
||||
'read pre-quantized 4-bit NF4 model'),
|
||||
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
|
||||
'read pre-quantized 4-bit FP4 model'),
|
||||
]
|
||||
|
||||
models_pre_quant_8bit_to_test = [
|
||||
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description", models_to_test)
|
||||
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
|
||||
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
|
||||
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
hf_model_kwargs = {"load_in_4bit": True}
|
||||
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
|
||||
model_name, hf_model_kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description",
|
||||
models_pre_qaunt_4bit_to_test)
|
||||
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
|
||||
model_name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description",
|
||||
models_pre_quant_8bit_to_test)
|
||||
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
|
||||
model_name)
|
||||
|
||||
|
||||
def log_generated_texts(prompts, outputs, runner_name):
|
||||
logged_texts = []
|
||||
for i, (_, generated_text) in enumerate(outputs):
|
||||
log_entry = {
|
||||
"prompt": prompts[i],
|
||||
"runner_name": runner_name,
|
||||
"generated_text": generated_text,
|
||||
}
|
||||
logged_texts.append(log_entry)
|
||||
return logged_texts
|
||||
|
||||
|
||||
def validate_generated_texts(hf_runner,
|
||||
vllm_runner,
|
||||
prompts,
|
||||
model_name,
|
||||
hf_model_kwargs=None):
|
||||
|
||||
if hf_model_kwargs is None:
|
||||
hf_model_kwargs = {}
|
||||
|
||||
# Run with HF runner
|
||||
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||
hf_outputs = llm.generate_greedy(prompts, 8)
|
||||
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
#Run with vLLM runner
|
||||
with vllm_runner(model_name,
|
||||
quantization='bitsandbytes',
|
||||
load_format='bitsandbytes',
|
||||
enforce_eager=True) as llm:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8) as llm:
|
||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
||||
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
||||
|
||||
# check the weights in MLP & SelfAttention are quantized to torch.uint8
|
||||
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
|
||||
assert qweight.dtype == torch.uint8, (
|
||||
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')
|
||||
# Clean up the GPU memory for the next test
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
qweight = model.model.layers[0].mlp.down_proj.qweight
|
||||
assert qweight.dtype == torch.uint8, (
|
||||
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')
|
||||
|
||||
qweight = model.model.layers[0].self_attn.o_proj.qweight
|
||||
assert qweight.dtype == torch.uint8, (
|
||||
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')
|
||||
|
||||
qweight = model.model.layers[0].self_attn.qkv_proj.qweight
|
||||
assert qweight.dtype == torch.uint8, (
|
||||
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')
|
||||
|
||||
# some weights should not be quantized
|
||||
weight = model.lm_head.weight
|
||||
assert weight.dtype != torch.uint8, (
|
||||
'lm_head weight dtype should not be torch.uint8')
|
||||
|
||||
weight = model.model.embed_tokens.weight
|
||||
assert weight.dtype != torch.uint8, (
|
||||
'embed_tokens weight dtype should not be torch.uint8')
|
||||
|
||||
weight = model.model.layers[0].input_layernorm.weight
|
||||
assert weight.dtype != torch.uint8, (
|
||||
'input_layernorm weight dtype should not be torch.uint8')
|
||||
|
||||
weight = model.model.layers[0].post_attention_layernorm.weight
|
||||
assert weight.dtype != torch.uint8, (
|
||||
'input_layernorm weight dtype should not be torch.uint8')
|
||||
|
||||
# check the output of the model is expected
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=8)
|
||||
|
||||
prompts = ['That which does not kill us', 'To be or not to be,']
|
||||
expected_outputs = [
|
||||
'That which does not kill us makes us stronger.',
|
||||
'To be or not to be, that is the question.'
|
||||
]
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
assert len(outputs) == len(prompts)
|
||||
|
||||
for index in range(len(outputs)):
|
||||
# compare the first line of the output
|
||||
actual_output = outputs[index][1][0].split('\n', 1)[0]
|
||||
expected_output = expected_outputs[index].split('\n', 1)[0]
|
||||
|
||||
assert len(actual_output) >= len(expected_output), (
|
||||
f'Actual {actual_output} should be larger than or equal to '
|
||||
f'expected {expected_output}')
|
||||
actual_output = actual_output[:len(expected_output)]
|
||||
|
||||
assert actual_output == expected_output, (
|
||||
f'Expected: {expected_output}, but got: {actual_output}')
|
||||
# Compare the generated strings
|
||||
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
|
||||
hf_str = hf_log["generated_text"]
|
||||
vllm_str = vllm_log["generated_text"]
|
||||
prompt = hf_log["prompt"]
|
||||
assert hf_str == vllm_str, (f"Model: {model_name}"
|
||||
f"Mismatch between HF and vLLM outputs:\n"
|
||||
f"Prompt: {prompt}\n"
|
||||
f"HF Output: '{hf_str}'\n"
|
||||
f"vLLM Output: '{vllm_str}'")
|
||||
|
||||
@ -405,6 +405,8 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
"BitAndBytes quantization with TP or PP is not supported yet.")
|
||||
|
||||
# Remove the constraint after the bitsandbytes issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
|
||||
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
|
||||
logger.warning("CUDA graph is not supported on BitAndBytes yet, "
|
||||
"fallback to the eager mode.")
|
||||
|
||||
@ -36,9 +36,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def adjust_bitsandbytes_shard(param: Parameter,
|
||||
qkv_offsets: Dict[str, Tuple[int, int]],
|
||||
loaded_shard_id: str) -> Tuple[int, int]:
|
||||
def adjust_bitsandbytes_4bit_shard(param: Parameter,
|
||||
qkv_offsets: Dict[str, Tuple[int, int]],
|
||||
loaded_shard_id: str) -> Tuple[int, int]:
|
||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||
|
||||
total, _ = qkv_offsets["total"]
|
||||
@ -505,8 +505,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
|
||||
if use_bitsandbytes:
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
@ -858,8 +859,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
|
||||
if use_bitsandbytes:
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.num_heads * self.head_size),
|
||||
"k": (self.num_heads * self.head_size,
|
||||
@ -871,7 +873,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
||||
0)
|
||||
}
|
||||
shard_size, shard_offset = adjust_bitsandbytes_shard(
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id)
|
||||
|
||||
if is_gguf_weight:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
@ -15,8 +14,28 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
Reference: https://arxiv.org/abs/2305.14314
|
||||
"""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
load_in_8bit: bool = False,
|
||||
load_in_4bit: bool = True,
|
||||
bnb_4bit_compute_dtype: str = "float32",
|
||||
bnb_4bit_quant_type: str = "fp4",
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: Optional[Any] = None,
|
||||
llm_int8_threshold: float = 0.0,
|
||||
) -> None:
|
||||
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.load_in_4bit = load_in_4bit
|
||||
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
|
||||
self.bnb_4bit_quant_type = bnb_4bit_quant_type
|
||||
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
||||
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
||||
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "BitsAndBytesConfig"
|
||||
@ -41,7 +60,46 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
return cls()
|
||||
|
||||
def get_safe_value(config, keys, default_value=None):
|
||||
try:
|
||||
value = cls.get_from_keys(config, keys)
|
||||
return value if value is not None else default_value
|
||||
except ValueError:
|
||||
return default_value
|
||||
|
||||
load_in_8bit = get_safe_value(config, ["load_in_8bit"],
|
||||
default_value=False)
|
||||
load_in_4bit = get_safe_value(config, ["load_in_4bit"],
|
||||
default_value=True)
|
||||
bnb_4bit_compute_dtype = get_safe_value(config,
|
||||
["bnb_4bit_compute_dtype"],
|
||||
default_value="float32")
|
||||
bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
|
||||
default_value="fp4")
|
||||
bnb_4bit_use_double_quant = get_safe_value(
|
||||
config, ["bnb_4bit_use_double_quant"], default_value=False)
|
||||
llm_int8_enable_fp32_cpu_offload = get_safe_value(
|
||||
config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False)
|
||||
llm_int8_has_fp16_weight = get_safe_value(config,
|
||||
["llm_int8_has_fp16_weight"],
|
||||
default_value=False)
|
||||
llm_int8_skip_modules = get_safe_value(config,
|
||||
["llm_int8_skip_modules"],
|
||||
default_value=[])
|
||||
llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"],
|
||||
default_value=0.0)
|
||||
|
||||
return cls(
|
||||
load_in_8bit=load_in_8bit,
|
||||
load_in_4bit=load_in_4bit,
|
||||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
||||
bnb_4bit_quant_type=bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
|
||||
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
|
||||
llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
|
||||
llm_int8_skip_modules=llm_int8_skip_modules,
|
||||
llm_int8_threshold=llm_int8_threshold)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
|
||||
@ -78,39 +136,58 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
quant_ratio = 0
|
||||
if params_dtype.is_floating_point:
|
||||
quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo(
|
||||
torch.uint8).bits
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
def calculate_quant_ratio(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
|
||||
def create_qweight_for_8bit():
|
||||
qweight = Int8Params(
|
||||
data=torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 0,
|
||||
"pack_factor": 1,
|
||||
"use_bitsandbytes_8bit": True,
|
||||
"generation": 0
|
||||
})
|
||||
return qweight
|
||||
|
||||
def create_qweight_for_4bit():
|
||||
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||
|
||||
total_size = input_size_per_partition * sum(output_partition_sizes)
|
||||
if total_size % quant_ratio != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape.")
|
||||
|
||||
qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio,
|
||||
1,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 0,
|
||||
"pack_factor": quant_ratio,
|
||||
"use_bitsandbytes_4bit": True
|
||||
})
|
||||
return qweight
|
||||
|
||||
if self.quant_config.load_in_8bit:
|
||||
qweight = create_qweight_for_8bit()
|
||||
else:
|
||||
quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo(
|
||||
torch.uint8).bits
|
||||
qweight = create_qweight_for_4bit()
|
||||
|
||||
if input_size_per_partition * sum(
|
||||
output_partition_sizes) % quant_ratio != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. ")
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition * sum(output_partition_sizes) //
|
||||
quant_ratio,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 0,
|
||||
# In bitsandbytes, a tensor of shape [n,m] is quantized to
|
||||
#[n*m/pack_ratio, 1],so the output_dim is 0
|
||||
"output_dim": 0,
|
||||
"pack_factor": quant_ratio,
|
||||
"use_bitsandbytes": True,
|
||||
})
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
|
||||
@ -119,6 +196,88 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.quant_config.load_in_8bit:
|
||||
return self._apply_8bit_weight(layer, x, bias)
|
||||
else:
|
||||
return self._apply_4bit_weight(layer, x, bias)
|
||||
|
||||
def _apply_8bit_weight(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import MatmulLtState, matmul
|
||||
|
||||
original_type = x.dtype
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
quant_states = qweight.bnb_quant_state
|
||||
matmul_states = qweight.matmul_state
|
||||
generation = qweight.generation
|
||||
|
||||
out_dim_0 = x.shape[0]
|
||||
out_dim_1 = sum(
|
||||
[quant_state[1].shape[0] for quant_state in quant_states.items()])
|
||||
out = torch.empty(out_dim_0,
|
||||
out_dim_1,
|
||||
dtype=torch.float16,
|
||||
device=x.device)
|
||||
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
|
||||
# in profile_run or the first generation of inference,
|
||||
# create new matmul_states
|
||||
if generation == 0 or generation == 1:
|
||||
matmul_states[i] = MatmulLtState()
|
||||
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
|
||||
matmul_states[i].SCB = quant_states[i]
|
||||
matmul_states[i].threshold = (
|
||||
self.quant_config.llm_int8_threshold)
|
||||
matmul_states[i].has_fp16_weights = (
|
||||
self.quant_config.llm_int8_has_fp16_weight)
|
||||
matmul_states[i].is_training = False
|
||||
if matmul_states[i].threshold > 0.0 and not matmul_states[
|
||||
i].has_fp16_weights:
|
||||
matmul_states[i].use_pool = True
|
||||
|
||||
new_x = bf_x.unsqueeze(0)
|
||||
|
||||
out[:, current_index:current_index + output_size] = matmul(
|
||||
new_x,
|
||||
qweight[offsets[i]:offsets[i + 1]],
|
||||
state=matmul_states[i])
|
||||
|
||||
current_index += output_size
|
||||
|
||||
# only update the matmul_states if it is not profile_run
|
||||
if (generation > 0
|
||||
and not self.quant_config.llm_int8_has_fp16_weight
|
||||
and matmul_states[i].CB is not None
|
||||
and matmul_states[i].CxB is not None):
|
||||
del matmul_states[i].CB
|
||||
qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB
|
||||
|
||||
out = out.to(original_type)
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
qweight.generation += 1
|
||||
|
||||
return out
|
||||
|
||||
def _apply_4bit_weight(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import matmul_4bit
|
||||
|
||||
|
||||
@ -771,7 +771,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
return pt_weights_iterator(hf_weights_files)
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
pre_quant: bool,
|
||||
load_8bit: bool,
|
||||
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
|
||||
Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
@ -780,11 +784,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# only load the bitsandbytes module when needed
|
||||
try:
|
||||
import bitsandbytes
|
||||
from bitsandbytes.functional import QuantState
|
||||
if bitsandbytes.__version__ < "0.42.0":
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.42.0.")
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install bitsandbytes>=0.42.0 via "
|
||||
"`pip install bitsandbytes>=0.42.0` to use "
|
||||
@ -793,80 +795,111 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision)
|
||||
|
||||
quant_state_dict = {}
|
||||
|
||||
def quantized_checkpoint() -> Generator:
|
||||
# First iterate over all quant state weights
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
||||
use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for weight_name, weight_tensor in weight_iterator:
|
||||
if weight_name.endswith(".weight"):
|
||||
continue
|
||||
# TODO: only nf4 quantization is supported for now
|
||||
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
|
||||
raise NotImplementedError(
|
||||
"Only bitsandbytes_nf4 quantization"
|
||||
f"is supported for now. {weight_name} is fp4 quantized"
|
||||
)
|
||||
temp_state_dict[weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str,
|
||||
temp_state_dict: Dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__nf4 in CPU
|
||||
quant_state[param_name +
|
||||
".quant_state.bitsandbytes__nf4"] = quant_state[
|
||||
param_name +
|
||||
".quant_state.bitsandbytes__nf4"].cpu().data
|
||||
return QuantState.from_dict(quant_state, device="cuda")
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
# Filter out all weights whose suffix is not ".weight"
|
||||
if not weight_name.endswith(".weight"):
|
||||
continue
|
||||
if weight_name + ".quant_state.bitsandbytes__nf4" \
|
||||
in temp_state_dict:
|
||||
quant_state = _parse_quant_state(weight_name,
|
||||
temp_state_dict)
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
yield weight_name.replace(".weight",
|
||||
".qweight"), weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
|
||||
def generator() -> Generator:
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if any(target_module in weight_name
|
||||
for target_module in self.target_modules):
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
# bitsandbytes requires data in GPU
|
||||
loaded_weight = weight_tensor.cuda().data
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
compress_statistics=True,
|
||||
quant_type="nf4")
|
||||
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
else:
|
||||
processed_weight = weight_tensor
|
||||
|
||||
yield weight_name, processed_weight
|
||||
quant_state_dict: Dict[str, Any] = {}
|
||||
|
||||
if pre_quant:
|
||||
return quantized_checkpoint(), quant_state_dict
|
||||
return generator(), quant_state_dict
|
||||
if load_8bit:
|
||||
return self._quantized_8bit_generator(
|
||||
hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
else:
|
||||
return self._quantized_4bit_generator(
|
||||
hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
|
||||
return self._unquantized_generator(hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
|
||||
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if not weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = weight_name.lower().replace(".scb", ".qweight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
|
||||
if not weight_name.endswith(".weight"):
|
||||
continue
|
||||
|
||||
qweight_name = weight_name.replace(".weight", ".qweight")
|
||||
if qweight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield qweight_name, weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
|
||||
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
# First iterate over all quant state weights
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
||||
use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for weight_name, weight_tensor in weight_iterator:
|
||||
if weight_name.endswith(".weight"):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
if "quant_state.bitsandbytes" in weight_name:
|
||||
temp_state_dict[weight_name] = weight_tensor.cpu().data
|
||||
else:
|
||||
temp_state_dict[weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str,
|
||||
temp_state_dict: Dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
|
||||
return QuantState.from_dict(quant_state, device="cuda")
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
# Filter out all weights whose suffix is not ".weight"
|
||||
if not weight_name.endswith(".weight"):
|
||||
continue
|
||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
|
||||
in temp_state_dict) or \
|
||||
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
||||
in temp_state_dict):
|
||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
yield weight_name.replace(".weight", ".qweight"), weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
|
||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if any(target_module in weight_name
|
||||
for target_module in self.target_modules):
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
# bitsandbytes requires data in GPU
|
||||
loaded_weight = weight_tensor.cuda().data
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
compress_statistics=True,
|
||||
quant_type="nf4")
|
||||
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
else:
|
||||
processed_weight = weight_tensor
|
||||
|
||||
yield weight_name, processed_weight
|
||||
|
||||
def _load_weights(self, model_config: ModelConfig,
|
||||
model: nn.Module) -> None:
|
||||
@ -883,16 +916,26 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
" May take a while ...")
|
||||
|
||||
is_quantized_checkpoint = False
|
||||
quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
if quant_config is not None and quant_config.get(
|
||||
'quant_method') == "bitsandbytes":
|
||||
is_quantized_checkpoint = True
|
||||
|
||||
pre_quant = False
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get('quant_method')
|
||||
if quant_method == "bitsandbytes":
|
||||
pre_quant = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"BitsAndBytes loader does not support {quant_method} "
|
||||
"quantization")
|
||||
|
||||
load_8bit = False
|
||||
if pre_quant:
|
||||
load_8bit = quant_config.get('load_in_8bit', False)
|
||||
|
||||
qweight_iterator, quant_state_dict = \
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model, model_config.revision, is_quantized_checkpoint)
|
||||
model_config.model, model_config.revision, pre_quant, load_8bit)
|
||||
|
||||
model.load_weights(qweight_iterator)
|
||||
|
||||
@ -942,6 +985,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)})
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user