# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Qwen3Next model.""" from collections.abc import Iterable from itertools import islice import torch from einops import rearrange from torch import nn from transformers.activations import ACT2FN from vllm.attention import Attention, AttentionBackend, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config, ) from vllm.distributed import ( divide, get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fla.ops import ( chunk_gated_delta_rule, fused_recurrent_gated_delta_rule, ) from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader, ) from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .interfaces import ( HasInnerState, IsHybrid, MixtureOfExperts, SupportsLoRA, SupportsPP, ) from .utils import ( AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) logger = init_logger(__name__) KVCache = tuple[torch.Tensor, torch.Tensor] class Qwen3NextSparseMoeBlock(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config parallel_config = vllm_config.parallel_config quant_config = vllm_config.quant_config self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." ) # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = self.ep_rank * self.n_local_physical_experts self.physical_expert_end = ( self.physical_expert_start + self.n_local_physical_experts ) self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate", ) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen3NextMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, expert_gate=self.shared_expert_gate, prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None self.experts = SharedFusedMoE( shared_experts=self.shared_expert, gate=self.gate, num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, routing_method_type=RoutingMethodType.Renormalize, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) if self.experts.is_internal_router: # In this case, the gate/router runs inside the FusedMoE class final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=hidden_states ) else: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) if self.shared_expert is not None: final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0 ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states ) return final_hidden_states.view(orig_shape) class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): @property def mamba_type(self) -> str: return "linear_attention" def get_attn_backend(self) -> type["AttentionBackend"]: from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend return GDNAttentionBackend def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( self.model_config.dtype, self.cache_config.mamba_cache_dtype ) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.gated_delta_net_state_shape( self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, self.head_v_dim, self.conv_kernel_size, self.num_spec, ) def __init__( self, config: Qwen3NextConfig, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, speculative_config: SpeculativeConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.hidden_size = config.hidden_size self.num_v_heads = config.linear_num_value_heads self.num_k_heads = config.linear_num_key_heads self.head_k_dim = config.linear_key_head_dim self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads self.value_dim = self.head_v_dim * self.num_v_heads self.conv_kernel_size = config.linear_conv_kernel_dim self.layer_idx = extract_layer_index(prefix) self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] self.layer_norm_epsilon = config.rms_norm_eps self.prefix = prefix self.config = config self.model_config = model_config self.cache_config = cache_config self.quant_config = quant_config self.speculative_config = speculative_config self.num_spec = ( self.speculative_config.num_speculative_tokens if self.speculative_config else 0 ) # QKV self.conv_dim = self.key_dim * 2 + self.value_dim self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.conv_dim, bias=False, prefix=f"{prefix}.conv1d", ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) # projection of the input hidden states self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 self.projection_size_ba = self.num_v_heads * 2 self.in_proj_qkvz = ColumnParallelLinear( input_size=self.hidden_size, output_size=self.projection_size_qkvz, bias=False, quant_config=quant_config, prefix=f"{prefix}.in_proj_qkvz", ) # ba_proj doesn't support blockwise fp8 quantization. self.in_proj_ba = ColumnParallelLinear( input_size=self.hidden_size, output_size=self.projection_size_ba, bias=False, quant_config=quant_config, prefix=f"{prefix}.in_proj_ba", ) query_key_settings = (self.key_dim, 0, False) value_settings = (self.value_dim, 0, False) delattr(self.conv1d.weight, "weight_loader") set_weight_attrs( self.conv1d.weight, { "weight_loader": mamba_v2_sharded_weight_loader( [ query_key_settings, query_key_settings, value_settings, ], self.tp_size, self.tp_rank, ) }, ) # selective projection used to make dt, B and C input dependant # time step projection (discretization) # instantiate once and copy inv_dt in init_weights of PretrainedModel self.dt_bias = nn.Parameter( torch.ones(self.num_v_heads // self.tp_size), ) self.A_log = nn.Parameter( torch.empty( divide(self.num_v_heads, self.tp_size), ) ) set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.norm = RMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, group_size=None, norm_before_gate=True, device=current_platform.current_device(), dtype=config.dtype, ) self.out_proj = RowParallelLinear( self.value_dim, self.hidden_size, bias=False, input_is_parallel=True, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self def fix_query_key_value_ordering( self, mixed_qkvz, mixed_ba, ): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. """ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, ( self.head_k_dim + self.head_k_dim + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // self.num_k_heads ), ) new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, 2 * self.num_v_heads // self.num_k_heads, ) mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) mixed_ba = mixed_ba.view(*new_tensor_shape_ba) split_arg_list_qkvz = [ self.head_k_dim, self.head_k_dim, (self.num_v_heads // self.num_k_heads * self.head_v_dim), (self.num_v_heads // self.num_k_heads * self.head_v_dim), ] split_arg_list_ba = [ self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads, ] # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] value = value.reshape(value.size(0), -1, self.head_v_dim) z = z.reshape(z.size(0), -1, self.head_v_dim) b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) return query, key, value, z, b, a def rearrange_mixed_qkv(self, mixed_qkv): if mixed_qkv is None: return None, None, None query, key, value = torch.split( mixed_qkv, [ self.key_dim // self.tp_size, self.key_dim // self.tp_size, self.value_dim // self.tp_size, ], dim=-1, ) query, key = map( lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), (query, key), ) value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) return query.contiguous(), key.contiguous(), value.contiguous() def forward( self, hidden_states: torch.Tensor, output: torch.Tensor, ): """ Forward pass with three parts: 1. Input projection 2. Core attention (custom op) 3. Output projection """ num_tokens = hidden_states.size(0) # ============================================================ # Part 1: Input Projection # ============================================================ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_ba, _ = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) query, key, value = map( lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) ) mixed_qkv = torch.cat((query, key, value), dim=-1) # ============================================================ # Part 2: Core Attention (Custom Op) # ============================================================ # Note: we should not use torch.empty here like other attention backends, # see discussions in https://github.com/vllm-project/vllm/pull/28182 core_attn_out = torch.zeros( (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), dtype=hidden_states.dtype, device=hidden_states.device, ) torch.ops.vllm.gdn_attention_core( mixed_qkv, b, a, core_attn_out, self.prefix, ) # ============================================================ # Part 3: Output Projection # ============================================================ z_shape_og = z.shape # Reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) def _forward_core( self, mixed_qkv: torch.Tensor, b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, ): """ Core attention computation (called by custom op). """ forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: # V1 profile run return assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, GDNAttentionMetadata) has_initial_state = attn_metadata.has_initial_state spec_query_start_loc = attn_metadata.spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc spec_sequence_masks = attn_metadata.spec_sequence_masks spec_token_indx = attn_metadata.spec_token_indx non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens mixed_qkv = mixed_qkv[:num_actual_tokens] b = b[:num_actual_tokens] a = a[:num_actual_tokens] # 1. Convolution sequence transformation conv_weights = self.conv1d.weight.view( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) else: mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv # 1.1: Process the multi-query part if spec_sequence_masks is not None: mixed_qkv_spec = causal_conv1d_update( mixed_qkv_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, conv_state_indices=spec_state_indices_tensor[:, 0][ : attn_metadata.num_spec_decodes ], num_accepted_tokens=num_accepted_tokens, query_start_loc=spec_query_start_loc, max_query_len=spec_state_indices_tensor.size(-1), validate_data=False, ) # 1.2: Process the remaining part if attn_metadata.num_prefills > 0: mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" mixed_qkv_non_spec = causal_conv1d_fn( mixed_qkv_non_spec_T, conv_weights, self.conv1d.bias, activation=self.activation, conv_states=conv_state, has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, conv_state_indices=non_spec_state_indices_tensor[ : attn_metadata.num_actual_tokens ], validate_data=True, ) else: mixed_qkv_non_spec = None query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( mixed_qkv_non_spec ) g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: g_spec = g beta_spec = beta g_non_spec = None beta_non_spec = None else: g_spec = g.index_select(1, spec_token_indx) beta_spec = beta.index_select(1, spec_token_indx) g_non_spec = g.index_select(1, non_spec_token_indx) beta_non_spec = beta.index_select(1, non_spec_token_indx) else: g_spec = None beta_spec = None g_non_spec = g beta_non_spec = beta # 2. Recurrent attention # 2.1: Process the multi-query part if spec_sequence_masks is not None: core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( q=query_spec, k=key_spec, v=value_spec, g=g_spec, beta=beta_spec, initial_state=ssm_state, inplace_final_state=True, cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], ssm_state_indices=spec_state_indices_tensor, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=True, ) else: core_attn_out_spec, last_recurrent_state = None, None # 2.2: Process the remaining part if attn_metadata.num_prefills > 0: initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 ( core_attn_out_non_spec, last_recurrent_state, ) = chunk_gated_delta_rule( q=query_non_spec, k=key_non_spec, v=value_non_spec, g=g_non_spec, beta=beta_non_spec, initial_state=initial_state, output_final_state=True, cu_seqlens=non_spec_query_start_loc, head_first=False, use_qk_l2norm_in_kernel=True, ) # Init cache ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( ssm_state.dtype ) elif attn_metadata.num_decodes > 0: core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule( q=query_non_spec, k=key_non_spec, v=value_non_spec, g=g_non_spec, beta=beta_non_spec, initial_state=ssm_state, inplace_final_state=True, cu_seqlens=non_spec_query_start_loc[ : attn_metadata.num_decodes + 1 ], ssm_state_indices=non_spec_state_indices_tensor, use_qk_l2norm_in_kernel=True, ) ) else: core_attn_out_non_spec, last_recurrent_state = None, None # 3. Merge core attention output if spec_sequence_masks is not None and core_attn_out_non_spec is not None: merged_out = torch.empty( (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), dtype=core_attn_out_non_spec.dtype, device=core_attn_out_non_spec.device, ) merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) elif spec_sequence_masks is not None: core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) else: core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) class Qwen3NextAttention(nn.Module): def __init__( self, config: Qwen3NextConfig, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_key_value_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.dual_chunk_attention_config = getattr( config, "dual_chunk_attention_config", None ) self.attn_output_gate = getattr(config, "attn_output_gate", True) self.qkv_proj = QKVParallelLinear( config.hidden_size, self.head_dim, self.total_num_heads * (1 + self.attn_output_gate), self.total_num_kv_heads, bias=getattr(config, "qkv_bias", False), quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( head_size=self.head_dim, rotary_dim=self.head_dim, max_position=config.max_position_embeddings, base=config.rope_theta, rope_scaling=config.rope_scaling, partial_rotary_factor=config.partial_rotary_factor, dual_chunk_attention_config=self.dual_chunk_attention_config, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": self.dual_chunk_attention_config, } if self.dual_chunk_attention_config else {}, ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, output: torch.Tensor, hidden_states: torch.Tensor, ): qkv, _ = self.qkv_proj(hidden_states) if self.attn_output_gate: q_gate, k, v = qkv.split( [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 ) orig_shape = q_gate.shape[:-1] q_gate = q_gate.view(*orig_shape, self.num_heads, -1) q, gate = torch.chunk(q_gate, 2, dim=-1) q = q.reshape(*orig_shape, -1) gate = gate.reshape(*orig_shape, -1) else: q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( -1, self.num_heads * self.head_dim ) k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( -1, self.num_kv_heads * self.head_dim ) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) if self.attn_output_gate: gate = torch.sigmoid(gate) attn_output = attn_output * gate output[:], _ = self.o_proj(attn_output) class Qwen3NextDecoderLayer(nn.Module): def __init__( self, vllm_config: VllmConfig, layer_type: str, prefix: str = "", ) -> None: super().__init__() config = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config speculative_config = vllm_config.speculative_config self.layer_type = layer_type self.layer_idx = extract_layer_index(prefix) if self.layer_type == "linear_attention": self.linear_attn = Qwen3NextGatedDeltaNet( config, model_config=model_config, cache_config=cache_config, quant_config=quant_config, speculative_config=speculative_config, prefix=f"{prefix}.linear_attn", ) elif self.layer_type == "full_attention": self.self_attn = Qwen3NextAttention( config, model_config=model_config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) else: raise ValueError(f"Invalid layer_type {self.layer_type}") mlp_only_layers = ( [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) if (self.layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (self.layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3NextSparseMoeBlock( vllm_config=vllm_config, prefix=f"{prefix}.mlp", ) else: self.mlp = Qwen3NextMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, ) self.input_layernorm = Qwen3NextRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = Qwen3NextRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.layer_scale = getattr(config, "layer_scale", False) if self.layer_scale: self.attn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, config.hidden_size, dtype=config.dtype, ), ) self.ffn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, config.hidden_size, dtype=config.dtype, ), ) def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor | None, positions: torch.Tensor = None, **kwargs: object, ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) self_attention_output = torch.empty_like(hidden_states) if self.layer_type == "linear_attention": self.linear_attn( hidden_states=hidden_states, output=self_attention_output, ) elif self.layer_type == "full_attention": self.self_attn( hidden_states=hidden_states, output=self_attention_output, positions=positions, ) else: raise ValueError("Invalid layer_type") hidden_states = self_attention_output if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 ) else: hidden_states = hidden_states * ( self.attn_layer_scale.to(hidden_states.dtype) + 1 ) # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 ) else: assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), ( f"shape must be the same {len(hidden_states.shape)}, " f"{len(self.ffn_layer_scale.shape)}" ) hidden_states = hidden_states * ( self.ffn_layer_scale.to(hidden_states.dtype) + 1 ) return hidden_states, residual @support_torch_compile class Qwen3NextModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen3NextConfig = vllm_config.model_config.hf_config parallel_config = vllm_config.parallel_config eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, ) def get_layer(prefix: str): return Qwen3NextDecoderLayer( vllm_config, layer_type=config.layer_types[extract_layer_index(prefix)], prefix=prefix, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) if get_pp_group().is_last_rank: self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, num_redundant_experts=self.num_redundant_experts, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if name.startswith("mtp."): continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # name = apply_attn_prefix(name, params_dict) if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. if ( name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class QwenNextMixtureOfExperts(MixtureOfExperts): def update_physical_experts_metadata( self, num_physical_experts: int, num_local_physical_experts: int, ) -> None: assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): moe = layer.mlp moe.n_local_physical_experts = num_local_physical_experts moe.n_physical_experts = num_physical_experts moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() def set_moe_parameters(self): self.expert_weights = [] self.moe_layers = [] example_moe = None for layer in self.model.layers: if isinstance(layer, Qwen3NextDecoderLayer) and isinstance( layer.mlp, Qwen3NextSparseMoeBlock ): example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) if example_moe is None: raise RuntimeError("No Qwen3Next layer found in the model.layers.") # Set MoE hyperparameters self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 self.num_logical_experts = example_moe.n_logical_experts self.num_physical_experts = example_moe.n_physical_experts self.num_local_physical_experts = example_moe.n_local_physical_experts self.num_routed_experts = example_moe.n_routed_experts self.num_redundant_experts = example_moe.n_redundant_experts class Qwen3NextForCausalLM( nn.Module, HasInnerState, SupportsLoRA, SupportsPP, QwenNextMixtureOfExperts, IsHybrid, ): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, ( "Qwen3Next currently does not support prefix caching" ) self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config self.model = Qwen3NextModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) # Set MoE hyperparameters self.set_moe_parameters() def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype ) @classmethod def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig" ) -> tuple[tuple[int, int], tuple[int, int]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config tp_size = parallel_config.tensor_parallel_size num_spec = ( vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0 ) return MambaStateShapeCalculator.gated_delta_net_state_shape( tp_size, hf_config.linear_num_key_heads, hf_config.linear_num_value_heads, hf_config.linear_key_head_dim, hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim, num_spec, ) def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=["mtp."], ) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() def gdn_attention_core( mixed_qkv: torch.Tensor, b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, ) -> None: """ Custom op for the core attention computation. Only handles the convolution + recurrent attention part. Input/output projections are handled outside this op. """ forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] self._forward_core( mixed_qkv=mixed_qkv, b=b, a=a, core_attn_out=core_attn_out, ) def gdn_attention_core_fake( mixed_qkv: torch.Tensor, b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, ) -> None: """Fake implementation for torch.compile.""" return direct_register_custom_op( op_name="gdn_attention_core", op_func=gdn_attention_core, mutates_args=["core_attn_out"], fake_impl=gdn_attention_core_fake, ) @triton.jit def fused_gdn_gating_kernel( g, beta_output, A_log, a, b, dt_bias, seq_len, NUM_HEADS: tl.constexpr, beta: tl.constexpr, threshold: tl.constexpr, BLK_HEADS: tl.constexpr, ): i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off mask = head_off < NUM_HEADS blk_A_log = tl.load(A_log + head_off, mask=mask) blk_a = tl.load(a + off, mask=mask) blk_b = tl.load(b + off, mask=mask) blk_bias = tl.load(dt_bias + head_off, mask=mask) # If the model is loaded in fp16, without the .float() here, A might be -inf x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) softplus_x = tl.where( beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x ) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) # compute beta_output = sigmoid(b) blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) tl.store( beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask ) def fused_gdn_gating( A_log: torch.Tensor, a: torch.Tensor, b: torch.Tensor, dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Fused computation of g and beta for Gated Delta Net. g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) beta_output = b.sigmoid() TODO maybe use torch.compile to replace this triton kernel """ batch, num_heads = a.shape seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) fused_gdn_gating_kernel[grid]( g, beta_output, A_log, a, b, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1, ) return g, beta_output