mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 17:07:14 +08:00
[Bugfix] Fix BNB loader target_modules (#10720)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
aa39a8e175
commit
1f958a7d52
@ -6,7 +6,6 @@ import fnmatch
|
||||
import glob
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
@ -18,7 +17,7 @@ import gguf
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub import HfApi
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
@ -704,51 +703,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
self.unsharded_weights_modules: List[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: List[str] = []
|
||||
# we don't need to quantize the whole model, only the target modules
|
||||
# that are specified in the adapter config file. If the adapter config
|
||||
# file is not provided, we will quantize the default modules.
|
||||
if (not load_config.model_loader_extra_config
|
||||
or "qlora_adapter_name_or_path"
|
||||
not in load_config.model_loader_extra_config):
|
||||
self.target_modules = []
|
||||
return
|
||||
|
||||
qlora_adapter = load_config.model_loader_extra_config[
|
||||
"qlora_adapter_name_or_path"]
|
||||
|
||||
config_file_path = self._get_config_file(qlora_adapter)
|
||||
|
||||
with open(config_file_path) as f:
|
||||
config = json.load(f)
|
||||
self.target_modules = config["target_modules"]
|
||||
# TODO: target_modules could be either a list or a regex string.
|
||||
# We need to handle both cases.
|
||||
assert isinstance(self.target_modules,
|
||||
list), "Unsupported target_modules: "
|
||||
f"{self.target_modules}"
|
||||
|
||||
def _get_config_file(self, qlora_adapter: str) -> str:
|
||||
is_local = os.path.isdir(qlora_adapter)
|
||||
config_file_path = None
|
||||
if is_local:
|
||||
for file in self.possible_config_file_names:
|
||||
config_file_path = os.path.join(qlora_adapter, file)
|
||||
if os.path.exists(config_file_path):
|
||||
break
|
||||
else:
|
||||
hf_api = HfApi()
|
||||
repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
|
||||
for file in self.possible_config_file_names:
|
||||
if file in repo_files:
|
||||
config_file_path = hf_hub_download(repo_id=qlora_adapter,
|
||||
filename=file)
|
||||
break
|
||||
|
||||
if not config_file_path:
|
||||
raise ValueError(
|
||||
f"Cannot find adapter config file in {qlora_adapter}")
|
||||
|
||||
return config_file_path
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: List[str] = []
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
@ -1030,25 +987,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
inverse_stacked_mapping[packed] = []
|
||||
inverse_stacked_mapping[packed].insert(idx, orig)
|
||||
|
||||
linear_module_lst = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (LinearBase, )):
|
||||
last_name = name.split(".")[-1]
|
||||
if sub_modules := inverse_stacked_mapping.get(last_name, []):
|
||||
# Map vllm's names to transformers' names.
|
||||
for sub_name in sub_modules:
|
||||
linear_module_lst.append(
|
||||
self.target_modules.append(
|
||||
name.replace(last_name, sub_name))
|
||||
else:
|
||||
linear_module_lst.append(name)
|
||||
if self.target_modules:
|
||||
# Update self.target_modules
|
||||
self.target_modules = [
|
||||
qual_name for qual_name in linear_module_lst
|
||||
if any(t in qual_name for t in self.target_modules)
|
||||
]
|
||||
else:
|
||||
self.target_modules = linear_module_lst
|
||||
self.target_modules.append(name)
|
||||
assert (self.target_modules
|
||||
), "vllm currently does not support BNB quantization for"
|
||||
f" {type(model).__name__}"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user