[Bugfix] Fix BNB loader target_modules (#10720)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-12-05 13:20:26 +08:00 committed by GitHub
parent aa39a8e175
commit 1f958a7d52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,7 +6,6 @@ import fnmatch
import glob
import inspect
import itertools
import json
import math
import os
import warnings
@ -18,7 +17,7 @@ import gguf
import huggingface_hub
import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub import HfApi
from torch import nn
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
@ -704,51 +703,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self.unsharded_weights_modules: List[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: List[str] = []
# we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules.
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
self.target_modules = []
return
qlora_adapter = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"]
config_file_path = self._get_config_file(qlora_adapter)
with open(config_file_path) as f:
config = json.load(f)
self.target_modules = config["target_modules"]
# TODO: target_modules could be either a list or a regex string.
# We need to handle both cases.
assert isinstance(self.target_modules,
list), "Unsupported target_modules: "
f"{self.target_modules}"
def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter)
config_file_path = None
if is_local:
for file in self.possible_config_file_names:
config_file_path = os.path.join(qlora_adapter, file)
if os.path.exists(config_file_path):
break
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
for file in self.possible_config_file_names:
if file in repo_files:
config_file_path = hf_hub_download(repo_id=qlora_adapter,
filename=file)
break
if not config_file_path:
raise ValueError(
f"Cannot find adapter config file in {qlora_adapter}")
return config_file_path
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: List[str] = []
def _get_weight_files(
self,
@ -1030,25 +987,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
inverse_stacked_mapping[packed] = []
inverse_stacked_mapping[packed].insert(idx, orig)
linear_module_lst = []
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []):
# Map vllm's names to transformers' names.
for sub_name in sub_modules:
linear_module_lst.append(
self.target_modules.append(
name.replace(last_name, sub_name))
else:
linear_module_lst.append(name)
if self.target_modules:
# Update self.target_modules
self.target_modules = [
qual_name for qual_name in linear_module_lst
if any(t in qual_name for t in self.target_modules)
]
else:
self.target_modules = linear_module_lst
self.target_modules.append(name)
assert (self.target_modules
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"