[Hybrid][torch.compile] Refactor mamba2 forward to avoid obscuring linear projections under custom op (#28587)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-11-19 02:49:36 +02:00 committed by GitHub
parent 9912b8ccb8
commit 1395461f5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 92 additions and 90 deletions

View File

@ -426,6 +426,10 @@ class MambaMixer2(MambaBase, CustomOp):
# `ColumnParallelLinear` and `MergedColumnParallelLinear`, # `ColumnParallelLinear` and `MergedColumnParallelLinear`,
# and `set_weight_attrs` doesn't allow to override it # and `set_weight_attrs` doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
self.register_buffer("conv_weights", conv_weights, persistent=False)
# - these are TPed by heads to reduce the size of the # - these are TPed by heads to reduce the size of the
# temporal shape # temporal shape
@ -459,6 +463,17 @@ class MambaMixer2(MambaBase, CustomOp):
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
) )
# - get hidden_states, B and C after depthwise convolution.
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
@ -470,10 +485,24 @@ class MambaMixer2(MambaBase, CustomOp):
self.cache_config = cache_config self.cache_config = cache_config
self.prefix = prefix self.prefix = prefix
# Pre-compute sizes for forward pass
self.tped_intermediate_size = self.intermediate_size // self.tp_size
self.tped_conv_size = self.conv_dim // self.tp_size
self.tped_dt_size = self.num_heads // self.tp_size
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.tped_intermediate_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
def forward_native( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: torch.Tensor | None = None, mup_vector: torch.Tensor | None = None,
): ):
pass pass
@ -481,22 +510,55 @@ class MambaMixer2(MambaBase, CustomOp):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: torch.Tensor | None = None, mup_vector: torch.Tensor | None = None,
): ):
torch.ops.vllm.mamba_mixer2( # 1. Gated MLP's linear projection
hidden_states, projected_states, _ = self.in_proj(hidden_states)
output, if mup_vector is not None:
self.prefix, projected_states = projected_states * mup_vector
mup_vector,
# 2. Prepare inputs for conv + SSM
ssm_output = torch.empty(
[
hidden_states.shape[0],
(self.num_heads // self.tp_size) * self.head_dim,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
) )
def forward_cuda( # 3. conv + SSM
# (split `projected_states` into hidden_states_B_C, dt in the custom op to
# ensure it is not treated as an intermediate tensor by torch compile)
torch.ops.vllm.mamba_mixer2(
projected_states,
ssm_output,
self.prefix,
)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
gate = projected_states[..., : self.tped_intermediate_size]
hidden_states = self.norm(ssm_output, gate)
# 5. Final linear projection
output, _ = self.out_proj(hidden_states)
return output
def conv_ssm_forward(
self, self,
hidden_states: torch.Tensor, projected_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
): ):
hidden_states_B_C, dt = torch.split(
projected_states[..., self.tped_intermediate_size :],
[self.tped_conv_size, self.tped_dt_size],
dim=-1,
)
forward_context = get_forward_context() forward_context = get_forward_context()
# attn_metadata contains metadata necessary for the mamba2 triton # attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
@ -524,46 +586,13 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
if mup_vector is not None:
projected_states = projected_states * mup_vector
gate, hidden_states_B_C, dt = torch.split(
projected_states,
[
self.intermediate_size // self.tp_size,
self.conv_dim // self.tp_size,
self.num_heads // self.tp_size,
],
dim=-1,
)
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
if attn_metadata is None: if attn_metadata is None:
# profile run # profile run
hidden_states_B_C = ( hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1) hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous() ).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C) hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states = self.norm(hidden_states, gate) return hidden_states
out, _ = self.out_proj(hidden_states)
return out
# NOTE: V0 put prefill before decode, v1 puts decode before prefill # NOTE: V0 put prefill before decode, v1 puts decode before prefill
num_prefills = attn_metadata.num_prefills # request count num_prefills = attn_metadata.num_prefills # request count
@ -622,18 +651,8 @@ class MambaMixer2(MambaBase, CustomOp):
block_idx_first_scheduled_token_p = None block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None num_computed_tokens_p = None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out, output[:num_actual_tokens],
[num_decodes, num_prefill_tokens], [num_decodes, num_prefill_tokens],
dim=0, dim=0,
) )
@ -658,7 +677,7 @@ class MambaMixer2(MambaBase, CustomOp):
) # this is the form that causal-conv see ) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn( hidden_states_B_C_p = causal_conv1d_fn(
x, x,
conv_weights, self.conv_weights,
self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
conv_states=conv_state, conv_states=conv_state,
@ -673,7 +692,9 @@ class MambaMixer2(MambaBase, CustomOp):
query_start_loc=query_start_loc_p, query_start_loc=query_start_loc_p,
).transpose(0, 1)[:num_prefill_tokens] ).transpose(0, 1)[:num_prefill_tokens]
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p) hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(
hidden_states_B_C_p
)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
initial_states = None initial_states = None
@ -815,7 +836,7 @@ class MambaMixer2(MambaBase, CustomOp):
hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d, hidden_states_B_C_d,
conv_state, conv_state,
conv_weights, self.conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
conv_state_indices=state_indices_tensor_d, conv_state_indices=state_indices_tensor_d,
@ -823,7 +844,9 @@ class MambaMixer2(MambaBase, CustomOp):
initial_state_idx=block_idx_last_computed_token_d, initial_state_idx=block_idx_last_computed_token_d,
) )
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d) hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
hidden_states_B_C_d
)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size n_groups = self.n_groups // self.tp_size
@ -861,15 +884,6 @@ class MambaMixer2(MambaBase, CustomOp):
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
) )
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None assert self.model_config is not None
assert self.cache_config is not None assert self.cache_config is not None
@ -901,21 +915,19 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_mixer2( def mamba_mixer2(
hidden_states: torch.Tensor, projected_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
mup_vector: torch.Tensor | None = None,
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector) self.conv_ssm_forward(projected_states=projected_states, output=output)
def mamba_mixer2_fake( def mamba_mixer2_fake(
hidden_states: torch.Tensor, projected_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
mup_vector: torch.Tensor | None = None,
) -> None: ) -> None:
return return

View File

@ -138,8 +138,7 @@ class BambaMixerDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = self.mamba(hidden_states)
self.mamba(hidden_states, output)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states) hidden_states = self.feed_forward(hidden_states)

View File

@ -198,10 +198,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
residual: torch.Tensor | None, residual: torch.Tensor | None,
**kwargs, **kwargs,
): ):
output = torch.empty_like(hidden_states) output = self.mamba(
self.mamba(
hidden_states, hidden_states,
output,
mup_vector=self.mup_vector, mup_vector=self.mup_vector,
) )
return output, residual return output, residual

View File

@ -115,8 +115,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
output = torch.empty_like(hidden_states) output = self.mamba(hidden_states)
self.mamba(hidden_states, output)
hidden_states = residual + output * self.residual_multiplier hidden_states = residual + output * self.residual_multiplier
residual = hidden_states residual = hidden_states

View File

@ -87,8 +87,7 @@ class Mamba2DecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = self.mixer(hidden_states)
self.mixer(hidden_states, output)
return output, residual return output, residual

View File

@ -376,8 +376,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states) output = self.mixer(hidden_states)
self.mixer(hidden_states, output)
return output, residual return output, residual

View File

@ -567,11 +567,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Process through Mamba mixer # Process through Mamba mixer
output = torch.empty_like(hidden_states) output = self.mamba(hidden_states)
self.mamba(
hidden_states,
output,
)
# residual connection after mamba # residual connection after mamba
hidden_states = residual + output hidden_states = residual + output