mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:34:58 +08:00
[Qwen] Add fp8 checkpoint support for qwen3-next. (#25079)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
parent
350c94deb3
commit
ef7eefe17a
@ -30,7 +30,6 @@ from vllm.model_executor.layers.layernorm import (
|
|||||||
GemmaRMSNorm as Qwen3NextRMSNorm)
|
GemmaRMSNorm as Qwen3NextRMSNorm)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -254,12 +253,20 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
# projection of the input hidden states
|
# projection of the input hidden states
|
||||||
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
|
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
|
||||||
self.projection_size_ba = self.num_v_heads * 2
|
self.projection_size_ba = self.num_v_heads * 2
|
||||||
self.in_proj = MergedColumnParallelLinear(
|
self.in_proj_qkvz = ColumnParallelLinear(
|
||||||
input_size=self.hidden_size,
|
input_size=self.hidden_size,
|
||||||
output_sizes=[self.projection_size_qkvz, self.projection_size_ba],
|
output_size=self.projection_size_qkvz,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.in_proj",
|
prefix=f"{prefix}.in_proj_qkvz",
|
||||||
|
)
|
||||||
|
# ba_proj doesn't support blockwise fp8 quantization.
|
||||||
|
self.in_proj_ba = ColumnParallelLinear(
|
||||||
|
input_size=self.hidden_size,
|
||||||
|
output_size=self.projection_size_ba,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.in_proj_ba",
|
||||||
)
|
)
|
||||||
|
|
||||||
query_key_settings = (self.key_dim, 0, False)
|
query_key_settings = (self.key_dim, 0, False)
|
||||||
@ -420,19 +427,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||||
|
|
||||||
# 1. Set up dimensions for reshapes later
|
|
||||||
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
|
|
||||||
if spec_token_masks is not None:
|
if spec_token_masks is not None:
|
||||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||||
projected_states_qkvz, projected_states_ba = torch.split(
|
|
||||||
projected_states,
|
# 1. Set up dimensions for reshapes later
|
||||||
[
|
projected_states_qkvz, _ = self.in_proj_qkvz(
|
||||||
self.projection_size_qkvz // self.tp_size,
|
hidden_states[:num_actual_tokens])
|
||||||
self.projection_size_ba // self.tp_size
|
projected_states_ba, _ = self.in_proj_ba(
|
||||||
],
|
hidden_states[:num_actual_tokens])
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||||
projected_states_qkvz, projected_states_ba)
|
projected_states_qkvz, projected_states_ba)
|
||||||
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
||||||
@ -976,8 +978,6 @@ class Qwen3NextModel(nn.Module):
|
|||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
("in_proj", "in_proj_qkvz", 0),
|
|
||||||
("in_proj", "in_proj_ba", 1),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
"v_proj",
|
"v_proj",
|
||||||
],
|
],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
|||||||
@ -63,7 +63,9 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
|||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
bias=False,
|
bias=False,
|
||||||
return_bias=False)
|
return_bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f'{prefix}.fc')
|
||||||
|
|
||||||
self.layers = torch.nn.ModuleList(
|
self.layers = torch.nn.ModuleList(
|
||||||
Qwen3NextDecoderLayer(
|
Qwen3NextDecoderLayer(
|
||||||
@ -72,7 +74,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
|||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
|
prefix=f'{prefix}.layers.{idx}',
|
||||||
) for idx in range(self.num_mtp_layers))
|
) for idx in range(self.num_mtp_layers))
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
@ -233,7 +235,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
|
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(
|
prefix=maybe_prefix(
|
||||||
prefix, "model"))
|
prefix, "mtp"))
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user