From 70022ffc002dabc59b936d3fc001b94b81ba08db Mon Sep 17 00:00:00 2001 From: xiao-llm Date: Thu, 23 Oct 2025 22:14:03 -0400 Subject: [PATCH] Granite 4.0 quark quantization support (#26944) Signed-off-by: Xiao YU Signed-off-by: Xiao Yu Co-authored-by: Xiao YU --- .../model_executor/models/granitemoehybrid.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 14d3a46e54af5..1bb7f4e9b8023 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -330,6 +330,7 @@ class GraniteMoeHybridModel(nn.Module): lora_config = vllm_config.lora_config self.config = config + self.quant_config = quant_config lora_vocab = ( (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config @@ -405,6 +406,33 @@ class GraniteMoeHybridModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + # layers.0.block_sparse_moe.expert_0.input_linear.input_scale + ckpt_gate_proj_name = "gate_proj" + ckpt_down_proj_name = "down_proj" + ckpt_up_proj_name = "up_proj" + num_experts = self.config.num_local_experts + + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + "block_sparse_moe.experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "block_sparse_moe.experts.w2_", + f"block_sparse_moe.experts.{expert_id}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -414,6 +442,7 @@ class GraniteMoeHybridModel(nn.Module): ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() def _load(n, p): param = params_dict[n] @@ -435,10 +464,56 @@ class GraniteMoeHybridModel(nn.Module): weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id) loaded_params.add(n) + def _load_quant_expert(name, loaded_weight): + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + name_mapped = name.replace(weight_name, param_name) + + # Skip layers on other devices. + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + weight_loader = param.weight_loader + success = False + + if weight_loader is not None: + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + + if success: + return name_mapped + return None + for n, p in weights: if "A_log" in n: n = n.replace("A_log", "A") + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(n) + ): + # Loading kv cache quantization scales + loaded_weight = p + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + _load(scale_name, loaded_weight) + loaded_params.add(scale_name) + continue + + if _load_quant_expert(n, p): + continue + # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215 # Mapping different experts' layout: # from HF (input_linear, output_linear, router)