diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 8dece26ddb29c..4c32ae81b34c5 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -365,6 +365,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): batch_size = 1 state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) + out = torch.empty_like(x) dt = torch.randn(batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) - 1.0 @@ -373,16 +374,17 @@ def test_selective_state_update(dim, dstate, has_z, itype): D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None state_ref = state.detach().clone() - out = selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True) + selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out) out_ref = selective_state_update_ref(state_ref, x, dt, @@ -581,6 +583,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, ], dim=0) x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) + out = torch.empty_like(x) dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) - 1.0 @@ -590,18 +593,19 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, z = torch.randn_like(x) if has_z else None state_ref = state[state_indices, :].clone() state_before = state.clone() - out = selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID) + selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + out=out) out_ref = selective_state_update_ref(state_ref, x[:batch_size], dt[:batch_size], @@ -665,6 +669,7 @@ def test_selective_state_update_with_heads_with_batch_indices( dtype=torch.int32, device=device) x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) + out = torch.empty_like(x) if not tie_hdim: dt = torch.randn(batch_size, nheads, @@ -691,18 +696,19 @@ def test_selective_state_update_with_heads_with_batch_indices( C = torch.randn(batch_size, ngroups, dstate, device=device) z = torch.randn_like(x) if has_z else None state_ref = state[state_indices, :].detach().clone() - out = selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=state_indices, - pad_slot_id=PAD_SLOT_ID) + selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID, + out=out) out_ref = selective_state_update_ref(state_ref, x, dt, diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 00c1a2911d7db..67b14a7faa89f 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -212,15 +212,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, B, C, chunk_size) - - Y, final_state = mamba_chunk_scan_combined(X, - dt, - A, - B, - C, - chunk_size, - D=None, - return_final_states=True) + Y = torch.empty_like(X) + final_state = mamba_chunk_scan_combined(X, + dt, + A, + B, + C, + chunk_size, + D=None, + return_final_states=True, + out=Y) # just test the last in sequence torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol) @@ -292,7 +293,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, _query_start_loc_to_chunk_indices_offsets( cu_seqlens, chunk_size, cu_seqlens[-1]) - Y, new_states = mamba_chunk_scan_combined( + Y = torch.empty_like(X) + new_states = mamba_chunk_scan_combined( X, dt, A, @@ -306,6 +308,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, chunk_offsets=chunk_offsets, return_varlen_states=True, initial_states=states, + out=Y, ) # just test the last in sequence diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 796c8d9375727..60cf3e11885a1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -220,7 +220,8 @@ class MambaMixer(CustomOp): has_initial_state=attn_metadata.context_lens_tensor > 0, query_start_loc=attn_metadata.query_start_loc) else: - scan_outputs = selective_state_update( + scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) + selective_state_update( mamba_cache_params.ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), @@ -231,7 +232,8 @@ class MambaMixer(CustomOp): gate.transpose(0, 1), time_proj_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor) + state_batch_indices=mamba_cache_params.state_indices_tensor, + out=scan_outputs) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 36edac2375d0e..5ac9a7f9ab3e4 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -541,7 +541,6 @@ class MambaMixer2(MambaBase, CustomOp): # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - # NOTE: V0 put prefill before decode, v1 puts decode before prefill if envs.VLLM_USE_V1: hidden_states_B_C_d, hidden_states_B_C_p = torch.split( hidden_states_B_C[:num_actual_tokens], @@ -583,7 +582,28 @@ class MambaMixer2(MambaBase, CustomOp): 1] if has_prefill else None) - ssd_output_list = [] + # Preallocate output tensor to avoid memcpy cost for merging prefill + # and decode outputs + preallocated_ssm_out = torch.empty( + [ + num_prefill_tokens + num_decodes, + (self.num_heads // self.tp_size) * self.head_dim + ], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + if envs.VLLM_USE_V1: + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) + else: + preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( + preallocated_ssm_out, + [num_prefill_tokens, num_decodes], + dim=0, + ) # Process prefill requests if has_prefill: @@ -623,7 +643,8 @@ class MambaMixer2(MambaBase, CustomOp): has_initial_states_p[:num_prefills, None, None, None], ssm_state[state_indices_tensor_p], 0) - scan_output, varlen_state = mamba_chunk_scan_combined( + # NOTE: final output is an in-place update of out tensor + varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), @@ -646,15 +667,14 @@ class MambaMixer2(MambaBase, CustomOp): return_final_states=False, dt_softplus=True, dt_limit=(0.0, float("inf")), + out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, + self.head_dim), ) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor ssm_state[state_indices_tensor_p] = varlen_state - # - reshape - ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) - # Process decode requests if has_decode: # 2. Convolution sequence transformation @@ -684,8 +704,8 @@ class MambaMixer2(MambaBase, CustomOp): # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected # using state_indices_tensor_d - - hidden_states_d = selective_state_update( + # NOTE: final output is an in-place update of out tensor + selective_state_update( ssm_state, hidden_states_d, dt_d, @@ -697,26 +717,16 @@ class MambaMixer2(MambaBase, CustomOp): dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices_tensor_d, + out=preallocated_ssm_out_d.view(num_decodes, -1, + self.head_dim), ) - if envs.VLLM_USE_V1: - ssd_output_list.insert( - 0, - hidden_states_d.view(-1, (self.num_heads // self.tp_size) * - self.head_dim)) - else: - ssd_output_list.append( - hidden_states_d.view(-1, (self.num_heads // self.tp_size) * - self.head_dim)) - - # Merge prefill and decode outputs before passing to gated MLP - hidden_states = torch.vstack(ssd_output_list) - # 4. gated MLP # GatedRMSNorm internally applying SiLU to the gate # SiLU is applied internally before normalization, unlike standard # norm usage - hidden_states = self.norm(hidden_states, gate[:num_actual_tokens]) + hidden_states = self.norm(preallocated_ssm_out, + gate[:num_actual_tokens]) # 5. Final linear projection output[:num_actual_tokens], _ = self.out_proj(hidden_states) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 3f67fc35afdfc..838290a9f5fb2 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -205,7 +205,8 @@ def selective_state_update(state, dt_bias=None, dt_softplus=False, state_batch_indices=None, - pad_slot_id=PAD_SLOT_ID): + pad_slot_id=PAD_SLOT_ID, + out=None): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -223,10 +224,9 @@ def selective_state_update(state, for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - Return: - out: (batch, dim) or (batch, nheads, dim) + out: Preallocated ssm output tensor. Assume same shape as x. + In-place updated. """ - has_heads = state.dim() > 3 if state.dim() == 3: state = state.unsqueeze(1) if x.dim() == 2: @@ -245,6 +245,8 @@ def selective_state_update(state, z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) + if out.dim() == 2: + out = out.unsqueeze(1) _, nheads, dim, dstate = state.shape batch = x.shape[0] @@ -264,7 +266,8 @@ def selective_state_update(state, assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: assert state_batch_indices.shape == (batch, ) - out = torch.empty_like(x) + assert out.shape == x.shape + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)) @@ -328,9 +331,6 @@ def selective_state_update(state, BLOCK_SIZE_M, num_warps=num_warps, ) - if not has_heads: - out = out.squeeze(1) - return out def selective_scan_fn(u, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 61eff0c008f60..fc2b3b25fd0a8 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -454,6 +454,7 @@ def _chunk_scan_fwd( chunk_indices=None, chunk_offsets=None, initial_states=None, + out=None, ): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape @@ -483,20 +484,10 @@ def _chunk_scan_fwd( else: chunk_indices, chunk_offsets = None, None - # Allocates output. - out = torch.empty(batch, - seqlen, - nheads, - headdim, - device=x.device, - dtype=x.dtype) + assert out.shape == x.shape + if z is not None: - out_x = torch.empty(batch, - seqlen, - nheads, - headdim, - device=x.device, - dtype=x.dtype) + out_x = torch.empty_like(x) assert out_x.stride() == out.stride() else: out_x = None @@ -579,4 +570,4 @@ def _chunk_scan_fwd( IS_TRITON_22=TRITON_22, HAS_INITSTATES=initial_states is not None, ) - return out, out_x + return out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index b121275e9eb38..ad2853a3d8a8b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -36,7 +36,8 @@ def _mamba_chunk_scan_combined_fwd(x, chunk_offsets=None, cu_seqlens=None, dt_softplus=False, - dt_limit=(0.0, float("inf"))): + dt_limit=(0.0, float("inf")), + out=None): batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 @@ -134,7 +135,7 @@ def _mamba_chunk_scan_combined_fwd(x, # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had # a seq_idx change, in which case we take states information from # init_states. - out, out_x = _chunk_scan_fwd( + out_x = _chunk_scan_fwd( CB, x, dt, @@ -147,9 +148,10 @@ def _mamba_chunk_scan_combined_fwd(x, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, initial_states=initial_states, + out=out, ) if cu_seqlens is None: - return out, out_x, dt, dA_cumsum, states, final_states + return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" varlen_states = chunk_state_varlen( @@ -161,7 +163,7 @@ def _mamba_chunk_scan_combined_fwd(x, states.squeeze(0), initial_states=initial_states, ) - return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + return out_x, dt, dA_cumsum, states, final_states, varlen_states def mamba_chunk_scan_combined(x, @@ -180,6 +182,7 @@ def mamba_chunk_scan_combined(x, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + out=None, return_final_states=False, return_varlen_states=False): """ @@ -197,15 +200,14 @@ def mamba_chunk_scan_combined(x, seq_idx: (batch, seqlen) cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt - Return: - out: (batch, seqlen, nheads, headdim) + out: Preallocated output tensor """ if not return_varlen_states: cu_seqlens = None else: assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" - out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( x, dt, A, @@ -221,12 +223,14 @@ def mamba_chunk_scan_combined(x, chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, - dt_limit=dt_limit) + dt_limit=dt_limit, + out=out) if not return_varlen_states: - return out if not return_final_states else (out, final_states) + if not return_final_states: + return + else: + return final_states else: varlen_states = rest[0] - return (out, - varlen_states) if not return_final_states else (out, - final_states, + return (varlen_states) if not return_final_states else (final_states, varlen_states) diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index a4ded2b7a3047..1a761d01fc066 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -387,7 +387,8 @@ class Phi4Mamba(nn.Module): has_initial_state=attn_metadata.context_lens_tensor > 0, query_start_loc=attn_metadata.query_start_loc) else: - scan_outputs = selective_state_update( + scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) + selective_state_update( mamba_cache_params.ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), @@ -400,7 +401,8 @@ class Phi4Mamba(nn.Module): None if self.yoco_kv else gate.transpose(0, 1), time_proj_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor) + state_batch_indices=mamba_cache_params.state_indices_tensor, + out=scan_outputs) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 9bc577cfe3a3e..8b1df66f02805 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -257,7 +257,21 @@ class Plamo2MambaMixer(nn.Module): query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] if has_prefill else None) - ssd_output_list = [] + # Preallocate output tensor to avoid memcpy cost for merging prefill + # and decode outputs + preallocated_ssm_out = torch.empty( + [ + num_prefill_tokens + num_decodes, + (self.num_heads // self.tp_size) * self.head_dim + ], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( + preallocated_ssm_out, + [num_prefill_tokens, num_decodes], + dim=0, + ) # Process prefill requests if has_prefill: @@ -290,7 +304,7 @@ class Plamo2MambaMixer(nn.Module): initial_states = torch.where( mamba2_metadata.has_initial_states[:, None, None, None], mamba_cache_params.ssm_state[state_indices_tensor_p], 0) - scan_output, varlen_state = mamba_chunk_scan_combined( + varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), @@ -312,15 +326,14 @@ class Plamo2MambaMixer(nn.Module): return_final_states=False, dt_softplus=True, dt_limit=(0.0, float("inf")), + out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, + self.head_dim), ) # update ssm states # - varlen state is a (batch, nheads, headdim, dstate) tensor mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state - # - reshape - ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) - # Process decode requests if has_decode: # 2. Convolution sequence transformation @@ -349,8 +362,7 @@ class Plamo2MambaMixer(nn.Module): # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected # using state_indices_tensor_d - - hidden_states_d = selective_state_update( + selective_state_update( mamba_cache_params.ssm_state, hidden_states_d, dt, @@ -362,17 +374,13 @@ class Plamo2MambaMixer(nn.Module): dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices_tensor_d, + out=preallocated_ssm_out_d.view(num_decodes, -1, + self.head_dim), ) assert self.num_heads % self.tp_size == 0 - ssd_output_list.append( - hidden_states_d.view(-1, (self.num_heads // self.tp_size) * - self.head_dim)) - - # Merge prefill and decode outputs before passing to MLP - hidden_states = torch.vstack(ssd_output_list) # 4. Final linear projection - out = self.out_proj(hidden_states) + out = self.out_proj(preallocated_ssm_out) return out