[Qwen] Add fp8 checkpoint support for qwen3-next. (#25079)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He 2025-09-18 16:16:04 +08:00 committed by GitHub
parent 350c94deb3
commit ef7eefe17a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 21 deletions

View File

@ -30,7 +30,6 @@ from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm)
# yapf: enable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
@ -254,12 +253,20 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# projection of the input hidden states
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
self.projection_size_ba = self.num_v_heads * 2
self.in_proj = MergedColumnParallelLinear(
self.in_proj_qkvz = ColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.projection_size_qkvz, self.projection_size_ba],
output_size=self.projection_size_qkvz,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
prefix=f"{prefix}.in_proj_qkvz",
)
# ba_proj doesn't support blockwise fp8 quantization.
self.in_proj_ba = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.projection_size_ba,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba",
)
query_key_settings = (self.key_dim, 0, False)
@ -420,19 +427,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens
# 1. Set up dimensions for reshapes later
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
if spec_token_masks is not None:
spec_token_masks = spec_token_masks[:num_actual_tokens]
projected_states_qkvz, projected_states_ba = torch.split(
projected_states,
[
self.projection_size_qkvz // self.tp_size,
self.projection_size_ba // self.tp_size
],
dim=-1,
)
# 1. Set up dimensions for reshapes later
projected_states_qkvz, _ = self.in_proj_qkvz(
hidden_states[:num_actual_tokens])
projected_states_ba, _ = self.in_proj_ba(
hidden_states[:num_actual_tokens])
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba)
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
@ -976,8 +978,6 @@ class Qwen3NextModel(nn.Module):
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("in_proj", "in_proj_qkvz", 0),
("in_proj", "in_proj_ba", 1),
]
params_dict = dict(self.named_parameters())
@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -63,7 +63,9 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False)
return_bias=False,
quant_config=quant_config,
prefix=f'{prefix}.fc')
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
@ -72,7 +74,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers))
self.make_empty_intermediate_tensors = (
@ -233,7 +235,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
self.config = config
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
prefix, "mtp"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
config.hidden_size,