diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index d2ae8959b103d..0ad50640bb3bc 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen3 model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -47,27 +47,31 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model -from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + maybe_prefix) logger = init_logger(__name__) class Qwen3Attention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - head_dim: Optional[int] = None, - rms_norm_eps: float = 1e-06, - qkv_bias: bool = False, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -89,6 +93,7 @@ class Qwen3Attention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -113,15 +118,22 @@ class Qwen3Attention(nn.Module): max_position=max_position, base=self.rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type, + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=attn_type) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -161,6 +173,9 @@ class Qwen3DecoderLayer(nn.Module): # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) # By default, Qwen3 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -185,6 +200,7 @@ class Qwen3DecoderLayer(nn.Module): rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.mlp = Qwen3MLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index ca14fd06574ec..7410589190bac 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -185,6 +185,7 @@ class Qwen3MoeAttention(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + dual_chunk_attention_config: Optional[dict[str, Any]] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -208,6 +209,7 @@ class Qwen3MoeAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear(hidden_size, self.head_dim, @@ -229,14 +231,21 @@ class Qwen3MoeAttention(nn.Module): max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -280,6 +289,9 @@ class Qwen3MoeDecoderLayer(nn.Module): rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -293,6 +305,7 @@ class Qwen3MoeDecoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, ) # `mlp_only_layers` in the config.