mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 13:45:51 +08:00
[Quantization] Enable BNB support for InternS1 (#21953)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
4931486988
commit
28b18cc741
@ -34,7 +34,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||||
pt_weights_iterator, safetensors_weights_iterator)
|
pt_weights_iterator, safetensors_weights_iterator)
|
||||||
from vllm.model_executor.models import is_pooling_model
|
from vllm.model_executor.models import is_pooling_model
|
||||||
from vllm.model_executor.utils import (get_packed_modules_mapping,
|
from vllm.model_executor.utils import (get_moe_expert_mapping,
|
||||||
|
get_packed_modules_mapping,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -43,6 +44,12 @@ from vllm.platforms import current_platform
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_moe_model(model: torch.nn.Module) -> bool:
|
||||||
|
"""Checks if the model contains FusedMoE layers."""
|
||||||
|
return bool(any(
|
||||||
|
isinstance(module, FusedMoE) for module in model.modules()))
|
||||||
|
|
||||||
|
|
||||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||||
|
|
||||||
@ -61,6 +68,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# Store all module names (from transformers) that support
|
# Store all module names (from transformers) that support
|
||||||
# BNB quantization.
|
# BNB quantization.
|
||||||
self.target_modules: list[str] = []
|
self.target_modules: list[str] = []
|
||||||
|
# Store the mapping of expert parameters for MoE models.
|
||||||
|
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||||
# mapping weight names from transformers to vllm.
|
# mapping weight names from transformers to vllm.
|
||||||
self.weight_mapper: Callable = lambda name: name
|
self.weight_mapper: Callable = lambda name: name
|
||||||
self.pre_quant: bool = False
|
self.pre_quant: bool = False
|
||||||
@ -413,13 +422,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# in case model has a mixture of disk-merged and disk-split
|
# in case model has a mixture of disk-merged and disk-split
|
||||||
# weights with same last name.
|
# weights with same last name.
|
||||||
self.target_modules.append(name)
|
self.target_modules.append(name)
|
||||||
elif (isinstance(module, FusedMoE)
|
elif isinstance(module, FusedMoE) and hasattr(
|
||||||
and hasattr(module.quant_method, "quant_config")):
|
module.quant_method, "quant_config"):
|
||||||
if not hasattr(model, "get_expert_mapping"):
|
|
||||||
raise AttributeError(
|
|
||||||
f"MoE Model {type(model).__name__} does not support "
|
|
||||||
"BitsAndBytes quantization yet. Ensure this model has "
|
|
||||||
"'get_expert_mapping' method.")
|
|
||||||
# TODO: support FusedMoE with prequant and 8bit.
|
# TODO: support FusedMoE with prequant and 8bit.
|
||||||
if self.pre_quant:
|
if self.pre_quant:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -430,9 +434,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
"BitsAndBytes 8bit quantization with FusedMoE is not "
|
"BitsAndBytes 8bit quantization with FusedMoE is not "
|
||||||
"supported yet.")
|
"supported yet.")
|
||||||
# Get the corresponding weight name using module name and
|
# Get the corresponding weight name using module name and
|
||||||
# get_expert_mapping.
|
# expert_params_mapping.
|
||||||
expert_mapping = model.get_expert_mapping()
|
|
||||||
for exp in expert_mapping:
|
for exp in self.expert_params_mapping:
|
||||||
weight_name = exp[1]
|
weight_name = exp[1]
|
||||||
rep_name = name.replace("experts",
|
rep_name = name.replace("experts",
|
||||||
"") + weight_name.removesuffix(".")
|
"") + weight_name.removesuffix(".")
|
||||||
@ -464,7 +468,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
elif isinstance(module, (RowParallelLinear, )):
|
elif isinstance(module, (RowParallelLinear, )):
|
||||||
self.column_sharded_weights_modules.append(name)
|
self.column_sharded_weights_modules.append(name)
|
||||||
elif isinstance(module, FusedMoE):
|
elif isinstance(module, FusedMoE):
|
||||||
expert_mapping = model.get_expert_mapping()
|
expert_mapping = self.expert_params_mapping
|
||||||
for exp in expert_mapping:
|
for exp in expert_mapping:
|
||||||
if exp[-1] == "w2":
|
if exp[-1] == "w2":
|
||||||
weight_name = exp[1]
|
weight_name = exp[1]
|
||||||
@ -516,6 +520,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
self.is_pool_model = is_pooling_model(model)
|
self.is_pool_model = is_pooling_model(model)
|
||||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||||
|
|
||||||
|
if is_moe_model(model):
|
||||||
|
self.expert_params_mapping = get_moe_expert_mapping(model)
|
||||||
|
if not self.expert_params_mapping:
|
||||||
|
raise AttributeError(
|
||||||
|
f"MoE Model {type(model).__name__} does not support "
|
||||||
|
"BitsAndBytes quantization yet. Ensure this model has "
|
||||||
|
"'get_expert_mapping' method.")
|
||||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||||
# to ensure correct loading of weights.
|
# to ensure correct loading of weights.
|
||||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||||
@ -569,10 +580,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
"""
|
"""
|
||||||
from bitsandbytes.functional import QuantState
|
from bitsandbytes.functional import QuantState
|
||||||
|
|
||||||
if not hasattr(model, "get_expert_mapping"):
|
if not self.expert_params_mapping:
|
||||||
return dict()
|
return dict()
|
||||||
|
|
||||||
expert_mapping = model.get_expert_mapping()
|
expert_mapping = self.expert_params_mapping
|
||||||
expert_qs_dict = {}
|
expert_qs_dict = {}
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if not isinstance(module, FusedMoE):
|
if not isinstance(module, FusedMoE):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Utils for model executor."""
|
"""Utils for model executor."""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ import torch
|
|||||||
|
|
||||||
def set_random_seed(seed: int) -> None:
|
def set_random_seed(seed: int) -> None:
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +31,7 @@ def set_weight_attrs(
|
|||||||
return
|
return
|
||||||
for key, value in weight_attrs.items():
|
for key, value in weight_attrs.items():
|
||||||
assert not hasattr(
|
assert not hasattr(
|
||||||
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
weight, key), f"Overwriting existing tensor attribute: {key}"
|
||||||
|
|
||||||
# NOTE(woosuk): During weight loading, we often do something like:
|
# NOTE(woosuk): During weight loading, we often do something like:
|
||||||
# narrowed_tensor = param.data.narrow(0, offset, len)
|
# narrowed_tensor = param.data.narrow(0, offset, len)
|
||||||
@ -41,6 +43,7 @@ def set_weight_attrs(
|
|||||||
# we sync the param tensor after its weight loader is called.
|
# we sync the param tensor after its weight loader is called.
|
||||||
# TODO(woosuk): Remove this hack once we have a better solution.
|
# TODO(woosuk): Remove this hack once we have a better solution.
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_tpu() and key == "weight_loader":
|
if current_platform.is_tpu() and key == "weight_loader":
|
||||||
value = _make_synced_weight_loader(value)
|
value = _make_synced_weight_loader(value)
|
||||||
setattr(weight, key, value)
|
setattr(weight, key, value)
|
||||||
@ -78,3 +81,16 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
|||||||
else:
|
else:
|
||||||
parent_map.update(child_map)
|
parent_map.update(child_map)
|
||||||
return parent_map
|
return parent_map
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_expert_mapping(
|
||||||
|
model: torch.nn.Module, ) -> list[tuple[str, str, int, str]]:
|
||||||
|
if parent_map := getattr(model, "get_expert_mapping", None):
|
||||||
|
return parent_map()
|
||||||
|
else:
|
||||||
|
# We only check main components instead of whole model submodules
|
||||||
|
for child in model.children():
|
||||||
|
child_map = getattr(child, "get_expert_mapping", None)
|
||||||
|
if child_map is not None:
|
||||||
|
return child_map()
|
||||||
|
return []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user