[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-07-11 16:01:13 +08:00 committed by GitHub
parent 762be26a8e
commit 8020e98c9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 561 additions and 88 deletions

View File

@ -14,7 +14,7 @@ from transformers import BitsAndBytesConfig
from tests.quantization.utils import is_quant_method_supported
from ...utils import compare_two_settings, multi_gpu_test
from ..utils import check_embeddings_close
from ..utils import check_embeddings_close, check_logprobs_close
models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"),
@ -26,6 +26,10 @@ models_4bit_to_embedding_test = [
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
]
models_4bit_to_moe_test = [
("allenai/OLMoE-1B-7B-0125-Instruct", "quantize moe model inflight"),
]
models_pre_qaunt_4bit_to_test = [
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
compare_two_settings(model_name, common_args, pp_args)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test)
def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
))
with vllm_runner(model_name,
quantization='bitsandbytes',
enforce_eager=False) as llm:
vllm_outputs = llm.generate_greedy_logprobs(example_prompts,
max_tokens=32,
num_logprobs=5)
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
transformers_outputs = llm.generate_greedy_logprobs_limit(
example_prompts, max_tokens=32, num_logprobs=5)
check_logprobs_close(
outputs_0_lst=transformers_outputs,
outputs_1_lst=vllm_outputs,
name_0="transformers",
name_1="vllm",
)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
@ -182,7 +215,8 @@ def validate_generated_texts(hf_runner,
model_name,
pre_quant=False,
hf_model_kwargs=None,
vllm_tp_size=1):
vllm_tp_size=1,
max_tokens=8):
# NOTE: run vLLM first, as it requires a clean process
# when using distributed inference
@ -190,7 +224,8 @@ def validate_generated_texts(hf_runner,
quantization=None if pre_quant else 'bitsandbytes',
tensor_parallel_size=vllm_tp_size,
enforce_eager=False) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_outputs = llm.generate_greedy(prompts, max_tokens)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
# Clean up the GPU memory for the next test
@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner,
# Run with HF runner
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
hf_outputs = llm.generate_greedy(prompts, 8)
hf_outputs = llm.generate_greedy(prompts, max_tokens)
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
# Clean up the GPU memory for the next test
gc.collect()
torch.cuda.empty_cache()
# Compare the generated strings
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
hf_str = hf_log["generated_text"]
vllm_str = vllm_log["generated_text"]
prompt = hf_log["prompt"]
assert hf_str == vllm_str, (f"Model: {model_name}"
f"Mismatch between HF and vLLM outputs:\n"
f"Prompt: {prompt}\n"

View File

@ -883,14 +883,21 @@ class FusedMoE(torch.nn.Module):
expert_data=expert_data,
tp_rank=tp_rank)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
def _load_w13(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
@ -998,6 +1005,27 @@ class FusedMoE(torch.nn.Module):
param.data.copy_(loaded_weight)
return True if return_success else None
# Case for BitsAndBytes
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit:
shard_dim = 0
expert_data = param.data[expert_id]
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
# BNB inflight quantization has already sharded the weights
full_load = True
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank,
load_full=full_load,
)
return True if return_success else None
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any, Callable, Optional, Union
import torch
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
@ -120,12 +123,15 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_skip_modules=llm_int8_skip_modules,
llm_int8_threshold=llm_int8_threshold)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
elif isinstance(layer, FusedMoE):
return BitsAndBytesMoEMethod(self)
return None
@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
return substr_check or prefix_check
def calculate_quant_ratio(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
else:
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
@ -173,12 +186,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
**extra_weight_attrs):
from bitsandbytes.nn import Int8Params
def calculate_quant_ratio(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
else:
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
def create_qweight_for_8bit():
qweight = Int8Params(
data=torch.empty(sum(output_partition_sizes),
@ -394,3 +401,210 @@ try:
except AttributeError as error:
raise error
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
"""MoE method for BitsAndBytes.
Args:
quant_config: The BitsAndBytes quantization config.
"""
def __init__(self, quant_config: BitsAndBytesConfig):
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.45.3":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.3.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.45.3 via "
"`pip install bitsandbytes>=0.45.3` to use "
"bitsandbytes quantizer.") from err
self.topk_indices_dtype = None
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.load_in_8bit:
call_fun = self._create_weights_8bit
else:
call_fun = self._create_weights_4bit
call_fun(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer)
else:
w13, w2 = self._apply_4bit_dequnt(layer)
return fused_experts(
hidden_states=x,
w1=w13,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
def _create_weights_4bit(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
quant_ratio = calculate_quant_ratio(params_dtype)
# Fused gate_up_proj (column parallel)
w13_total_size = (hidden_size * 2 *
intermediate_size_per_partition) // quant_ratio
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_total_size,
1,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
set_weight_attrs(
w13_qweight,
{
"num_experts":
num_experts,
"input_dim":
hidden_size,
"output_dim":
2 * intermediate_size_per_partition,
"experts_shape": (
num_experts,
intermediate_size_per_partition * 2,
hidden_size,
),
"pack_factor":
quant_ratio,
"use_bitsandbytes_4bit":
True,
},
)
# down_proj (row parallel)
w2_total_size = (hidden_size *
intermediate_size_per_partition) // quant_ratio
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
w2_total_size,
1,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(
w2_qweight,
{
"num_experts":
num_experts,
"input_dim":
intermediate_size_per_partition,
"output_dim":
hidden_size,
"experts_shape": (
num_experts,
hidden_size,
intermediate_size_per_partition,
),
"pack_factor":
quant_ratio,
"use_bitsandbytes_4bit":
True,
},
)
layer.register_parameter("w2_weight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
def _create_weights_8bit(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
def _apply_4bit_dequnt(
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
from bitsandbytes.functional import dequantize_4bit
w13 = dequantize_4bit(
layer.w13_weight.reshape(-1, 1),
layer.w13_weight.bnb_quant_state,
)
w2 = dequantize_4bit(
layer.w2_weight.reshape(-1, 1),
layer.w2_weight.bnb_quant_state,
)
w13 = w13.reshape(layer.w13_weight.experts_shape)
w2 = w2.reshape(layer.w2_weight.experts_shape)
return w13, w2
def _apply_8bit_dequant(
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@ -20,6 +20,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
# yapf: enable
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
@ -411,9 +412,33 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
elif (isinstance(module, FusedMoE)
and hasattr(module.quant_method, "quant_config")):
if not hasattr(model, "get_expert_mapping"):
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method.")
# TODO: support FusedMoE with prequant and 8bit.
if self.pre_quant:
raise ValueError(
"Prequant BitsAndBytes models with FusedMoE is not "
"supported yet.")
if self.load_8bit:
raise ValueError(
"BitsAndBytes 8bit quantization with FusedMoE is not "
"supported yet.")
# Get the corresponding weight name using module name and
# get_expert_mapping.
expert_mapping = model.get_expert_mapping()
for exp in expert_mapping:
weight_name = exp[1]
rep_name = name.replace("experts",
"") + weight_name.removesuffix(".")
self.target_modules.append(rep_name)
assert (self.target_modules
), "vllm currently does not support BNB quantization for"
), "vLLM currently does not support BNB quantization for"
f" {type(model).__name__}"
def _classify_module_sharding(self, model: nn.Module):
@ -437,6 +462,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)
elif isinstance(module, FusedMoE):
expert_mapping = model.get_expert_mapping()
for exp in expert_mapping:
if exp[-1] == "w2":
weight_name = exp[1]
rep_name = name.replace(
"experts", "") + weight_name.removesuffix(".")
self.column_sharded_weights_modules.append(rep_name)
def _verify_model_compatibility(self, model: nn.Module,
model_config: ModelConfig) -> None:
@ -490,34 +523,132 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self._get_bnb_target_modules(model)
self._classify_module_sharding(model)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
def _dequantize_dq(self, quant_states: Any):
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This
comes at the cost of increased memory usage.
"""
from bitsandbytes.functional import QuantState, dequantize_blockwise
self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)
def _dequantize_single_state(quant_state):
"""Helper function to dequantize a single QuantState object."""
if not (isinstance(quant_state, QuantState)
and quant_state.nested):
return
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
absmax = dequantize_blockwise(quant_state.absmax,
quant_state.state2)
absmax += quant_state.offset
param_dict = dict(model.named_parameters())
# Ensure float32 dtype
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None
if isinstance(quant_states, dict):
for quant_state in quant_states.values():
_dequantize_single_state(quant_state)
else:
_dequantize_single_state(quant_states)
return quant_states
def _fuse_moe_quant_states(self, model: nn.Module,
quant_states_dict: dict) -> dict:
"""
This function consolidates individual expert quantization states into
fused representations for w13 and w2.
"""
from bitsandbytes.functional import QuantState
if not hasattr(model, "get_expert_mapping"):
return dict()
expert_mapping = model.get_expert_mapping()
expert_qs_dict = {}
for name, module in model.named_modules():
if not isinstance(module, FusedMoE):
continue
w1_states_lst = []
w2_states_lst = []
w3_states_lst = []
for exp in expert_mapping:
shard_id = exp[-1]
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
layer_prefix = name.split("experts")[0]
weight_qual_name = layer_prefix + exp[1] + "weight"
quant_state = self._dequantize_dq(
quant_states_dict[weight_qual_name])
if shard_id == "w1":
w1_states_lst.append(quant_state)
elif shard_id == "w2":
w2_states_lst.append(quant_state)
else:
w3_states_lst.append(quant_state)
del quant_states_dict[weight_qual_name]
assert (len(w1_states_lst) == len(w2_states_lst) ==
len(w3_states_lst))
w13_absmax_lst = []
w2_absmax_lst = []
w13_total_dim0 = 0
w2_total_dim0 = 0
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst,
w3_states_lst):
assert w1_qs.shape == w3_qs.shape
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
# w1 and w3 are interleaved in storage
w13_absmax_lst.append(w1_qs.absmax)
w13_absmax_lst.append(w3_qs.absmax)
w2_absmax_lst.append(w2_qs.absmax)
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
w2_total_dim0 += w2_qs.shape[0]
w13_absmax = torch.cat(w13_absmax_lst)
w2_absmax = torch.cat(w2_absmax_lst)
# Create fused quantization state for w13.
w13_qs = QuantState(
absmax=w13_absmax,
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
code=w1_states_lst[0].code,
blocksize=w1_states_lst[0].blocksize,
quant_type="nf4",
dtype=w1_states_lst[0].dtype,
)
# Create fused quantization state for w2.
w2_qs = QuantState(
absmax=w2_absmax,
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
code=w2_states_lst[0].code,
blocksize=w2_states_lst[0].blocksize,
quant_type="nf4",
dtype=w2_states_lst[0].dtype,
)
# The weight suffixes .w13_weight and .w2_weight are consistent
# with the param in BitsAndBytesMoEMethod.
w13_weight_name = name + ".w13_weight"
w2_weight_name = name + ".w2_weight"
expert_qs_dict[w13_weight_name] = w13_qs
expert_qs_dict[w2_weight_name] = w2_qs
return expert_qs_dict
def _stack_quantization_states(
self, model: nn.Module,
quant_state_dict: dict) -> dict[str, dict[int, Any]]:
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
param_dict = dict(model.named_parameters())
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue
@ -558,14 +689,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
stacked_quant_state_dict[quant_param_name][shard_index] = (
quant_state_dict[non_stacked_param_name])
return stacked_quant_state_dict
def _bind_quant_states_to_params(self, model: nn.Module,
stacked_quant_state_dict: dict) -> None:
# save quant_states and offsets as the attributes of the parameters
param_dict = dict(model.named_parameters())
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
# Dequantize double quantized values during weight loading.
dequantize_dq(quant_states)
self._dequantize_dq(quant_states)
set_weight_attrs(param, {"bnb_quant_state": quant_states})
if not isinstance(quant_states, dict):
continue
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
@ -585,29 +722,40 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if self.load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
expert_quant_state_dict = self._fuse_moe_quant_states(
model, quant_state_dict)
stacked_quant_state_dict = self._stack_quantization_states(
model, quant_state_dict)
stacked_quant_state_dict = {
**expert_quant_state_dict,
**stacked_quant_state_dict
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def dequantize_dq(quant_states: dict) -> None:
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This comes
at the cost of increased memory usage.
"""
from bitsandbytes.functional import QuantState, dequantize_blockwise
for _, quant_state in quant_states.items():
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
if isinstance(quant_state, QuantState) and quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax,
quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None

View File

@ -330,6 +330,15 @@ class OlmoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@ -341,14 +350,6 @@ class OlmoeModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
@ -379,7 +380,7 @@ class OlmoeModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
for mapping in self.get_expert_mapping():
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
@ -425,6 +426,17 @@ class OlmoeModel(nn.Module):
class OlmoeForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -466,3 +478,6 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

View File

@ -516,6 +516,14 @@ class PhiMoEModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@ -672,3 +680,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

View File

@ -391,6 +391,15 @@ class Qwen2MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@ -402,14 +411,6 @@ class Qwen2MoeModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
@ -441,11 +442,13 @@ class Qwen2MoeModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
for mapping in self.get_expert_mapping():
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if "layers.13.mlp.experts.w2_weight" in name:
pass
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
@ -493,6 +496,17 @@ class Qwen2MoeModel(nn.Module):
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -538,3 +552,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

View File

@ -375,6 +375,15 @@ class Qwen3MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@ -393,12 +402,7 @@ class Qwen3MoeModel(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
@ -539,3 +543,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()