From 8020e98c9f033e76c97eb8261f772d59eba49c9a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 11 Jul 2025 16:01:13 +0800 Subject: [PATCH] [Quantization][1/N] MoE support BNB-Inflight Quantization (#20061) Signed-off-by: Jee Jee Li --- .../models/quantization/test_bitsandbytes.py | 45 +++- vllm/model_executor/layers/fused_moe/layer.py | 36 ++- .../layers/quantization/bitsandbytes.py | 232 ++++++++++++++++- .../model_loader/bitsandbytes_loader.py | 238 ++++++++++++++---- vllm/model_executor/models/olmoe.py | 33 ++- vllm/model_executor/models/phimoe.py | 11 + vllm/model_executor/models/qwen2_moe.py | 35 ++- vllm/model_executor/models/qwen3_moe.py | 19 +- 8 files changed, 561 insertions(+), 88 deletions(-) diff --git a/tests/models/quantization/test_bitsandbytes.py b/tests/models/quantization/test_bitsandbytes.py index 18662fbdd002..e53902cdb8f4 100644 --- a/tests/models/quantization/test_bitsandbytes.py +++ b/tests/models/quantization/test_bitsandbytes.py @@ -14,7 +14,7 @@ from transformers import BitsAndBytesConfig from tests.quantization.utils import is_quant_method_supported from ...utils import compare_two_settings, multi_gpu_test -from ..utils import check_embeddings_close +from ..utils import check_embeddings_close, check_logprobs_close models_4bit_to_test = [ ("facebook/opt-125m", "quantize opt model inflight"), @@ -26,6 +26,10 @@ models_4bit_to_embedding_test = [ ("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"), ] +models_4bit_to_moe_test = [ + ("allenai/OLMoE-1B-7B-0125-Instruct", "quantize moe model inflight"), +] + models_pre_qaunt_4bit_to_test = [ ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', 'read pre-quantized 4-bit FP4 model'), @@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None: compare_two_settings(model_name, common_args, pp_args) +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test) +def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + )) + with vllm_runner(model_name, + quantization='bitsandbytes', + enforce_eager=False) as llm: + vllm_outputs = llm.generate_greedy_logprobs(example_prompts, + max_tokens=32, + num_logprobs=5) + + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + transformers_outputs = llm.generate_greedy_logprobs_limit( + example_prompts, max_tokens=32, num_logprobs=5) + check_logprobs_close( + outputs_0_lst=transformers_outputs, + outputs_1_lst=vllm_outputs, + name_0="transformers", + name_1="vllm", + ) + + @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", @@ -182,7 +215,8 @@ def validate_generated_texts(hf_runner, model_name, pre_quant=False, hf_model_kwargs=None, - vllm_tp_size=1): + vllm_tp_size=1, + max_tokens=8): # NOTE: run vLLM first, as it requires a clean process # when using distributed inference @@ -190,7 +224,8 @@ def validate_generated_texts(hf_runner, quantization=None if pre_quant else 'bitsandbytes', tensor_parallel_size=vllm_tp_size, enforce_eager=False) as llm: - vllm_outputs = llm.generate_greedy(prompts, 8) + + vllm_outputs = llm.generate_greedy(prompts, max_tokens) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") # Clean up the GPU memory for the next test @@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner, # Run with HF runner with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: - hf_outputs = llm.generate_greedy(prompts, 8) + hf_outputs = llm.generate_greedy(prompts, max_tokens) hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") # Clean up the GPU memory for the next test gc.collect() torch.cuda.empty_cache() - # Compare the generated strings for hf_log, vllm_log in zip(hf_logs, vllm_logs): hf_str = hf_log["generated_text"] vllm_str = vllm_log["generated_text"] prompt = hf_log["prompt"] - assert hf_str == vllm_str, (f"Model: {model_name}" f"Mismatch between HF and vLLM outputs:\n" f"Prompt: {prompt}\n" diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 36ac75a8df4b..4a31e7d8edfa 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -883,14 +883,21 @@ class FusedMoE(torch.nn.Module): expert_data=expert_data, tp_rank=tp_rank) - def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + def _load_w13(self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -998,6 +1005,27 @@ class FusedMoE(torch.nn.Module): param.data.copy_(loaded_weight) return True if return_success else None + # Case for BitsAndBytes + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + if use_bitsandbytes_4bit: + shard_dim = 0 + + expert_data = param.data[expert_id] + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + # BNB inflight quantization has already sharded the weights + full_load = True + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + load_full=full_load, + ) + return True if return_success else None + # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors # should be whatever dimension intermediate_size_per_partition is diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 1ed3ef8d2173..20625f587f51 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any, Callable, Optional, Union import torch +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, set_weight_attrs) @@ -120,12 +123,15 @@ class BitsAndBytesConfig(QuantizationConfig): llm_int8_skip_modules=llm_int8_skip_modules, llm_int8_threshold=llm_int8_threshold) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["LinearMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]: if isinstance(layer, LinearBase): if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules): return UnquantizedLinearMethod() return BitsAndBytesLinearMethod(self) + elif isinstance(layer, FusedMoE): + return BitsAndBytesMoEMethod(self) return None @@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]): return substr_check or prefix_check +def calculate_quant_ratio(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits + else: + return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits + + class BitsAndBytesLinearMethod(LinearMethodBase): """Linear method for BitsAndBytes. @@ -173,12 +186,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase): **extra_weight_attrs): from bitsandbytes.nn import Int8Params - def calculate_quant_ratio(dtype): - if dtype.is_floating_point: - return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits - else: - return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits - def create_qweight_for_8bit(): qweight = Int8Params( data=torch.empty(sum(output_partition_sizes), @@ -394,3 +401,210 @@ try: except AttributeError as error: raise error + + +class BitsAndBytesMoEMethod(FusedMoEMethodBase): + """MoE method for BitsAndBytes. + + Args: + quant_config: The BitsAndBytes quantization config. + """ + + def __init__(self, quant_config: BitsAndBytesConfig): + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.45.3": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.45.3.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.45.3 via " + "`pip install bitsandbytes>=0.45.3` to use " + "bitsandbytes quantizer.") from err + self.topk_indices_dtype = None + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.quant_config.load_in_8bit: + call_fun = self._create_weights_8bit + else: + call_fun = self._create_weights_4bit + call_fun( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `BitsAndBytesMoEMethod` yet.") + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) + if self.quant_config.load_in_8bit: + w13, w2 = self._apply_8bit_dequant(layer) + else: + w13, w2 = self._apply_4bit_dequnt(layer) + return fused_experts( + hidden_states=x, + w1=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + + def _create_weights_4bit( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + quant_ratio = calculate_quant_ratio(params_dtype) + # Fused gate_up_proj (column parallel) + w13_total_size = (hidden_size * 2 * + intermediate_size_per_partition) // quant_ratio + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_total_size, + 1, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + set_weight_attrs( + w13_qweight, + { + "num_experts": + num_experts, + "input_dim": + hidden_size, + "output_dim": + 2 * intermediate_size_per_partition, + "experts_shape": ( + num_experts, + intermediate_size_per_partition * 2, + hidden_size, + ), + "pack_factor": + quant_ratio, + "use_bitsandbytes_4bit": + True, + }, + ) + # down_proj (row parallel) + w2_total_size = (hidden_size * + intermediate_size_per_partition) // quant_ratio + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + w2_total_size, + 1, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs( + w2_qweight, + { + "num_experts": + num_experts, + "input_dim": + intermediate_size_per_partition, + "output_dim": + hidden_size, + "experts_shape": ( + num_experts, + hidden_size, + intermediate_size_per_partition, + ), + "pack_factor": + quant_ratio, + "use_bitsandbytes_4bit": + True, + }, + ) + layer.register_parameter("w2_weight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + def _create_weights_8bit( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + def _apply_4bit_dequnt( + self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + from bitsandbytes.functional import dequantize_4bit + w13 = dequantize_4bit( + layer.w13_weight.reshape(-1, 1), + layer.w13_weight.bnb_quant_state, + ) + w2 = dequantize_4bit( + layer.w2_weight.reshape(-1, 1), + layer.w2_weight.bnb_quant_state, + ) + w13 = w13.reshape(layer.w13_weight.experts_shape) + w2 = w2.reshape(layer.w2_weight.experts_shape) + return w13, w2 + + def _apply_8bit_dequant( + self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 8e330f7eeaf4..d22b1e7b67d4 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -20,6 +20,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) # yapf: enable from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -411,9 +412,33 @@ class BitsAndBytesModelLoader(BaseModelLoader): # in case model has a mixture of disk-merged and disk-split # weights with same last name. self.target_modules.append(name) + elif (isinstance(module, FusedMoE) + and hasattr(module.quant_method, "quant_config")): + if not hasattr(model, "get_expert_mapping"): + raise AttributeError( + f"MoE Model {type(model).__name__} does not support " + "BitsAndBytes quantization yet. Ensure this model has " + "'get_expert_mapping' method.") + # TODO: support FusedMoE with prequant and 8bit. + if self.pre_quant: + raise ValueError( + "Prequant BitsAndBytes models with FusedMoE is not " + "supported yet.") + if self.load_8bit: + raise ValueError( + "BitsAndBytes 8bit quantization with FusedMoE is not " + "supported yet.") + # Get the corresponding weight name using module name and + # get_expert_mapping. + expert_mapping = model.get_expert_mapping() + for exp in expert_mapping: + weight_name = exp[1] + rep_name = name.replace("experts", + "") + weight_name.removesuffix(".") + self.target_modules.append(rep_name) assert (self.target_modules - ), "vllm currently does not support BNB quantization for" + ), "vLLM currently does not support BNB quantization for" f" {type(model).__name__}" def _classify_module_sharding(self, model: nn.Module): @@ -437,6 +462,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): # dimension (dim=-1) elif isinstance(module, (RowParallelLinear, )): self.column_sharded_weights_modules.append(name) + elif isinstance(module, FusedMoE): + expert_mapping = model.get_expert_mapping() + for exp in expert_mapping: + if exp[-1] == "w2": + weight_name = exp[1] + rep_name = name.replace( + "experts", "") + weight_name.removesuffix(".") + self.column_sharded_weights_modules.append(rep_name) def _verify_model_compatibility(self, model: nn.Module, model_config: ModelConfig) -> None: @@ -490,34 +523,132 @@ class BitsAndBytesModelLoader(BaseModelLoader): self._get_bnb_target_modules(model) self._classify_module_sharding(model) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def _dequantize_dq(self, quant_states: Any): + """ + When BNB employs Double Quantization, we perform the dequantization of + these constants during weight loading rather than at inference time, + thereby avoiding this computational overhead during inference. This + comes at the cost of increased memory usage. + """ + from bitsandbytes.functional import QuantState, dequantize_blockwise - self._verify_model_compatibility(model, model_config) - self._initialize_loader_state(model, model_config) + def _dequantize_single_state(quant_state): + """Helper function to dequantize a single QuantState object.""" + if not (isinstance(quant_state, QuantState) + and quant_state.nested): + return - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") - qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator( - model_config.model, - model_config.revision, - )) - weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights(qweight_iterator) - # Some models may have weights loading tracker unimplemented. - if loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + # Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356 + absmax = dequantize_blockwise(quant_state.absmax, + quant_state.state2) + absmax += quant_state.offset - param_dict = dict(model.named_parameters()) + # Ensure float32 dtype + if absmax.dtype != torch.float32: + absmax = absmax.float() + + quant_state.absmax = absmax + quant_state.nested = False + quant_state.offset = None + quant_state.state2 = None + + if isinstance(quant_states, dict): + for quant_state in quant_states.values(): + _dequantize_single_state(quant_state) + else: + _dequantize_single_state(quant_states) + return quant_states + + def _fuse_moe_quant_states(self, model: nn.Module, + quant_states_dict: dict) -> dict: + """ + + This function consolidates individual expert quantization states into + fused representations for w13 and w2. + """ + from bitsandbytes.functional import QuantState + + if not hasattr(model, "get_expert_mapping"): + return dict() + + expert_mapping = model.get_expert_mapping() + expert_qs_dict = {} + for name, module in model.named_modules(): + if not isinstance(module, FusedMoE): + continue + w1_states_lst = [] + w2_states_lst = [] + w3_states_lst = [] + for exp in expert_mapping: + shard_id = exp[-1] + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + layer_prefix = name.split("experts")[0] + weight_qual_name = layer_prefix + exp[1] + "weight" + quant_state = self._dequantize_dq( + quant_states_dict[weight_qual_name]) + if shard_id == "w1": + w1_states_lst.append(quant_state) + elif shard_id == "w2": + w2_states_lst.append(quant_state) + else: + w3_states_lst.append(quant_state) + del quant_states_dict[weight_qual_name] + assert (len(w1_states_lst) == len(w2_states_lst) == + len(w3_states_lst)) + w13_absmax_lst = [] + w2_absmax_lst = [] + w13_total_dim0 = 0 + w2_total_dim0 = 0 + for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, + w3_states_lst): + assert w1_qs.shape == w3_qs.shape + assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize + assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype + # w1 and w3 are interleaved in storage + w13_absmax_lst.append(w1_qs.absmax) + w13_absmax_lst.append(w3_qs.absmax) + w2_absmax_lst.append(w2_qs.absmax) + w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0] + w2_total_dim0 += w2_qs.shape[0] + + w13_absmax = torch.cat(w13_absmax_lst) + w2_absmax = torch.cat(w2_absmax_lst) + # Create fused quantization state for w13. + w13_qs = QuantState( + absmax=w13_absmax, + shape=(w13_total_dim0, w1_states_lst[0].shape[1]), + code=w1_states_lst[0].code, + blocksize=w1_states_lst[0].blocksize, + quant_type="nf4", + dtype=w1_states_lst[0].dtype, + ) + # Create fused quantization state for w2. + w2_qs = QuantState( + absmax=w2_absmax, + shape=(w2_total_dim0, w2_states_lst[0].shape[1]), + code=w2_states_lst[0].code, + blocksize=w2_states_lst[0].blocksize, + quant_type="nf4", + dtype=w2_states_lst[0].dtype, + ) + # The weight suffixes .w13_weight and .w2_weight are consistent + # with the param in BitsAndBytesMoEMethod. + w13_weight_name = name + ".w13_weight" + w2_weight_name = name + ".w2_weight" + expert_qs_dict[w13_weight_name] = w13_qs + expert_qs_dict[w2_weight_name] = w2_qs + return expert_qs_dict + + def _stack_quantization_states( + self, model: nn.Module, + quant_state_dict: dict) -> dict[str, dict[int, Any]]: stacked_quant_state_dict: dict[str, dict[int, Any]] = {} # TODO: Change this lazy import to normal import # after the checks are updated to run on a new version from vllm.model_executor.models.utils import is_pp_missing_parameter - + param_dict = dict(model.named_parameters()) for quant_param_name in quant_state_dict: if is_pp_missing_parameter(quant_param_name, model): continue @@ -558,14 +689,20 @@ class BitsAndBytesModelLoader(BaseModelLoader): stacked_quant_state_dict[quant_param_name][shard_index] = ( quant_state_dict[non_stacked_param_name]) + return stacked_quant_state_dict + def _bind_quant_states_to_params(self, model: nn.Module, + stacked_quant_state_dict: dict) -> None: # save quant_states and offsets as the attributes of the parameters + param_dict = dict(model.named_parameters()) for param_name, param in param_dict.items(): if param_name in stacked_quant_state_dict: quant_states = stacked_quant_state_dict[param_name] # Dequantize double quantized values during weight loading. - dequantize_dq(quant_states) + self._dequantize_dq(quant_states) set_weight_attrs(param, {"bnb_quant_state": quant_states}) + if not isinstance(quant_states, dict): + continue pack_ratio = getattr(param, "pack_factor", -1) if pack_ratio == -1: @@ -585,29 +722,40 @@ class BitsAndBytesModelLoader(BaseModelLoader): if self.load_8bit: set_weight_attrs( param, {"matmul_state": [None] * len(quant_states)}) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + + self._verify_model_compatibility(model, model_config) + self._initialize_loader_state(model, model_config) + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator( + model_config.model, + model_config.revision, + )) + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + expert_quant_state_dict = self._fuse_moe_quant_states( + model, quant_state_dict) + + stacked_quant_state_dict = self._stack_quantization_states( + model, quant_state_dict) + + stacked_quant_state_dict = { + **expert_quant_state_dict, + **stacked_quant_state_dict + } + self._bind_quant_states_to_params(model, stacked_quant_state_dict) torch.cuda.empty_cache() def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - - -def dequantize_dq(quant_states: dict) -> None: - """ - When BNB employs Double Quantization, we perform the dequantization of - these constants during weight loading rather than at inference time, - thereby avoiding this computational overhead during inference. This comes - at the cost of increased memory usage. - """ - from bitsandbytes.functional import QuantState, dequantize_blockwise - for _, quant_state in quant_states.items(): - # Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356 - if isinstance(quant_state, QuantState) and quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, - quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - quant_state.absmax = absmax - quant_state.nested = False - quant_state.offset = None - quant_state.state2 = None diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index ebfdb690fe29..33438216ac1a 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -330,6 +330,15 @@ class OlmoeModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -341,14 +350,6 @@ class OlmoeModel(nn.Module): ("gate_up_proj", "up_proj", 1), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -379,7 +380,7 @@ class OlmoeModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - for mapping in expert_params_mapping: + for mapping in self.get_expert_mapping(): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue @@ -425,6 +426,17 @@ class OlmoeModel(nn.Module): class OlmoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -466,3 +478,6 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 2ab4edc18ccf..0fc64e88a6b6 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -516,6 +516,14 @@ class PhiMoEModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -672,3 +680,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index a2c65f4b5edb..597f4c7e1206 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -391,6 +391,15 @@ class Qwen2MoeModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -402,14 +411,6 @@ class Qwen2MoeModel(nn.Module): ("gate_up_proj", "up_proj", 1), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -441,11 +442,13 @@ class Qwen2MoeModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - for mapping in expert_params_mapping: + for mapping in self.get_expert_mapping(): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) + if "layers.13.mlp.experts.w2_weight" in name: + pass # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue @@ -493,6 +496,17 @@ class Qwen2MoeModel(nn.Module): class Qwen2MoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -538,3 +552,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index ff182aadf738..c87f41fa7c06 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -375,6 +375,15 @@ class Qwen3MoeModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -393,12 +402,7 @@ class Qwen3MoeModel(nn.Module): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - + expert_params_mapping = self.get_expert_mapping() params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -539,3 +543,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP): torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping()