[PERF] Decouple projections from GDN custom op. Attempt 2 (#28083)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson 2025-11-06 05:01:12 +04:00 committed by GitHub
parent 1767658559
commit b6a248bdd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 205 additions and 53 deletions

View File

@ -462,7 +462,7 @@ class CompilationConfig:
"vllm::short_conv", "vllm::short_conv",
"vllm::linear_attention", "vllm::linear_attention",
"vllm::plamo2_mamba_mixer", "vllm::plamo2_mamba_mixer",
"vllm::gdn_attention", "vllm::gdn_attention_core",
"vllm::kda_attention", "vllm::kda_attention",
"vllm::sparse_attn_indexer", "vllm::sparse_attn_indexer",
] ]

View File

@ -369,6 +369,109 @@ class GemmaRMSNorm(CustomOp):
return self.forward_native(x, residual) return self.forward_native(x, residual)
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
"""RMS Normalization with optional gating.
This is a native PyTorch implementation that supports:
- Standard RMS normalization
- Group RMS normalization
- Optional gating with SiLU activation
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
group_size: int | None = None,
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""Initialize RMSNormGated.
Args:
hidden_size: Size of the hidden dimension
eps: Epsilon for numerical stability
group_size: If not None, do GroupNorm with each group
having group_size elements.
group_size=None is equivalent to group_size=hidden_size
(i.e. there's only 1 group).
norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
If False and z is provided: out = norm(x * silu(z))
device: Device to create parameters on
dtype: Data type for parameters
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward_native(
self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
"""
Native PyTorch implementation of RMS normalization with gating.
Args:
x: Input tensor
z: Optional gating tensor
Returns:
Normalized (and optionally gated) tensor
If z is not None:
- norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=False: out = norm(x * silu(z))
"""
# Apply gating before normalization if needed
if z is not None and not self.norm_before_gate:
x = x * F.silu(z)
# RMS Normalization
if self.group_size is None:
# Standard RMS norm across the last dimension
variance = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(variance + self.eps)
out = x_normed * self.weight
else:
# Group RMS norm
from einops import rearrange
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
x_normed = x_group * torch.rsqrt(variance + self.eps)
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
# Apply gating after normalization if needed
if z is not None and self.norm_before_gate:
out = out * F.silu(z)
return out
def forward_cuda(
self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
""" """
Layer Normalization. Layer Normalization.

View File

@ -30,12 +30,14 @@ from vllm.distributed import (
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import ( from vllm.model_executor.layers.fla.ops import (
RMSNormGated,
chunk_gated_delta_rule, chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule, fused_recurrent_gated_delta_rule,
) )
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
)
from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@ -436,17 +438,66 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
): ):
return torch.ops.vllm.gdn_attention( """
hidden_states, Forward pass with three parts:
output, 1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = map(
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
# ============================================================
# Part 2: Core Attention (Custom Op)
# ============================================================
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.gdn_attention_core(
mixed_qkv,
b,
a,
core_attn_out,
self.prefix, self.prefix,
) )
def _forward( # ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def _forward_core(
self, self,
hidden_states: torch.Tensor, mixed_qkv: torch.Tensor,
output: torch.Tensor, b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
): ):
"""
Core attention computation (called by custom op).
"""
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
@ -471,18 +522,11 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens
# 1. Set up dimensions for reshapes later mixed_qkv = mixed_qkv[:num_actual_tokens]
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) b = b[:num_actual_tokens]
projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens]) a = a[:num_actual_tokens]
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = map(
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
# 2. Convolution sequence transformation # 1. Convolution sequence transformation
conv_weights = self.conv1d.weight.view( conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2) self.conv1d.weight.size(0), self.conv1d.weight.size(2)
) )
@ -498,7 +542,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_spec = None mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv mixed_qkv_non_spec = mixed_qkv
# 2.1: process the mutli-query part # 1.1: Process the multi-query part
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
mixed_qkv_spec = causal_conv1d_update( mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec, mixed_qkv_spec,
@ -515,7 +559,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
validate_data=False, validate_data=False,
) )
# 2.2: process the remaining part # 1.2: Process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
# - "cache_indices" updates the conv_state cache in positions # - "cache_indices" updates the conv_state cache in positions
@ -570,9 +614,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
g_non_spec = g g_non_spec = g
beta_non_spec = beta beta_non_spec = beta
# 3. Recurrent attention # 2. Recurrent attention
# 3.1: process the mutlti-query part # 2.1: Process the multi-query part
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_spec, q=query_spec,
@ -590,7 +634,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
else: else:
core_attn_out_spec, last_recurrent_state = None, None core_attn_out_spec, last_recurrent_state = None, None
# 3.2: process the remaining part # 2.2: Process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0 initial_state[~has_initial_state, ...] = 0
@ -633,30 +677,20 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
else: else:
core_attn_out_non_spec, last_recurrent_state = None, None core_attn_out_non_spec, last_recurrent_state = None, None
# Merge core attention output # 3. Merge core attention output
if spec_sequence_masks is not None and core_attn_out_non_spec is not None: if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
core_attn_out = torch.empty( merged_out = torch.empty(
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]), (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
dtype=core_attn_out_non_spec.dtype, dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device, device=core_attn_out_non_spec.device,
) )
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec) merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
elif spec_sequence_masks is not None: elif spec_sequence_masks is not None:
core_attn_out = core_attn_out_spec core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
else: else:
core_attn_out = core_attn_out_non_spec core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
class Qwen3NextAttention(nn.Module): class Qwen3NextAttention(nn.Module):
@ -1260,29 +1294,44 @@ class Qwen3NextForCausalLM(
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
def gdn_attention( def gdn_attention_core(
hidden_states: torch.Tensor, mixed_qkv: torch.Tensor,
output: torch.Tensor, b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str, layer_name: str,
) -> None: ) -> None:
"""
Custom op for the core attention computation.
Only handles the convolution + recurrent attention part.
Input/output projections are handled outside this op.
"""
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output) self._forward_core(
mixed_qkv=mixed_qkv,
b=b,
a=a,
core_attn_out=core_attn_out,
)
def gdn_attention_fake( def gdn_attention_core_fake(
hidden_states: torch.Tensor, mixed_qkv: torch.Tensor,
output: torch.Tensor, b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str, layer_name: str,
) -> None: ) -> None:
"""Fake implementation for torch.compile."""
return return
direct_register_custom_op( direct_register_custom_op(
op_name="gdn_attention", op_name="gdn_attention_core",
op_func=gdn_attention, op_func=gdn_attention_core,
mutates_args=["output"], mutates_args=["core_attn_out"],
fake_impl=gdn_attention_fake, fake_impl=gdn_attention_core_fake,
) )