mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 17:14:24 +08:00
[Bugfix] Fix BNB loader target_modules (#10720)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
aa39a8e175
commit
1f958a7d52
@ -6,7 +6,6 @@ import fnmatch
|
|||||||
import glob
|
import glob
|
||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -18,7 +17,7 @@ import gguf
|
|||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||||
@ -704,51 +703,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
self.unsharded_weights_modules: List[str] = []
|
self.unsharded_weights_modules: List[str] = []
|
||||||
# Save the module names that are sharded by column.
|
# Save the module names that are sharded by column.
|
||||||
self.column_sharded_weights_modules: List[str] = []
|
self.column_sharded_weights_modules: List[str] = []
|
||||||
# we don't need to quantize the whole model, only the target modules
|
# Store all module names (from transformers) that support
|
||||||
# that are specified in the adapter config file. If the adapter config
|
# BNB quantization.
|
||||||
# file is not provided, we will quantize the default modules.
|
self.target_modules: List[str] = []
|
||||||
if (not load_config.model_loader_extra_config
|
|
||||||
or "qlora_adapter_name_or_path"
|
|
||||||
not in load_config.model_loader_extra_config):
|
|
||||||
self.target_modules = []
|
|
||||||
return
|
|
||||||
|
|
||||||
qlora_adapter = load_config.model_loader_extra_config[
|
|
||||||
"qlora_adapter_name_or_path"]
|
|
||||||
|
|
||||||
config_file_path = self._get_config_file(qlora_adapter)
|
|
||||||
|
|
||||||
with open(config_file_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
self.target_modules = config["target_modules"]
|
|
||||||
# TODO: target_modules could be either a list or a regex string.
|
|
||||||
# We need to handle both cases.
|
|
||||||
assert isinstance(self.target_modules,
|
|
||||||
list), "Unsupported target_modules: "
|
|
||||||
f"{self.target_modules}"
|
|
||||||
|
|
||||||
def _get_config_file(self, qlora_adapter: str) -> str:
|
|
||||||
is_local = os.path.isdir(qlora_adapter)
|
|
||||||
config_file_path = None
|
|
||||||
if is_local:
|
|
||||||
for file in self.possible_config_file_names:
|
|
||||||
config_file_path = os.path.join(qlora_adapter, file)
|
|
||||||
if os.path.exists(config_file_path):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
hf_api = HfApi()
|
|
||||||
repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
|
|
||||||
for file in self.possible_config_file_names:
|
|
||||||
if file in repo_files:
|
|
||||||
config_file_path = hf_hub_download(repo_id=qlora_adapter,
|
|
||||||
filename=file)
|
|
||||||
break
|
|
||||||
|
|
||||||
if not config_file_path:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot find adapter config file in {qlora_adapter}")
|
|
||||||
|
|
||||||
return config_file_path
|
|
||||||
|
|
||||||
def _get_weight_files(
|
def _get_weight_files(
|
||||||
self,
|
self,
|
||||||
@ -1030,25 +987,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
inverse_stacked_mapping[packed] = []
|
inverse_stacked_mapping[packed] = []
|
||||||
inverse_stacked_mapping[packed].insert(idx, orig)
|
inverse_stacked_mapping[packed].insert(idx, orig)
|
||||||
|
|
||||||
linear_module_lst = []
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, (LinearBase, )):
|
if isinstance(module, (LinearBase, )):
|
||||||
last_name = name.split(".")[-1]
|
last_name = name.split(".")[-1]
|
||||||
if sub_modules := inverse_stacked_mapping.get(last_name, []):
|
if sub_modules := inverse_stacked_mapping.get(last_name, []):
|
||||||
# Map vllm's names to transformers' names.
|
# Map vllm's names to transformers' names.
|
||||||
for sub_name in sub_modules:
|
for sub_name in sub_modules:
|
||||||
linear_module_lst.append(
|
self.target_modules.append(
|
||||||
name.replace(last_name, sub_name))
|
name.replace(last_name, sub_name))
|
||||||
else:
|
else:
|
||||||
linear_module_lst.append(name)
|
self.target_modules.append(name)
|
||||||
if self.target_modules:
|
|
||||||
# Update self.target_modules
|
|
||||||
self.target_modules = [
|
|
||||||
qual_name for qual_name in linear_module_lst
|
|
||||||
if any(t in qual_name for t in self.target_modules)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.target_modules = linear_module_lst
|
|
||||||
assert (self.target_modules
|
assert (self.target_modules
|
||||||
), "vllm currently does not support BNB quantization for"
|
), "vllm currently does not support BNB quantization for"
|
||||||
f" {type(model).__name__}"
|
f" {type(model).__name__}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user