From a99b9f7dee0ad261284cbcd823f5b37381d15ac1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 14 Jul 2025 15:34:34 +0800 Subject: [PATCH] [Quantization] add BNB for MixtralForCausalLM (#20893) Signed-off-by: Jee Jee Li --- vllm/model_executor/model_loader/utils.py | 7 +- vllm/model_executor/models/granitemoe.py | 105 +++++++++++++++++- .../model_executor/models/granitemoeshared.py | 5 +- vllm/model_executor/models/mixtral.py | 21 ++-- vllm/model_executor/models/olmoe.py | 3 +- vllm/model_executor/models/qwen2_moe.py | 3 +- vllm/model_executor/models/qwen3_moe.py | 4 +- 7 files changed, 128 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 792a1044a5640..8e5f332ba7ccf 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -227,7 +227,12 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ - "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" + "fp8", + "compressed-tensors", + "gptq_marlin", + "awq_marlin", + "quark", + "bitsandbytes", ] vllm_supported_archs = ModelRegistry.get_supported_archs() diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 5a70f3a616c6d..142b0e9672958 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -45,12 +45,14 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from . import mixtral from .interfaces import SupportsLoRA, SupportsPP -from .utils import AutoWeightsLoader, make_layers, maybe_prefix +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_layers, + maybe_prefix) class GraniteMoeMoE(nn.Module): @@ -307,6 +309,103 @@ class GraniteMoeModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def _load_weights(self, + weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + This function is copied from `MixtralModel.load_weights`, mainly to + decouple from mixtral, avoiding impact on support like BNB + quantization. + """ + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: new_weights = {} @@ -339,7 +438,7 @@ class GraniteMoeModel(nn.Module): new_weights[gate_name] = p else: new_weights[n] = p - return mixtral.MixtralModel.load_weights(self, new_weights.items()) + return self._load_weights(new_weights.items()) class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index bb160dbce45b2..7303f48537828 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -27,8 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from . import mixtral -from .granitemoe import GraniteMoeAttention, GraniteMoeMoE +from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE from .interfaces import SupportsLoRA, SupportsPP from .utils import AutoWeightsLoader, make_layers, maybe_prefix @@ -242,7 +241,7 @@ class GraniteMoeSharedModel(nn.Module): new_weights[gate_name] = p else: new_weights[n] = p - return mixtral.MixtralModel.load_weights(self, new_weights.items()) + return GraniteMoeModel._load_weights(self, new_weights.items()) class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index dec365119c725..30de83da49e0e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -317,6 +317,15 @@ class MixtralModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -326,16 +335,9 @@ class MixtralModel(nn.Module): ("qkv_proj", "v_proj", "v"), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -486,3 +488,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 33438216ac1a0..7552f64c423ea 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -352,6 +352,7 @@ class OlmoeModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -380,7 +381,7 @@ class OlmoeModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - for mapping in self.get_expert_mapping(): + for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 597f4c7e1206e..84bae87804c13 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -413,6 +413,7 @@ class Qwen2MoeModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -442,7 +443,7 @@ class Qwen2MoeModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - for mapping in self.get_expert_mapping(): + for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index c87f41fa7c064..0f749b3e38f15 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -400,11 +400,9 @@ class Qwen3MoeModel(nn.Module): ".v_scale", "_v_scale", ".weight_scale", "_weight_scale", ".input_scale", "_input_scale") - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = self.get_expert_mapping() params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below).