diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 8d9ae282153cf..0f20f42d8650b 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -8,9 +8,11 @@ import gc import pytest import torch +from transformers import BitsAndBytesConfig from tests.quantization.utils import is_quant_method_supported +from ..models.utils import check_embeddings_close from ..utils import compare_two_settings, create_new_process_for_each_test models_4bit_to_test = [ @@ -19,6 +21,10 @@ models_4bit_to_test = [ "quantize inflight model with both HF and Mistral format weights") ] +models_4bit_to_embedding_test = [ + ("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"), +] + models_pre_qaunt_4bit_to_test = [ ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', 'read pre-quantized 4-bit FP4 model'), @@ -31,6 +37,12 @@ models_pre_quant_8bit_to_test = [ ("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"), ] +models_pre_quant_8bit_to_test = [ + ('meta-llama/Llama-Guard-3-8B-INT8', + 'read pre-quantized llama 8-bit model'), + ("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"), +] + @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @@ -39,7 +51,8 @@ models_pre_quant_8bit_to_test = [ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: - hf_model_kwargs = {"load_in_4bit": True} + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], model_name, False, hf_model_kwargs) @@ -77,7 +90,8 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: - hf_model_kwargs = {"load_in_4bit": True} + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], @@ -113,6 +127,54 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None: compare_two_settings(model_name, common_args, pp_args) +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", + models_4bit_to_embedding_test) +@pytest.mark.parametrize("dtype", ["half"]) +@create_new_process_for_each_test() +def test_4bit_bnb_embedding_model( + model_name, + description, + hf_runner, + vllm_runner, + example_prompts, + dtype: str, +) -> None: + + # The example_prompts has ending "\n", for example: + # "Write a short story about a robot that dreams for the first time.\n" + # sentence_transformers will strip the input texts, see: + # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159 + # This makes the input_ids different between hf_model and vllm_model. + # So we need to strip the input texts to avoid test failing. + example_prompts = [str(s).strip() for s in example_prompts] + + # Inflight 4bit quantization + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) + with hf_runner( + model_name, + dtype=dtype, + model_kwargs=hf_model_kwargs, + is_sentence_transformer=True, + ) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model_name, + task="embed", + dtype=dtype, + quantization="bitsandbytes") as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=5e-2, + ) + + def log_generated_texts(prompts, outputs, runner_name): logged_texts = [] for i, (_, generated_text) in enumerate(outputs): diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 47a7a06bb7445..6771c128c5a1b 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.models import is_pooling_model from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -133,6 +134,16 @@ class BitsAndBytesModelLoader(BaseModelLoader): return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + def _maybe_pool_model(module_name:str): + # For pool model, we need to add the prefix `model.` + # for the weight name if possible. + if self.is_pool_model and self.target_modules[0]. \ + startswith("model.") and not module_name.startswith( + "model."): + return "model."+module_name + + return module_name + if use_safetensors: iterator = safetensors_weights_iterator( hf_weights_files, @@ -148,6 +159,9 @@ class BitsAndBytesModelLoader(BaseModelLoader): # mapping weight names from transformers to vllm while preserving # original names. mapped_name = self.weight_mapper(org_name) + mapped_name=_maybe_pool_model(mapped_name) + + yield org_name, mapped_name, param def _get_quantized_weights_iterator( @@ -405,7 +419,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): raise AttributeError( f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet. No 'packed_modules_mapping' found.") - + self.is_pool_model=is_pooling_model(model) self.modules_mapping = ParamMapping( copy.deepcopy(model.packed_modules_mapping))