mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 16:27:15 +08:00
[Hybrid][torch.compile] Refactor mamba2 forward to avoid obscuring linear projections under custom op (#28587)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
parent
9912b8ccb8
commit
1395461f5f
@ -426,6 +426,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
|
||||
# and `set_weight_attrs` doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
self.register_buffer("conv_weights", conv_weights, persistent=False)
|
||||
|
||||
# - these are TPed by heads to reduce the size of the
|
||||
# temporal shape
|
||||
@ -459,6 +463,17 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
|
||||
)
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
@ -470,10 +485,24 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
# Pre-compute sizes for forward pass
|
||||
self.tped_intermediate_size = self.intermediate_size // self.tp_size
|
||||
self.tped_conv_size = self.conv_dim // self.tp_size
|
||||
self.tped_dt_size = self.num_heads // self.tp_size
|
||||
|
||||
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.tped_intermediate_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
pass
|
||||
@ -481,22 +510,55 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
mup_vector,
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
if mup_vector is not None:
|
||||
projected_states = projected_states * mup_vector
|
||||
|
||||
# 2. Prepare inputs for conv + SSM
|
||||
ssm_output = torch.empty(
|
||||
[
|
||||
hidden_states.shape[0],
|
||||
(self.num_heads // self.tp_size) * self.head_dim,
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
# 3. conv + SSM
|
||||
# (split `projected_states` into hidden_states_B_C, dt in the custom op to
|
||||
# ensure it is not treated as an intermediate tensor by torch compile)
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
projected_states,
|
||||
ssm_output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
gate = projected_states[..., : self.tped_intermediate_size]
|
||||
hidden_states = self.norm(ssm_output, gate)
|
||||
|
||||
# 5. Final linear projection
|
||||
output, _ = self.out_proj(hidden_states)
|
||||
|
||||
return output
|
||||
|
||||
def conv_ssm_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
projected_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
hidden_states_B_C, dt = torch.split(
|
||||
projected_states[..., self.tped_intermediate_size :],
|
||||
[self.tped_conv_size, self.tped_dt_size],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
@ -524,46 +586,13 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
|
||||
if mup_vector is not None:
|
||||
projected_states = projected_states * mup_vector
|
||||
|
||||
gate, hidden_states_B_C, dt = torch.split(
|
||||
projected_states,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
self.conv_dim // self.tp_size,
|
||||
self.num_heads // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# profile run
|
||||
hidden_states_B_C = (
|
||||
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
|
||||
).contiguous()
|
||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
return hidden_states
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
@ -622,18 +651,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
block_idx_first_scheduled_token_p = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
num_prefill_tokens + num_decodes,
|
||||
(self.num_heads // self.tp_size) * self.head_dim,
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
output[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
@ -658,7 +677,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
) # this is the form that causal-conv see
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
self.conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
@ -673,7 +692,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
query_start_loc=query_start_loc_p,
|
||||
).transpose(0, 1)[:num_prefill_tokens]
|
||||
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
|
||||
hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_p
|
||||
)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
@ -815,7 +836,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
@ -823,7 +844,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
initial_state_idx=block_idx_last_computed_token_d,
|
||||
)
|
||||
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
|
||||
hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_d
|
||||
)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
n_groups = self.n_groups // self.tp_size
|
||||
@ -861,15 +884,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
)
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
|
||||
|
||||
# 5. Final linear projection
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
@ -901,21 +915,19 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
hidden_states: torch.Tensor,
|
||||
projected_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector)
|
||||
self.conv_ssm_forward(projected_states=projected_states, output=output)
|
||||
|
||||
|
||||
def mamba_mixer2_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
projected_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@ -138,8 +138,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output)
|
||||
output = self.mamba(hidden_states)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
|
||||
@ -198,10 +198,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
residual: torch.Tensor | None,
|
||||
**kwargs,
|
||||
):
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
output = self.mamba(
|
||||
hidden_states,
|
||||
output,
|
||||
mup_vector=self.mup_vector,
|
||||
)
|
||||
return output, residual
|
||||
|
||||
@ -115,8 +115,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output)
|
||||
output = self.mamba(hidden_states)
|
||||
hidden_states = residual + output * self.residual_multiplier
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@ -87,8 +87,7 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output)
|
||||
output = self.mixer(hidden_states)
|
||||
return output, residual
|
||||
|
||||
|
||||
|
||||
@ -376,8 +376,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output)
|
||||
output = self.mixer(hidden_states)
|
||||
return output, residual
|
||||
|
||||
|
||||
|
||||
@ -567,11 +567,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Process through Mamba mixer
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
hidden_states,
|
||||
output,
|
||||
)
|
||||
output = self.mamba(hidden_states)
|
||||
|
||||
# residual connection after mamba
|
||||
hidden_states = residual + output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user