From 1395461f5fb76145433c1dc8a3b7262ee3799bf8 Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Wed, 19 Nov 2025 02:49:36 +0200 Subject: [PATCH] [Hybrid][torch.compile] Refactor mamba2 forward to avoid obscuring linear projections under custom op (#28587) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .../layers/mamba/mamba_mixer2.py | 160 ++++++++++-------- vllm/model_executor/models/bamba.py | 3 +- vllm/model_executor/models/falcon_h1.py | 4 +- .../model_executor/models/granitemoehybrid.py | 3 +- vllm/model_executor/models/mamba2.py | 3 +- vllm/model_executor/models/nemotron_h.py | 3 +- vllm/model_executor/models/zamba2.py | 6 +- 7 files changed, 92 insertions(+), 90 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index fb45afa33dad6..57313990b8206 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -426,6 +426,10 @@ class MambaMixer2(MambaBase, CustomOp): # `ColumnParallelLinear` and `MergedColumnParallelLinear`, # and `set_weight_attrs` doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + self.register_buffer("conv_weights", conv_weights, persistent=False) # - these are TPed by heads to reduce the size of the # temporal shape @@ -459,6 +463,17 @@ class MambaMixer2(MambaBase, CustomOp): intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps ) + # - get hidden_states, B and C after depthwise convolution. + self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, + ], + dim=-1, + ) + compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") @@ -470,10 +485,24 @@ class MambaMixer2(MambaBase, CustomOp): self.cache_config = cache_config self.prefix = prefix + # Pre-compute sizes for forward pass + self.tped_intermediate_size = self.intermediate_size // self.tp_size + self.tped_conv_size = self.conv_dim // self.tp_size + self.tped_dt_size = self.num_heads // self.tp_size + + self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( + hidden_states_B_C, + [ + self.tped_intermediate_size, + self.groups_ssm_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, + ], + dim=-1, + ) + def forward_native( self, hidden_states: torch.Tensor, - output: torch.Tensor, mup_vector: torch.Tensor | None = None, ): pass @@ -481,22 +510,55 @@ class MambaMixer2(MambaBase, CustomOp): def forward( self, hidden_states: torch.Tensor, - output: torch.Tensor, mup_vector: torch.Tensor | None = None, ): - torch.ops.vllm.mamba_mixer2( - hidden_states, - output, - self.prefix, - mup_vector, + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + if mup_vector is not None: + projected_states = projected_states * mup_vector + + # 2. Prepare inputs for conv + SSM + ssm_output = torch.empty( + [ + hidden_states.shape[0], + (self.num_heads // self.tp_size) * self.head_dim, + ], + dtype=hidden_states.dtype, + device=hidden_states.device, ) - def forward_cuda( + # 3. conv + SSM + # (split `projected_states` into hidden_states_B_C, dt in the custom op to + # ensure it is not treated as an intermediate tensor by torch compile) + torch.ops.vllm.mamba_mixer2( + projected_states, + ssm_output, + self.prefix, + ) + + # 4. gated MLP + # GatedRMSNorm internally applying SiLU to the gate + # SiLU is applied internally before normalization, unlike standard + # norm usage + gate = projected_states[..., : self.tped_intermediate_size] + hidden_states = self.norm(ssm_output, gate) + + # 5. Final linear projection + output, _ = self.out_proj(hidden_states) + + return output + + def conv_ssm_forward( self, - hidden_states: torch.Tensor, + projected_states: torch.Tensor, output: torch.Tensor, - mup_vector: torch.Tensor | None = None, ): + hidden_states_B_C, dt = torch.split( + projected_states[..., self.tped_intermediate_size :], + [self.tped_conv_size, self.tped_dt_size], + dim=-1, + ) + forward_context = get_forward_context() # attn_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill @@ -524,46 +586,13 @@ class MambaMixer2(MambaBase, CustomOp): cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p - # 1. Gated MLP's linear projection - projected_states, _ = self.in_proj(hidden_states) - - if mup_vector is not None: - projected_states = projected_states * mup_vector - - gate, hidden_states_B_C, dt = torch.split( - projected_states, - [ - self.intermediate_size // self.tp_size, - self.conv_dim // self.tp_size, - self.num_heads // self.tp_size, - ], - dim=-1, - ) - - conv_weights = self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ) - - # - get hidden_states, B and C after depthwise convolution. - split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( - hidden_states_B_C, - [ - self.intermediate_size // self.tp_size, - self.groups_ssm_state_size // self.tp_size, - self.groups_ssm_state_size // self.tp_size, - ], - dim=-1, - ) - if attn_metadata is None: # profile run hidden_states_B_C = ( hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1) ).contiguous() - hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C) - hidden_states = self.norm(hidden_states, gate) - out, _ = self.out_proj(hidden_states) - return out + hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C) + return hidden_states # NOTE: V0 put prefill before decode, v1 puts decode before prefill num_prefills = attn_metadata.num_prefills # request count @@ -622,18 +651,8 @@ class MambaMixer2(MambaBase, CustomOp): block_idx_first_scheduled_token_p = None num_computed_tokens_p = None - # Preallocate output tensor to avoid memcpy cost for merging prefill - # and decode outputs - preallocated_ssm_out = torch.empty( - [ - num_prefill_tokens + num_decodes, - (self.num_heads // self.tp_size) * self.head_dim, - ], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( - preallocated_ssm_out, + output[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0, ) @@ -658,7 +677,7 @@ class MambaMixer2(MambaBase, CustomOp): ) # this is the form that causal-conv see hidden_states_B_C_p = causal_conv1d_fn( x, - conv_weights, + self.conv_weights, self.conv1d.bias, activation=self.activation, conv_states=conv_state, @@ -673,7 +692,9 @@ class MambaMixer2(MambaBase, CustomOp): query_start_loc=query_start_loc_p, ).transpose(0, 1)[:num_prefill_tokens] - hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p) + hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn( + hidden_states_B_C_p + ) # 3. State Space Model sequence transformation initial_states = None @@ -815,7 +836,7 @@ class MambaMixer2(MambaBase, CustomOp): hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, conv_state, - conv_weights, + self.conv_weights, self.conv1d.bias, self.activation, conv_state_indices=state_indices_tensor_d, @@ -823,7 +844,9 @@ class MambaMixer2(MambaBase, CustomOp): initial_state_idx=block_idx_last_computed_token_d, ) - hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d) + hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn( + hidden_states_B_C_d + ) # 3. State Space Model sequence transformation n_groups = self.n_groups // self.tp_size @@ -861,15 +884,6 @@ class MambaMixer2(MambaBase, CustomOp): out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) - # 4. gated MLP - # GatedRMSNorm internally applying SiLU to the gate - # SiLU is applied internally before normalization, unlike standard - # norm usage - hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens]) - - # 5. Final linear projection - output[:num_actual_tokens], _ = self.out_proj(hidden_states) - def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: assert self.model_config is not None assert self.cache_config is not None @@ -901,21 +915,19 @@ class MambaMixer2(MambaBase, CustomOp): def mamba_mixer2( - hidden_states: torch.Tensor, + projected_states: torch.Tensor, output: torch.Tensor, layer_name: str, - mup_vector: torch.Tensor | None = None, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector) + self.conv_ssm_forward(projected_states=projected_states, output=output) def mamba_mixer2_fake( - hidden_states: torch.Tensor, + projected_states: torch.Tensor, output: torch.Tensor, layer_name: str, - mup_vector: torch.Tensor | None = None, ) -> None: return diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index e0a2defd5127e..c6cc83487fec2 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -138,8 +138,7 @@ class BambaMixerDecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output) + output = self.mamba(hidden_states) # Fully Connected hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 3653425b8e1ca..b985847af5daf 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -198,10 +198,8 @@ class FalconH1SSMDecoderLayer(nn.Module): residual: torch.Tensor | None, **kwargs, ): - output = torch.empty_like(hidden_states) - self.mamba( + output = self.mamba( hidden_states, - output, mup_vector=self.mup_vector, ) return output, residual diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 05177f1d1ac2c..a340112ec62ae 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -115,8 +115,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output) + output = self.mamba(hidden_states) hidden_states = residual + output * self.residual_multiplier residual = hidden_states diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index fc17f98be1986..5fcfa94312303 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -87,8 +87,7 @@ class Mamba2DecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output) + output = self.mixer(hidden_states) return output, residual diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index f7e0caf410e10..8675eff592224 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -376,8 +376,7 @@ class NemotronHMambaDecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output) + output = self.mixer(hidden_states) return output, residual diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 64e6979c8fcfb..729a9655d0879 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -567,11 +567,7 @@ class Zamba2MambaDecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) # Process through Mamba mixer - output = torch.empty_like(hidden_states) - self.mamba( - hidden_states, - output, - ) + output = self.mamba(hidden_states) # residual connection after mamba hidden_states = residual + output