[Quantization] Pool model support bitsandbytes (#18087)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-05-20 00:03:43 +08:00 committed by GitHub
parent 1b15df2546
commit 6781af5608
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 3 deletions

View File

@ -8,9 +8,11 @@ import gc
import pytest
import torch
from transformers import BitsAndBytesConfig
from tests.quantization.utils import is_quant_method_supported
from ..models.utils import check_embeddings_close
from ..utils import compare_two_settings, create_new_process_for_each_test
models_4bit_to_test = [
@ -19,6 +21,10 @@ models_4bit_to_test = [
"quantize inflight model with both HF and Mistral format weights")
]
models_4bit_to_embedding_test = [
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
]
models_pre_qaunt_4bit_to_test = [
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
@ -31,6 +37,12 @@ models_pre_quant_8bit_to_test = [
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
]
models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8',
'read pre-quantized llama 8-bit model'),
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
]
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@ -39,7 +51,8 @@ models_pre_quant_8bit_to_test = [
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = {"load_in_4bit": True}
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True))
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, False, hf_model_kwargs)
@ -77,7 +90,8 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = {"load_in_4bit": True}
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True))
validate_generated_texts(hf_runner,
vllm_runner,
example_prompts[:1],
@ -113,6 +127,54 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
compare_two_settings(model_name, common_args, pp_args)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
models_4bit_to_embedding_test)
@pytest.mark.parametrize("dtype", ["half"])
@create_new_process_for_each_test()
def test_4bit_bnb_embedding_model(
model_name,
description,
hf_runner,
vllm_runner,
example_prompts,
dtype: str,
) -> None:
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]
# Inflight 4bit quantization
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True))
with hf_runner(
model_name,
dtype=dtype,
model_kwargs=hf_model_kwargs,
is_sentence_transformer=True,
) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model_name,
task="embed",
dtype=dtype,
quantization="bitsandbytes") as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=5e-2,
)
def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):

View File

@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@ -133,6 +134,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
def _maybe_pool_model(module_name:str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if self.is_pool_model and self.target_modules[0]. \
startswith("model.") and not module_name.startswith(
"model."):
return "model."+module_name
return module_name
if use_safetensors:
iterator = safetensors_weights_iterator(
hf_weights_files,
@ -148,6 +159,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
mapped_name=_maybe_pool_model(mapped_name)
yield org_name, mapped_name, param
def _get_quantized_weights_iterator(
@ -405,7 +419,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
self.is_pool_model=is_pooling_model(model)
self.modules_mapping = ParamMapping(
copy.deepcopy(model.packed_modules_mapping))