mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 00:15:01 +08:00
Nvidia ModelOpt workaround for issue 28072 (#30164)
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
parent
060893654d
commit
0bb0bae436
@ -188,7 +188,24 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
|||||||
|
|
||||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||||
if len(self.exclude_modules) > 0:
|
if len(self.exclude_modules) > 0:
|
||||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
|
# This is a workaround for the weights remapping issue:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/28072
|
||||||
|
# Right now, the Nvidia ModelOpt library use just one wildcard pattern:
|
||||||
|
# module_path*
|
||||||
|
# It gets applied if the whole tree of modules rooted at module_path
|
||||||
|
# is not quantized. Here we replace such pattern by 2 patterns that are
|
||||||
|
# collectively equivalent to the original pattern:
|
||||||
|
# module_path
|
||||||
|
# module_path.*
|
||||||
|
new_exclude_modules = []
|
||||||
|
for exclude in self.exclude_modules:
|
||||||
|
if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
|
||||||
|
new_exclude_modules.append(exclude[:-1])
|
||||||
|
new_exclude_modules.append(exclude[:-1] + ".*")
|
||||||
|
else:
|
||||||
|
new_exclude_modules.append(exclude)
|
||||||
|
|
||||||
|
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_config_filenames() -> list[str]:
|
def get_config_filenames() -> list[str]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user