mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 12:22:13 +08:00
[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)
Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
parent
25373b6c6c
commit
b690e34824
@ -365,6 +365,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
|||||||
batch_size = 1
|
batch_size = 1
|
||||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||||
|
out = torch.empty_like(x)
|
||||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||||
@ -373,16 +374,17 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
|||||||
D = torch.randn(dim, device=device)
|
D = torch.randn(dim, device=device)
|
||||||
z = torch.randn_like(x) if has_z else None
|
z = torch.randn_like(x) if has_z else None
|
||||||
state_ref = state.detach().clone()
|
state_ref = state.detach().clone()
|
||||||
out = selective_state_update(state,
|
selective_state_update(state,
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
D=D,
|
D=D,
|
||||||
z=z,
|
z=z,
|
||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=True)
|
dt_softplus=True,
|
||||||
|
out=out)
|
||||||
out_ref = selective_state_update_ref(state_ref,
|
out_ref = selective_state_update_ref(state_ref,
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
@ -581,6 +583,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
|||||||
],
|
],
|
||||||
dim=0)
|
dim=0)
|
||||||
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||||
|
out = torch.empty_like(x)
|
||||||
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||||
@ -590,18 +593,19 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
|||||||
z = torch.randn_like(x) if has_z else None
|
z = torch.randn_like(x) if has_z else None
|
||||||
state_ref = state[state_indices, :].clone()
|
state_ref = state[state_indices, :].clone()
|
||||||
state_before = state.clone()
|
state_before = state.clone()
|
||||||
out = selective_state_update(state,
|
selective_state_update(state,
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
D=D,
|
D=D,
|
||||||
z=z,
|
z=z,
|
||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=padded_state_indices,
|
state_batch_indices=padded_state_indices,
|
||||||
pad_slot_id=PAD_SLOT_ID)
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
|
out=out)
|
||||||
out_ref = selective_state_update_ref(state_ref,
|
out_ref = selective_state_update_ref(state_ref,
|
||||||
x[:batch_size],
|
x[:batch_size],
|
||||||
dt[:batch_size],
|
dt[:batch_size],
|
||||||
@ -665,6 +669,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
|||||||
dtype=torch.int32, device=device)
|
dtype=torch.int32, device=device)
|
||||||
|
|
||||||
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||||
|
out = torch.empty_like(x)
|
||||||
if not tie_hdim:
|
if not tie_hdim:
|
||||||
dt = torch.randn(batch_size,
|
dt = torch.randn(batch_size,
|
||||||
nheads,
|
nheads,
|
||||||
@ -691,18 +696,19 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
|||||||
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||||
z = torch.randn_like(x) if has_z else None
|
z = torch.randn_like(x) if has_z else None
|
||||||
state_ref = state[state_indices, :].detach().clone()
|
state_ref = state[state_indices, :].detach().clone()
|
||||||
out = selective_state_update(state,
|
selective_state_update(state,
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
D=D,
|
D=D,
|
||||||
z=z,
|
z=z,
|
||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=state_indices,
|
state_batch_indices=state_indices,
|
||||||
pad_slot_id=PAD_SLOT_ID)
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
|
out=out)
|
||||||
out_ref = selective_state_update_ref(state_ref,
|
out_ref = selective_state_update_ref(state_ref,
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
|
|||||||
@ -212,15 +212,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
|||||||
|
|
||||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||||
B, C, chunk_size)
|
B, C, chunk_size)
|
||||||
|
Y = torch.empty_like(X)
|
||||||
Y, final_state = mamba_chunk_scan_combined(X,
|
final_state = mamba_chunk_scan_combined(X,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
D=None,
|
D=None,
|
||||||
return_final_states=True)
|
return_final_states=True,
|
||||||
|
out=Y)
|
||||||
|
|
||||||
# just test the last in sequence
|
# just test the last in sequence
|
||||||
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
||||||
@ -292,7 +293,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
_query_start_loc_to_chunk_indices_offsets(
|
_query_start_loc_to_chunk_indices_offsets(
|
||||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||||
|
|
||||||
Y, new_states = mamba_chunk_scan_combined(
|
Y = torch.empty_like(X)
|
||||||
|
new_states = mamba_chunk_scan_combined(
|
||||||
X,
|
X,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
@ -306,6 +308,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
chunk_offsets=chunk_offsets,
|
chunk_offsets=chunk_offsets,
|
||||||
return_varlen_states=True,
|
return_varlen_states=True,
|
||||||
initial_states=states,
|
initial_states=states,
|
||||||
|
out=Y,
|
||||||
)
|
)
|
||||||
|
|
||||||
# just test the last in sequence
|
# just test the last in sequence
|
||||||
|
|||||||
@ -220,7 +220,8 @@ class MambaMixer(CustomOp):
|
|||||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||||
query_start_loc=attn_metadata.query_start_loc)
|
query_start_loc=attn_metadata.query_start_loc)
|
||||||
else:
|
else:
|
||||||
scan_outputs = selective_state_update(
|
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||||
|
selective_state_update(
|
||||||
mamba_cache_params.ssm_state,
|
mamba_cache_params.ssm_state,
|
||||||
hidden_states.transpose(0, 1),
|
hidden_states.transpose(0, 1),
|
||||||
discrete_time_step.transpose(0, 1),
|
discrete_time_step.transpose(0, 1),
|
||||||
@ -231,7 +232,8 @@ class MambaMixer(CustomOp):
|
|||||||
gate.transpose(0, 1),
|
gate.transpose(0, 1),
|
||||||
time_proj_bias,
|
time_proj_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||||
|
out=scan_outputs)
|
||||||
scan_outputs = scan_outputs.transpose(0, 1)
|
scan_outputs = scan_outputs.transpose(0, 1)
|
||||||
|
|
||||||
# 4. Final linear projection
|
# 4. Final linear projection
|
||||||
|
|||||||
@ -541,7 +541,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||||
# Separate prefill and decode by splitting varlen input
|
# Separate prefill and decode by splitting varlen input
|
||||||
# Split along token dimension
|
# Split along token dimension
|
||||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||||
hidden_states_B_C[:num_actual_tokens],
|
hidden_states_B_C[:num_actual_tokens],
|
||||||
@ -583,7 +582,28 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
1]
|
1]
|
||||||
if has_prefill else None)
|
if has_prefill else None)
|
||||||
|
|
||||||
ssd_output_list = []
|
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||||
|
# and decode outputs
|
||||||
|
preallocated_ssm_out = torch.empty(
|
||||||
|
[
|
||||||
|
num_prefill_tokens + num_decodes,
|
||||||
|
(self.num_heads // self.tp_size) * self.head_dim
|
||||||
|
],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||||
|
preallocated_ssm_out,
|
||||||
|
[num_decodes, num_prefill_tokens],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||||
|
preallocated_ssm_out,
|
||||||
|
[num_prefill_tokens, num_decodes],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
# Process prefill requests
|
# Process prefill requests
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
@ -623,7 +643,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
has_initial_states_p[:num_prefills, None, None, None],
|
has_initial_states_p[:num_prefills, None, None, None],
|
||||||
ssm_state[state_indices_tensor_p], 0)
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
|
|
||||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
# NOTE: final output is an in-place update of out tensor
|
||||||
|
varlen_state = mamba_chunk_scan_combined(
|
||||||
hidden_states_p.view(1, num_prefill_tokens,
|
hidden_states_p.view(1, num_prefill_tokens,
|
||||||
self.num_heads // self.tp_size,
|
self.num_heads // self.tp_size,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
@ -646,15 +667,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
return_final_states=False,
|
return_final_states=False,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
|
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||||
|
self.head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
# update ssm states
|
# update ssm states
|
||||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||||
ssm_state[state_indices_tensor_p] = varlen_state
|
ssm_state[state_indices_tensor_p] = varlen_state
|
||||||
|
|
||||||
# - reshape
|
|
||||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
|
||||||
|
|
||||||
# Process decode requests
|
# Process decode requests
|
||||||
if has_decode:
|
if has_decode:
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
@ -684,8 +704,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||||
# - mamba_cache_params.ssm_state's slots will be selected
|
# - mamba_cache_params.ssm_state's slots will be selected
|
||||||
# using state_indices_tensor_d
|
# using state_indices_tensor_d
|
||||||
|
# NOTE: final output is an in-place update of out tensor
|
||||||
hidden_states_d = selective_state_update(
|
selective_state_update(
|
||||||
ssm_state,
|
ssm_state,
|
||||||
hidden_states_d,
|
hidden_states_d,
|
||||||
dt_d,
|
dt_d,
|
||||||
@ -697,26 +717,16 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=state_indices_tensor_d,
|
state_batch_indices=state_indices_tensor_d,
|
||||||
|
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||||
|
self.head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
ssd_output_list.insert(
|
|
||||||
0,
|
|
||||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
|
||||||
self.head_dim))
|
|
||||||
else:
|
|
||||||
ssd_output_list.append(
|
|
||||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
|
||||||
self.head_dim))
|
|
||||||
|
|
||||||
# Merge prefill and decode outputs before passing to gated MLP
|
|
||||||
hidden_states = torch.vstack(ssd_output_list)
|
|
||||||
|
|
||||||
# 4. gated MLP
|
# 4. gated MLP
|
||||||
# GatedRMSNorm internally applying SiLU to the gate
|
# GatedRMSNorm internally applying SiLU to the gate
|
||||||
# SiLU is applied internally before normalization, unlike standard
|
# SiLU is applied internally before normalization, unlike standard
|
||||||
# norm usage
|
# norm usage
|
||||||
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
|
hidden_states = self.norm(preallocated_ssm_out,
|
||||||
|
gate[:num_actual_tokens])
|
||||||
|
|
||||||
# 5. Final linear projection
|
# 5. Final linear projection
|
||||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||||
|
|||||||
@ -205,7 +205,8 @@ def selective_state_update(state,
|
|||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
state_batch_indices=None,
|
state_batch_indices=None,
|
||||||
pad_slot_id=PAD_SLOT_ID):
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
|
out=None):
|
||||||
"""
|
"""
|
||||||
Argument:
|
Argument:
|
||||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||||
@ -223,10 +224,9 @@ def selective_state_update(state,
|
|||||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||||
in this case, the kernel will not process entries at
|
in this case, the kernel will not process entries at
|
||||||
indices 0 and 3
|
indices 0 and 3
|
||||||
Return:
|
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||||
out: (batch, dim) or (batch, nheads, dim)
|
In-place updated.
|
||||||
"""
|
"""
|
||||||
has_heads = state.dim() > 3
|
|
||||||
if state.dim() == 3:
|
if state.dim() == 3:
|
||||||
state = state.unsqueeze(1)
|
state = state.unsqueeze(1)
|
||||||
if x.dim() == 2:
|
if x.dim() == 2:
|
||||||
@ -245,6 +245,8 @@ def selective_state_update(state,
|
|||||||
z = z.unsqueeze(1)
|
z = z.unsqueeze(1)
|
||||||
if dt_bias is not None and dt_bias.dim() == 1:
|
if dt_bias is not None and dt_bias.dim() == 1:
|
||||||
dt_bias = dt_bias.unsqueeze(0)
|
dt_bias = dt_bias.unsqueeze(0)
|
||||||
|
if out.dim() == 2:
|
||||||
|
out = out.unsqueeze(1)
|
||||||
|
|
||||||
_, nheads, dim, dstate = state.shape
|
_, nheads, dim, dstate = state.shape
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
@ -264,7 +266,8 @@ def selective_state_update(state,
|
|||||||
assert dt_bias.shape == (nheads, dim)
|
assert dt_bias.shape == (nheads, dim)
|
||||||
if state_batch_indices is not None:
|
if state_batch_indices is not None:
|
||||||
assert state_batch_indices.shape == (batch, )
|
assert state_batch_indices.shape == (batch, )
|
||||||
out = torch.empty_like(x)
|
assert out.shape == x.shape
|
||||||
|
|
||||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||||
(0, 0, 0))
|
(0, 0, 0))
|
||||||
@ -328,9 +331,6 @@ def selective_state_update(state,
|
|||||||
BLOCK_SIZE_M,
|
BLOCK_SIZE_M,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
)
|
)
|
||||||
if not has_heads:
|
|
||||||
out = out.squeeze(1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def selective_scan_fn(u,
|
def selective_scan_fn(u,
|
||||||
|
|||||||
@ -454,6 +454,7 @@ def _chunk_scan_fwd(
|
|||||||
chunk_indices=None,
|
chunk_indices=None,
|
||||||
chunk_offsets=None,
|
chunk_offsets=None,
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
|
out=None,
|
||||||
):
|
):
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
batch, seqlen, nheads, headdim = x.shape
|
||||||
_, _, nchunks, chunk_size = dt.shape
|
_, _, nchunks, chunk_size = dt.shape
|
||||||
@ -483,20 +484,10 @@ def _chunk_scan_fwd(
|
|||||||
else:
|
else:
|
||||||
chunk_indices, chunk_offsets = None, None
|
chunk_indices, chunk_offsets = None, None
|
||||||
|
|
||||||
# Allocates output.
|
assert out.shape == x.shape
|
||||||
out = torch.empty(batch,
|
|
||||||
seqlen,
|
|
||||||
nheads,
|
|
||||||
headdim,
|
|
||||||
device=x.device,
|
|
||||||
dtype=x.dtype)
|
|
||||||
if z is not None:
|
if z is not None:
|
||||||
out_x = torch.empty(batch,
|
out_x = torch.empty_like(x)
|
||||||
seqlen,
|
|
||||||
nheads,
|
|
||||||
headdim,
|
|
||||||
device=x.device,
|
|
||||||
dtype=x.dtype)
|
|
||||||
assert out_x.stride() == out.stride()
|
assert out_x.stride() == out.stride()
|
||||||
else:
|
else:
|
||||||
out_x = None
|
out_x = None
|
||||||
@ -579,4 +570,4 @@ def _chunk_scan_fwd(
|
|||||||
IS_TRITON_22=TRITON_22,
|
IS_TRITON_22=TRITON_22,
|
||||||
HAS_INITSTATES=initial_states is not None,
|
HAS_INITSTATES=initial_states is not None,
|
||||||
)
|
)
|
||||||
return out, out_x
|
return out_x
|
||||||
|
|||||||
@ -36,7 +36,8 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
chunk_offsets=None,
|
chunk_offsets=None,
|
||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf"))):
|
dt_limit=(0.0, float("inf")),
|
||||||
|
out=None):
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
batch, seqlen, nheads, headdim = x.shape
|
||||||
_, _, ngroups, dstate = B.shape
|
_, _, ngroups, dstate = B.shape
|
||||||
assert nheads % ngroups == 0
|
assert nheads % ngroups == 0
|
||||||
@ -134,7 +135,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||||
# a seq_idx change, in which case we take states information from
|
# a seq_idx change, in which case we take states information from
|
||||||
# init_states.
|
# init_states.
|
||||||
out, out_x = _chunk_scan_fwd(
|
out_x = _chunk_scan_fwd(
|
||||||
CB,
|
CB,
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
@ -147,9 +148,10 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_offsets=chunk_offsets,
|
chunk_offsets=chunk_offsets,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
|
out=out,
|
||||||
)
|
)
|
||||||
if cu_seqlens is None:
|
if cu_seqlens is None:
|
||||||
return out, out_x, dt, dA_cumsum, states, final_states
|
return out_x, dt, dA_cumsum, states, final_states
|
||||||
else:
|
else:
|
||||||
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
||||||
varlen_states = chunk_state_varlen(
|
varlen_states = chunk_state_varlen(
|
||||||
@ -161,7 +163,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
states.squeeze(0),
|
states.squeeze(0),
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
)
|
)
|
||||||
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
return out_x, dt, dA_cumsum, states, final_states, varlen_states
|
||||||
|
|
||||||
|
|
||||||
def mamba_chunk_scan_combined(x,
|
def mamba_chunk_scan_combined(x,
|
||||||
@ -180,6 +182,7 @@ def mamba_chunk_scan_combined(x,
|
|||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
|
out=None,
|
||||||
return_final_states=False,
|
return_final_states=False,
|
||||||
return_varlen_states=False):
|
return_varlen_states=False):
|
||||||
"""
|
"""
|
||||||
@ -197,15 +200,14 @@ def mamba_chunk_scan_combined(x,
|
|||||||
seq_idx: (batch, seqlen)
|
seq_idx: (batch, seqlen)
|
||||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
||||||
dt_softplus: Whether to apply softplus to dt
|
dt_softplus: Whether to apply softplus to dt
|
||||||
Return:
|
out: Preallocated output tensor
|
||||||
out: (batch, seqlen, nheads, headdim)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not return_varlen_states:
|
if not return_varlen_states:
|
||||||
cu_seqlens = None
|
cu_seqlens = None
|
||||||
else:
|
else:
|
||||||
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
|
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
|
||||||
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
|
out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
|
||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
A,
|
A,
|
||||||
@ -221,12 +223,14 @@ def mamba_chunk_scan_combined(x,
|
|||||||
chunk_offsets=chunk_offsets,
|
chunk_offsets=chunk_offsets,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
dt_softplus=dt_softplus,
|
dt_softplus=dt_softplus,
|
||||||
dt_limit=dt_limit)
|
dt_limit=dt_limit,
|
||||||
|
out=out)
|
||||||
if not return_varlen_states:
|
if not return_varlen_states:
|
||||||
return out if not return_final_states else (out, final_states)
|
if not return_final_states:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return final_states
|
||||||
else:
|
else:
|
||||||
varlen_states = rest[0]
|
varlen_states = rest[0]
|
||||||
return (out,
|
return (varlen_states) if not return_final_states else (final_states,
|
||||||
varlen_states) if not return_final_states else (out,
|
|
||||||
final_states,
|
|
||||||
varlen_states)
|
varlen_states)
|
||||||
|
|||||||
@ -387,7 +387,8 @@ class Phi4Mamba(nn.Module):
|
|||||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||||
query_start_loc=attn_metadata.query_start_loc)
|
query_start_loc=attn_metadata.query_start_loc)
|
||||||
else:
|
else:
|
||||||
scan_outputs = selective_state_update(
|
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||||
|
selective_state_update(
|
||||||
mamba_cache_params.ssm_state,
|
mamba_cache_params.ssm_state,
|
||||||
hidden_states.transpose(0, 1),
|
hidden_states.transpose(0, 1),
|
||||||
discrete_time_step.transpose(0, 1),
|
discrete_time_step.transpose(0, 1),
|
||||||
@ -400,7 +401,8 @@ class Phi4Mamba(nn.Module):
|
|||||||
None if self.yoco_kv else gate.transpose(0, 1),
|
None if self.yoco_kv else gate.transpose(0, 1),
|
||||||
time_proj_bias,
|
time_proj_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||||
|
out=scan_outputs)
|
||||||
scan_outputs = scan_outputs.transpose(0, 1)
|
scan_outputs = scan_outputs.transpose(0, 1)
|
||||||
|
|
||||||
# 4. Final linear projection
|
# 4. Final linear projection
|
||||||
|
|||||||
@ -257,7 +257,21 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
||||||
if has_prefill else None)
|
if has_prefill else None)
|
||||||
|
|
||||||
ssd_output_list = []
|
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||||
|
# and decode outputs
|
||||||
|
preallocated_ssm_out = torch.empty(
|
||||||
|
[
|
||||||
|
num_prefill_tokens + num_decodes,
|
||||||
|
(self.num_heads // self.tp_size) * self.head_dim
|
||||||
|
],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||||
|
preallocated_ssm_out,
|
||||||
|
[num_prefill_tokens, num_decodes],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
# Process prefill requests
|
# Process prefill requests
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
@ -290,7 +304,7 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
initial_states = torch.where(
|
initial_states = torch.where(
|
||||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||||
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
||||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
varlen_state = mamba_chunk_scan_combined(
|
||||||
hidden_states_p.view(1, num_prefill_tokens,
|
hidden_states_p.view(1, num_prefill_tokens,
|
||||||
self.num_heads // self.tp_size,
|
self.num_heads // self.tp_size,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
@ -312,15 +326,14 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
return_final_states=False,
|
return_final_states=False,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
|
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||||
|
self.head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
# update ssm states
|
# update ssm states
|
||||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||||
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
||||||
|
|
||||||
# - reshape
|
|
||||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
|
||||||
|
|
||||||
# Process decode requests
|
# Process decode requests
|
||||||
if has_decode:
|
if has_decode:
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
@ -349,8 +362,7 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||||
# - mamba_cache_params.ssm_state's slots will be selected
|
# - mamba_cache_params.ssm_state's slots will be selected
|
||||||
# using state_indices_tensor_d
|
# using state_indices_tensor_d
|
||||||
|
selective_state_update(
|
||||||
hidden_states_d = selective_state_update(
|
|
||||||
mamba_cache_params.ssm_state,
|
mamba_cache_params.ssm_state,
|
||||||
hidden_states_d,
|
hidden_states_d,
|
||||||
dt,
|
dt,
|
||||||
@ -362,17 +374,13 @@ class Plamo2MambaMixer(nn.Module):
|
|||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=state_indices_tensor_d,
|
state_batch_indices=state_indices_tensor_d,
|
||||||
|
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||||
|
self.head_dim),
|
||||||
)
|
)
|
||||||
assert self.num_heads % self.tp_size == 0
|
assert self.num_heads % self.tp_size == 0
|
||||||
ssd_output_list.append(
|
|
||||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
|
||||||
self.head_dim))
|
|
||||||
|
|
||||||
# Merge prefill and decode outputs before passing to MLP
|
|
||||||
hidden_states = torch.vstack(ssd_output_list)
|
|
||||||
|
|
||||||
# 4. Final linear projection
|
# 4. Final linear projection
|
||||||
out = self.out_proj(hidden_states)
|
out = self.out_proj(preallocated_ssm_out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user