From 28b18cc741e596ea6f9981b8365c4819523fc24b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 1 Aug 2025 19:09:54 +0800 Subject: [PATCH] [Quantization] Enable BNB support for InternS1 (#21953) Signed-off-by: Jee Jee Li --- .../model_loader/bitsandbytes_loader.py | 39 ++++++++++++------- vllm/model_executor/utils.py | 20 +++++++++- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 68fcb785691c8..f54dfab5238e1 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -34,7 +34,8 @@ from vllm.model_executor.model_loader.weight_utils import ( filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models import is_pooling_model -from vllm.model_executor.utils import (get_packed_modules_mapping, +from vllm.model_executor.utils import (get_moe_expert_mapping, + get_packed_modules_mapping, set_weight_attrs) from vllm.platforms import current_platform @@ -43,6 +44,12 @@ from vllm.platforms import current_platform logger = init_logger(__name__) +def is_moe_model(model: torch.nn.Module) -> bool: + """Checks if the model contains FusedMoE layers.""" + return bool(any( + isinstance(module, FusedMoE) for module in model.modules())) + + class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" @@ -61,6 +68,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Store all module names (from transformers) that support # BNB quantization. self.target_modules: list[str] = [] + # Store the mapping of expert parameters for MoE models. + self.expert_params_mapping: list[tuple[str, str, int, str]] = [] # mapping weight names from transformers to vllm. self.weight_mapper: Callable = lambda name: name self.pre_quant: bool = False @@ -413,13 +422,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): # in case model has a mixture of disk-merged and disk-split # weights with same last name. self.target_modules.append(name) - elif (isinstance(module, FusedMoE) - and hasattr(module.quant_method, "quant_config")): - if not hasattr(model, "get_expert_mapping"): - raise AttributeError( - f"MoE Model {type(model).__name__} does not support " - "BitsAndBytes quantization yet. Ensure this model has " - "'get_expert_mapping' method.") + elif isinstance(module, FusedMoE) and hasattr( + module.quant_method, "quant_config"): # TODO: support FusedMoE with prequant and 8bit. if self.pre_quant: raise ValueError( @@ -430,9 +434,9 @@ class BitsAndBytesModelLoader(BaseModelLoader): "BitsAndBytes 8bit quantization with FusedMoE is not " "supported yet.") # Get the corresponding weight name using module name and - # get_expert_mapping. - expert_mapping = model.get_expert_mapping() - for exp in expert_mapping: + # expert_params_mapping. + + for exp in self.expert_params_mapping: weight_name = exp[1] rep_name = name.replace("experts", "") + weight_name.removesuffix(".") @@ -464,7 +468,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): elif isinstance(module, (RowParallelLinear, )): self.column_sharded_weights_modules.append(name) elif isinstance(module, FusedMoE): - expert_mapping = model.get_expert_mapping() + expert_mapping = self.expert_params_mapping for exp in expert_mapping: if exp[-1] == "w2": weight_name = exp[1] @@ -516,6 +520,13 @@ class BitsAndBytesModelLoader(BaseModelLoader): self.is_pool_model = is_pooling_model(model) self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) + if is_moe_model(model): + self.expert_params_mapping = get_moe_expert_mapping(model) + if not self.expert_params_mapping: + raise AttributeError( + f"MoE Model {type(model).__name__} does not support " + "BitsAndBytes quantization yet. Ensure this model has " + "'get_expert_mapping' method.") # For some models like Molmo, we need to use hf_to_vllm_mapper # to ensure correct loading of weights. if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): @@ -569,10 +580,10 @@ class BitsAndBytesModelLoader(BaseModelLoader): """ from bitsandbytes.functional import QuantState - if not hasattr(model, "get_expert_mapping"): + if not self.expert_params_mapping: return dict() - expert_mapping = model.get_expert_mapping() + expert_mapping = self.expert_params_mapping expert_qs_dict = {} for name, module in model.named_modules(): if not isinstance(module, FusedMoE): diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 2b20ca2a3ba3f..41ed0b09c5a2a 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utils for model executor.""" + import copy from typing import Any, Optional @@ -9,6 +10,7 @@ import torch def set_random_seed(seed: int) -> None: from vllm.platforms import current_platform + current_platform.seed_everything(seed) @@ -29,7 +31,7 @@ def set_weight_attrs( return for key, value in weight_attrs.items(): assert not hasattr( - weight, key), (f"Overwriting existing tensor attribute: {key}") + weight, key), f"Overwriting existing tensor attribute: {key}" # NOTE(woosuk): During weight loading, we often do something like: # narrowed_tensor = param.data.narrow(0, offset, len) @@ -41,6 +43,7 @@ def set_weight_attrs( # we sync the param tensor after its weight loader is called. # TODO(woosuk): Remove this hack once we have a better solution. from vllm.platforms import current_platform + if current_platform.is_tpu() and key == "weight_loader": value = _make_synced_weight_loader(value) setattr(weight, key, value) @@ -77,4 +80,17 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: f"safely because of conflicts from {type(child).__name__}.") else: parent_map.update(child_map) - return parent_map \ No newline at end of file + return parent_map + + +def get_moe_expert_mapping( + model: torch.nn.Module, ) -> list[tuple[str, str, int, str]]: + if parent_map := getattr(model, "get_expert_mapping", None): + return parent_map() + else: + # We only check main components instead of whole model submodules + for child in model.children(): + child_map = getattr(child, "get_expert_mapping", None) + if child_map is not None: + return child_map() + return []