From ef7eefe17a7dc212ddb8a8aabd7760218a10e25e Mon Sep 17 00:00:00 2001 From: Tao He Date: Thu, 18 Sep 2025 16:16:04 +0800 Subject: [PATCH] [Qwen] Add fp8 checkpoint support for qwen3-next. (#25079) Signed-off-by: Tao He --- vllm/model_executor/models/qwen3_next.py | 35 ++++++++++---------- vllm/model_executor/models/qwen3_next_mtp.py | 8 +++-- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ca9f4d402dac..eb060cb90f44 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm) # yapf: enable from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -254,12 +253,20 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # projection of the input hidden states self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 self.projection_size_ba = self.num_v_heads * 2 - self.in_proj = MergedColumnParallelLinear( + self.in_proj_qkvz = ColumnParallelLinear( input_size=self.hidden_size, - output_sizes=[self.projection_size_qkvz, self.projection_size_ba], + output_size=self.projection_size_qkvz, bias=False, quant_config=quant_config, - prefix=f"{prefix}.in_proj", + prefix=f"{prefix}.in_proj_qkvz", + ) + # ba_proj doesn't support blockwise fp8 quantization. + self.in_proj_ba = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_ba, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_ba", ) query_key_settings = (self.key_dim, 0, False) @@ -420,19 +427,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens - - # 1. Set up dimensions for reshapes later - projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) if spec_token_masks is not None: spec_token_masks = spec_token_masks[:num_actual_tokens] - projected_states_qkvz, projected_states_ba = torch.split( - projected_states, - [ - self.projection_size_qkvz // self.tp_size, - self.projection_size_ba // self.tp_size - ], - dim=-1, - ) + + # 1. Set up dimensions for reshapes later + projected_states_qkvz, _ = self.in_proj_qkvz( + hidden_states[:num_actual_tokens]) + projected_states_ba, _ = self.in_proj_ba( + hidden_states[:num_actual_tokens]) query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba) query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), @@ -976,8 +978,6 @@ class Qwen3NextModel(nn.Module): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - ("in_proj", "in_proj_qkvz", 0), - ("in_proj", "in_proj_ba", 1), ] params_dict = dict(self.named_parameters()) @@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, "v_proj", ], "gate_up_proj": ["gate_proj", "up_proj"], - "in_proj": ["in_proj_qkvz", "in_proj_ba"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 190a1750e673..c755eeb9b4ea 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -63,7 +63,9 @@ class Qwen3NextMultiTokenPredictor(nn.Module): self.config.hidden_size, gather_output=True, bias=False, - return_bias=False) + return_bias=False, + quant_config=quant_config, + prefix=f'{prefix}.fc') self.layers = torch.nn.ModuleList( Qwen3NextDecoderLayer( @@ -72,7 +74,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module): model_config=model_config, cache_config=cache_config, quant_config=quant_config, - prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}', + prefix=f'{prefix}.layers.{idx}', ) for idx in range(self.num_mtp_layers)) self.make_empty_intermediate_tensors = ( @@ -233,7 +235,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP): self.config = config self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config, prefix=maybe_prefix( - prefix, "model")) + prefix, "mtp")) self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead(self.unpadded_vocab_size, config.hidden_size,