mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:45:00 +08:00
[Qwen] Add fp8 checkpoint support for qwen3-next. (#25079)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
parent
350c94deb3
commit
ef7eefe17a
@ -30,7 +30,6 @@ from vllm.model_executor.layers.layernorm import (
|
||||
GemmaRMSNorm as Qwen3NextRMSNorm)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
@ -254,12 +253,20 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
# projection of the input hidden states
|
||||
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
|
||||
self.projection_size_ba = self.num_v_heads * 2
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
self.in_proj_qkvz = ColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[self.projection_size_qkvz, self.projection_size_ba],
|
||||
output_size=self.projection_size_qkvz,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
prefix=f"{prefix}.in_proj_qkvz",
|
||||
)
|
||||
# ba_proj doesn't support blockwise fp8 quantization.
|
||||
self.in_proj_ba = ColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.projection_size_ba,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj_ba",
|
||||
)
|
||||
|
||||
query_key_settings = (self.key_dim, 0, False)
|
||||
@ -420,19 +427,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||
|
||||
# 1. Set up dimensions for reshapes later
|
||||
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
|
||||
if spec_token_masks is not None:
|
||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||
projected_states_qkvz, projected_states_ba = torch.split(
|
||||
projected_states,
|
||||
[
|
||||
self.projection_size_qkvz // self.tp_size,
|
||||
self.projection_size_ba // self.tp_size
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# 1. Set up dimensions for reshapes later
|
||||
projected_states_qkvz, _ = self.in_proj_qkvz(
|
||||
hidden_states[:num_actual_tokens])
|
||||
projected_states_ba, _ = self.in_proj_ba(
|
||||
hidden_states[:num_actual_tokens])
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||
projected_states_qkvz, projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
||||
@ -976,8 +978,6 @@ class Qwen3NextModel(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("in_proj", "in_proj_qkvz", 0),
|
||||
("in_proj", "in_proj_ba", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@ -63,7 +63,9 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
self.config.hidden_size,
|
||||
gather_output=True,
|
||||
bias=False,
|
||||
return_bias=False)
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.fc')
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
Qwen3NextDecoderLayer(
|
||||
@ -72,7 +74,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
|
||||
prefix=f'{prefix}.layers.{idx}',
|
||||
) for idx in range(self.num_mtp_layers))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
@ -233,7 +235,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
|
||||
self.config = config
|
||||
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
prefix, "mtp"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user