[Hybrid][torch.compile] Refactor mamba2 forward to avoid obscuring linear projections under custom op (#28587)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-11-19 02:49:36 +02:00 committed by GitHub
parent 9912b8ccb8
commit 1395461f5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 92 additions and 90 deletions

View File

@ -426,6 +426,10 @@ class MambaMixer2(MambaBase, CustomOp):
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
# and `set_weight_attrs` doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
self.register_buffer("conv_weights", conv_weights, persistent=False)
# - these are TPed by heads to reduce the size of the
# temporal shape
@ -459,6 +463,17 @@ class MambaMixer2(MambaBase, CustomOp):
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
)
# - get hidden_states, B and C after depthwise convolution.
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
@ -470,10 +485,24 @@ class MambaMixer2(MambaBase, CustomOp):
self.cache_config = cache_config
self.prefix = prefix
# Pre-compute sizes for forward pass
self.tped_intermediate_size = self.intermediate_size // self.tp_size
self.tped_conv_size = self.conv_dim // self.tp_size
self.tped_dt_size = self.num_heads // self.tp_size
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.tped_intermediate_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
pass
@ -481,22 +510,55 @@ class MambaMixer2(MambaBase, CustomOp):
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
torch.ops.vllm.mamba_mixer2(
hidden_states,
output,
self.prefix,
mup_vector,
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
if mup_vector is not None:
projected_states = projected_states * mup_vector
# 2. Prepare inputs for conv + SSM
ssm_output = torch.empty(
[
hidden_states.shape[0],
(self.num_heads // self.tp_size) * self.head_dim,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
def forward_cuda(
# 3. conv + SSM
# (split `projected_states` into hidden_states_B_C, dt in the custom op to
# ensure it is not treated as an intermediate tensor by torch compile)
torch.ops.vllm.mamba_mixer2(
projected_states,
ssm_output,
self.prefix,
)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
gate = projected_states[..., : self.tped_intermediate_size]
hidden_states = self.norm(ssm_output, gate)
# 5. Final linear projection
output, _ = self.out_proj(hidden_states)
return output
def conv_ssm_forward(
self,
hidden_states: torch.Tensor,
projected_states: torch.Tensor,
output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
hidden_states_B_C, dt = torch.split(
projected_states[..., self.tped_intermediate_size :],
[self.tped_conv_size, self.tped_dt_size],
dim=-1,
)
forward_context = get_forward_context()
# attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
@ -524,46 +586,13 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
if mup_vector is not None:
projected_states = projected_states * mup_vector
gate, hidden_states_B_C, dt = torch.split(
projected_states,
[
self.intermediate_size // self.tp_size,
self.conv_dim // self.tp_size,
self.num_heads // self.tp_size,
],
dim=-1,
)
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
if attn_metadata is None:
# profile run
hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states = self.norm(hidden_states, gate)
out, _ = self.out_proj(hidden_states)
return out
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
return hidden_states
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
num_prefills = attn_metadata.num_prefills # request count
@ -622,18 +651,8 @@ class MambaMixer2(MambaBase, CustomOp):
block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
output[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
@ -658,7 +677,7 @@ class MambaMixer2(MambaBase, CustomOp):
) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn(
x,
conv_weights,
self.conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
@ -673,7 +692,9 @@ class MambaMixer2(MambaBase, CustomOp):
query_start_loc=query_start_loc_p,
).transpose(0, 1)[:num_prefill_tokens]
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(
hidden_states_B_C_p
)
# 3. State Space Model sequence transformation
initial_states = None
@ -815,7 +836,7 @@ class MambaMixer2(MambaBase, CustomOp):
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
conv_state,
conv_weights,
self.conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
@ -823,7 +844,9 @@ class MambaMixer2(MambaBase, CustomOp):
initial_state_idx=block_idx_last_computed_token_d,
)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
hidden_states_B_C_d
)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
@ -861,15 +884,6 @@ class MambaMixer2(MambaBase, CustomOp):
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
@ -901,21 +915,19 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_mixer2(
hidden_states: torch.Tensor,
projected_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector)
self.conv_ssm_forward(projected_states=projected_states, output=output)
def mamba_mixer2_fake(
hidden_states: torch.Tensor,
projected_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: torch.Tensor | None = None,
) -> None:
return

View File

@ -138,8 +138,7 @@ class BambaMixerDecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output)
output = self.mamba(hidden_states)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states)

View File

@ -198,10 +198,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
residual: torch.Tensor | None,
**kwargs,
):
output = torch.empty_like(hidden_states)
self.mamba(
output = self.mamba(
hidden_states,
output,
mup_vector=self.mup_vector,
)
return output, residual

View File

@ -115,8 +115,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output)
output = self.mamba(hidden_states)
hidden_states = residual + output * self.residual_multiplier
residual = hidden_states

View File

@ -87,8 +87,7 @@ class Mamba2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output)
output = self.mixer(hidden_states)
return output, residual

View File

@ -376,8 +376,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output)
output = self.mixer(hidden_states)
return output, residual

View File

@ -567,11 +567,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
# Process through Mamba mixer
output = torch.empty_like(hidden_states)
self.mamba(
hidden_states,
output,
)
output = self.mamba(hidden_states)
# residual connection after mamba
hidden_states = residual + output