mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Quantization] Pool model support bitsandbytes (#18087)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
1b15df2546
commit
6781af5608
@ -8,9 +8,11 @@ import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ..models.utils import check_embeddings_close
|
||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
|
||||
models_4bit_to_test = [
|
||||
@ -19,6 +21,10 @@ models_4bit_to_test = [
|
||||
"quantize inflight model with both HF and Mistral format weights")
|
||||
]
|
||||
|
||||
models_4bit_to_embedding_test = [
|
||||
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
|
||||
]
|
||||
|
||||
models_pre_qaunt_4bit_to_test = [
|
||||
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
|
||||
'read pre-quantized 4-bit FP4 model'),
|
||||
@ -31,6 +37,12 @@ models_pre_quant_8bit_to_test = [
|
||||
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
|
||||
]
|
||||
|
||||
models_pre_quant_8bit_to_test = [
|
||||
('meta-llama/Llama-Guard-3-8B-INT8',
|
||||
'read pre-quantized llama 8-bit model'),
|
||||
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@ -39,7 +51,8 @@ models_pre_quant_8bit_to_test = [
|
||||
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
hf_model_kwargs = {"load_in_4bit": True}
|
||||
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True))
|
||||
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
|
||||
model_name, False, hf_model_kwargs)
|
||||
|
||||
@ -77,7 +90,8 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
hf_model_kwargs = {"load_in_4bit": True}
|
||||
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True))
|
||||
validate_generated_texts(hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts[:1],
|
||||
@ -113,6 +127,54 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
|
||||
compare_two_settings(model_name, common_args, pp_args)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description",
|
||||
models_4bit_to_embedding_test)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@create_new_process_for_each_test()
|
||||
def test_4bit_bnb_embedding_model(
|
||||
model_name,
|
||||
description,
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
|
||||
# The example_prompts has ending "\n", for example:
|
||||
# "Write a short story about a robot that dreams for the first time.\n"
|
||||
# sentence_transformers will strip the input texts, see:
|
||||
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
|
||||
# This makes the input_ids different between hf_model and vllm_model.
|
||||
# So we need to strip the input texts to avoid test failing.
|
||||
example_prompts = [str(s).strip() for s in example_prompts]
|
||||
|
||||
# Inflight 4bit quantization
|
||||
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True))
|
||||
with hf_runner(
|
||||
model_name,
|
||||
dtype=dtype,
|
||||
model_kwargs=hf_model_kwargs,
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
with vllm_runner(model_name,
|
||||
task="embed",
|
||||
dtype=dtype,
|
||||
quantization="bitsandbytes") as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=5e-2,
|
||||
)
|
||||
|
||||
|
||||
def log_generated_texts(prompts, outputs, runner_name):
|
||||
logged_texts = []
|
||||
for i, (_, generated_text) in enumerate(outputs):
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models import is_pooling_model
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -133,6 +134,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
def _maybe_pool_model(module_name:str):
|
||||
# For pool model, we need to add the prefix `model.`
|
||||
# for the weight name if possible.
|
||||
if self.is_pool_model and self.target_modules[0]. \
|
||||
startswith("model.") and not module_name.startswith(
|
||||
"model."):
|
||||
return "model."+module_name
|
||||
|
||||
return module_name
|
||||
|
||||
if use_safetensors:
|
||||
iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
@ -148,6 +159,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
mapped_name = self.weight_mapper(org_name)
|
||||
mapped_name=_maybe_pool_model(mapped_name)
|
||||
|
||||
|
||||
yield org_name, mapped_name, param
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
@ -405,7 +419,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet. No 'packed_modules_mapping' found.")
|
||||
|
||||
self.is_pool_model=is_pooling_model(model)
|
||||
self.modules_mapping = ParamMapping(
|
||||
copy.deepcopy(model.packed_modules_mapping))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user