From 5c3f2628d5ddad55fd350ef733167fd7050e4ac6 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 25 Jul 2025 18:57:34 +0800 Subject: [PATCH] [Quantization] Enable BNB support for more MoE models (#21370) Signed-off-by: Jee Jee Li --- vllm/model_executor/models/dots1.py | 148 +++++++++++++------------ vllm/model_executor/models/glm4_moe.py | 25 +++-- 2 files changed, 93 insertions(+), 80 deletions(-) diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 4bdcbfabbbc26..9b21a79446138 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -54,8 +54,8 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -327,6 +327,7 @@ class Dots1DecoderLayer(nn.Module): return hidden_states, residual +@support_torch_compile class Dots1Model(nn.Module): fall_back_to_pt_during_load = False @@ -404,68 +405,12 @@ class Dots1Model(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -@support_torch_compile -class Dots1ForCausalLM(nn.Module, SupportsPP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = Dots1Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - else: - self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model( - input_ids, - positions, - intermediate_tensors, - inputs_embeds, - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -477,14 +422,9 @@ class Dots1ForCausalLM(nn.Module, SupportsPP): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -534,3 +474,71 @@ class Dots1ForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Dots1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 43824abb571a6..6a196fef572de 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -461,6 +461,15 @@ class Glm4MoeModel(nn.Module): device=device), }) + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -472,16 +481,9 @@ class Glm4MoeModel(nn.Module): ("gate_up_proj", "up_proj", 1), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: @@ -570,7 +572,7 @@ class Glm4MoeModel(nn.Module): return loaded_params -class Glm4MoeForCausalLM(nn.Module, SupportsPP): +class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -677,6 +679,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP): loader = AutoWeightsLoader(self) return loader.load_weights(weights) + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, weight_name: str) -> Optional[int]: