diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index d6844b8dc5f2..1b6a91840148 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -15,6 +15,8 @@ from ..utils import compare_two_settings, create_new_process_for_each_test models_4bit_to_test = [ ("facebook/opt-125m", "quantize opt model inflight"), + ("mistralai/Mistral-7B-Instruct-v0.3", + "quantize inflight model with both HF and Mistral format weights") ] models_pre_qaunt_4bit_to_test = [ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index c88af56e1805..b2ffca2a4b4d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): model_name_or_path: str, allowed_patterns: List[str], revision: Optional[str] = None, - ) -> Tuple[List[str], str]: + ) -> Tuple[str, List[str], str]: """Retrieve weight files. Download the files if necessary. Return the weight files and the file pattern.""" @@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): weight_files = glob.glob( os.path.join(model_name_or_path, pattern)) if weight_files: - return weight_files, pattern + return model_name_or_path, weight_files, pattern else: hf_api = HfApi() repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) @@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): revision, ignore_patterns=self.load_config.ignore_patterns, ) - return glob.glob(os.path.join(hf_folder, pattern)), pattern + return hf_folder, glob.glob( + os.path.join(hf_folder, pattern)), pattern raise RuntimeError( f"No model weights found in: `{model_name_or_path}`") @@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader): allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] - hf_weights_files, matched_pattern = self._get_weight_files( + hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( model_name_or_path, allowed_patterns, revision) - if matched_pattern != "*.safetensors": + use_safetensors = matched_pattern == "*.safetensors" + is_local = os.path.isdir(model_name_or_path) + index_file = SAFE_WEIGHTS_INDEX_NAME + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) @@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") - return hf_weights_files, matched_pattern == "*.safetensors" + return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): if use_safetensors: