[Qwen3] Enable dual-chunk-attention support for Qwen3 models. (#21924)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He 2025-08-07 10:58:08 +08:00 committed by GitHub
parent 6b47ef24de
commit 7377131a2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 31 deletions

View File

@ -23,7 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3 model compatible with HuggingFace weights.""" """Inference-only Qwen3 model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
@ -47,27 +47,31 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
class Qwen3Attention(nn.Module): class Qwen3Attention(nn.Module):
def __init__(self, def __init__(
hidden_size: int, self,
num_heads: int, hidden_size: int,
num_kv_heads: int, num_heads: int,
max_position: int = 4096 * 32, num_kv_heads: int,
head_dim: Optional[int] = None, max_position: int = 4096 * 32,
rms_norm_eps: float = 1e-06, head_dim: Optional[int] = None,
qkv_bias: bool = False, rms_norm_eps: float = 1e-06,
rope_theta: float = 10000, qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None,
rope_scaling: Optional[tuple] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", rope_scaling: Optional[tuple] = None,
attn_type: str = AttentionType.DECODER) -> None: prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@ -89,6 +93,7 @@ class Qwen3Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
@ -113,15 +118,22 @@ class Qwen3Attention(nn.Module):
max_position=max_position, max_position=max_position,
base=self.rope_theta, base=self.rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type,
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {},
) )
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
@ -161,6 +173,9 @@ class Qwen3DecoderLayer(nn.Module):
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000) rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen3 uses causal attention as it is a decoder-only model. # By default, Qwen3 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable # You can override the HF config with `is_causal=False` to enable
@ -185,6 +200,7 @@ class Qwen3DecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_type=attn_type, attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.mlp = Qwen3MLP( self.mlp = Qwen3MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,

View File

@ -185,6 +185,7 @@ class Qwen3MoeAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -208,6 +209,7 @@ class Qwen3MoeAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(hidden_size, self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim, self.head_dim,
@ -229,14 +231,21 @@ class Qwen3MoeAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {},
) )
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
@ -280,6 +289,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
self.self_attn = Qwen3MoeAttention( self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
@ -293,6 +305,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
) )
# `mlp_only_layers` in the config. # `mlp_only_layers` in the config.