[Model] Mamba2 varlen refactor (#21467)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang 2025-09-26 07:31:14 -04:00 committed by GitHub
parent 633f943e30
commit 2b6b1d7809
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 722 additions and 864 deletions

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
mamba_chunk_scan_combined_varlen)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba2_attn import (
_query_start_loc_to_chunk_indices_offsets)
@ -185,9 +185,14 @@ def generate_continuous_batched_examples(example_lens_by_batch,
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
# varlen has implicit batch=1
dt2 = dt2.squeeze(0)
X2 = X2.squeeze(0)
B2 = B2.squeeze(0)
C2 = C2.squeeze(0)
yield ([Y_min[s, IND_S[s]:IND_E[s]]
for s in range(num_examples)] if return_naive_ref else None,
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))
@pytest.mark.parametrize("itype",
@ -198,7 +203,7 @@ def generate_continuous_batched_examples(example_lens_by_batch,
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):
# this tests the kernels on a single example (no batching)
# this tests the kernels on a single example (bs=1)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
@ -219,23 +224,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
# varlen has implicit batch=1
X = X.squeeze(0)
dt = dt.squeeze(0)
A = A.squeeze(0)
B = B.squeeze(0)
C = C.squeeze(0)
Y = torch.empty_like(X)
final_state = mamba_chunk_scan_combined(X,
dt,
A,
B,
C,
chunk_size,
D=None,
return_final_states=True,
out=Y)
final_state = mamba_chunk_scan_combined_varlen(X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
out=Y)
# just test the last in sequence
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.testing.assert_close(final_state[:, -1],
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol)
@ -300,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
cu_seqlens, chunk_size, cu_seqlens[-1])
Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined(
new_states = mamba_chunk_scan_combined_varlen(
X,
dt,
A,
@ -312,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=states,
out=Y,
)
@ -321,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
for i in range(num_examples):
# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
@ -386,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined(
state_ref = mamba_chunk_scan_combined_varlen(
X,
dt,
A,
@ -398,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_ref,
)
@ -414,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(chunked_seqlens), device=device),
chunked_seqlens,
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
for i in range(num_sequences):
# fmt: off
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
# fmt: on
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined(
partial_state = mamba_chunk_scan_combined_varlen(
X_chunked,
dt_chunked,
A,
@ -446,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
seq_idx=chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_partial,
)
@ -461,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
remaining_chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(remaining_chunked_seqlens), device=device),
remaining_chunked_seqlens,
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
torch.int32)
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
for i in range(num_sequences):
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
# assert input chunking is correct
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
],
dim=1)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
dim=0)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
# fmt: on
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
@ -498,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
remaining_chunked_cu_seqlens[-1])
Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined(
state_chunked = mamba_chunk_scan_combined_varlen(
remaining_X_chunked,
remaining_dt_chunked,
A,
@ -510,7 +528,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
seq_idx=remaining_chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=partial_state,
out=Y_chunked,
)
@ -518,17 +535,17 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:, :chunked_seqlens[i], ...],
Y_ref_seq[:, :chunked_seqlens[i], ...],
Y_seq[:chunked_seqlens[i], ...],
Y_ref_seq[:chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
torch.testing.assert_close(
Y_seq[:, chunked_seqlens[i]:, ...],
Y_ref_seq[:, chunked_seqlens[i]:, ...],
Y_seq[chunked_seqlens[i]:, ...],
Y_ref_seq[chunked_seqlens[i]:, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023

View File

@ -29,7 +29,7 @@ from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
mamba_chunk_scan_combined_varlen)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
@ -504,6 +504,7 @@ class MambaMixer2(MambaBase, CustomOp):
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
query_start_loc_p = attn_metadata.query_start_loc_p
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
@ -545,6 +546,7 @@ class MambaMixer2(MambaBase, CustomOp):
out, _ = self.out_proj(hidden_states)
return out
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
@ -570,9 +572,6 @@ class MambaMixer2(MambaBase, CustomOp):
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
@ -620,15 +619,15 @@ class MambaMixer2(MambaBase, CustomOp):
ssm_state[state_indices_tensor_p], 0)
# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
dt_p.unsqueeze(0),
dt_p,
self.A,
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=chunk_size,
D=self.D,
@ -639,17 +638,15 @@ class MambaMixer2(MambaBase, CustomOp):
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor_p] = varlen_state
ssm_state[state_indices_tensor_p] = varlen_states
# Process decode requests
if has_decode:

View File

@ -427,7 +427,7 @@ def causal_conv1d_fn(
batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
else:
seqlens = np.diff(query_start_loc.to('cpu'))
seqlens = query_start_loc.diff().to('cpu')
args = seqlens
MAX_NUM_PROGRAMS = 1024

View File

@ -99,34 +99,28 @@ def _bmm_chunk_fwd_kernel(
seq_idx_ptr,
# Matrix dimensions
seqlen,
chunk_size,
K,
ngroups,
stride_a_batch,
stride_a_seqlen,
stride_a_head,
stride_ak,
stride_b_batch,
stride_b_seqlen,
stride_b_head,
stride_bk,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_outm,
stride_outn,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
chunk_size: tl.constexpr,
K: tl.constexpr,
ngroups: tl.constexpr,
stride_a_seqlen: tl.int64,
stride_a_head: tl.int64,
stride_ak: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_bk: tl.constexpr,
stride_out_chunk: tl.int64,
stride_out_head: tl.int64,
stride_outm: tl.int64,
stride_outn: tl.constexpr,
stride_seq_idx_seqlen: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_ch = tl.program_id(axis=2).to(tl.int64)
pid_ch = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_ch // ngroups
pid_h = pid_ch - pid_c * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
@ -135,10 +129,10 @@ def _bmm_chunk_fwd_kernel(
if IS_CAUSAL:
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
return
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@ -150,6 +144,8 @@ def _bmm_chunk_fwd_kernel(
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# compute a * b.T
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
@ -165,18 +161,19 @@ def _bmm_chunk_fwd_kernel(
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_SEQ_IDX:
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
mask=offs_n < chunk_size_limit,
other=-2)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
# Zero out the results that are not from the same request
# in the varlen batch
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
mask=offs_n < chunk_size_limit,
other=-2)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
offs_n[None, :] * stride_outn)
tl.store(out_ptrs,
@ -185,78 +182,61 @@ def _bmm_chunk_fwd_kernel(
(offs_n[None, :] < chunk_size))
def _bmm_chunk_fwd(a,
b,
chunk_size,
seq_idx=None,
causal=False,
output_dtype=None):
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
"""
Argument:
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
a: (seqlen, ngroups, k)
b: (seqlen, ngroups, k)
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
Return:
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
out: (nchunks, ngroups, chunk_size, chunk_size)
"""
# Check constraints.
has_groups = a.dim() == 4
if not has_groups:
batch, seqlen, k = a.shape
else:
batch, seqlen, ngroups, k = a.shape
seqlen, ngroups, k = a.shape
assert b.shape == a.shape
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if a.stride(-1) != 1 and a.stride(1) != 1:
assert seq_idx is not None
assert seq_idx.shape == (seqlen, )
if a.stride(-1) != 1 and a.stride(0) != 1:
a = a.contiguous()
if b.stride(-1) != 1 and b.stride(1) != 1:
if b.stride(-1) != 1 and b.stride(0) != 1:
b = b.contiguous()
nchunks = math.ceil(seqlen / chunk_size)
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty(
(batch, nchunks, chunk_size, chunk_size) if not has_groups else
(batch, nchunks, ngroups, chunk_size, chunk_size),
device=a.device,
dtype=out_dtype)
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
device=a.device,
dtype=out_dtype)
dot_dtype = (tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
(tl.float16 if a.dtype == torch.float16
or b.dtype == torch.float16 else tl.float32))
grid = lambda META: (triton.cdiv(
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
chunk_size, META['BLOCK_SIZE_N']), batch, nchunks
if not has_groups else nchunks * ngroups)
chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups)
with torch.cuda.device(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a,
b,
out,
seq_idx,
seqlen,
chunk_size,
k,
ngroups if has_groups else 1,
a.stride(0),
a.stride(1),
0 if not has_groups else a.stride(2),
a.stride(-1),
b.stride(0),
b.stride(1),
0 if not has_groups else b.stride(2),
b.stride(-1),
out.stride(0),
out.stride(1),
0 if not has_groups else out.stride(2),
out.stride(-2),
out.stride(-1),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
causal,
dot_dtype,
HAS_SEQ_IDX=seq_idx is not None,
a_ptr=a,
b_ptr=b,
out_ptr=out,
seq_idx_ptr=seq_idx,
seqlen=seqlen,
chunk_size=chunk_size,
K=k,
ngroups=ngroups,
stride_a_seqlen=a.stride(0),
stride_a_head=a.stride(1),
stride_ak=a.stride(2),
stride_b_seqlen=b.stride(0),
stride_b_head=b.stride(1),
stride_bk=b.stride(2),
stride_out_chunk=out.stride(0),
stride_out_head=out.stride(1),
stride_outm=out.stride(-2),
stride_outn=out.stride(-1),
stride_seq_idx_seqlen=seq_idx.stride(0),
IS_CAUSAL=causal,
dot_dtype=dot_dtype,
)
return out

View File

@ -6,7 +6,6 @@
# ruff: noqa: E501,SIM102
import torch
from packaging import version
from vllm.triton_utils import tl, triton
@ -114,7 +113,6 @@ def _chunk_scan_fwd_kernel(
x_ptr,
z_ptr,
out_ptr,
out_x_ptr,
dt_ptr,
dA_cumsum_ptr,
seq_idx_ptr,
@ -126,60 +124,49 @@ def _chunk_scan_fwd_kernel(
chunk_offsets_ptr,
chunk_meta_num,
# Matrix dimensions
chunk_size,
hdim,
dstate,
batch,
chunk_size: tl.constexpr,
hdim: tl.constexpr,
dstate: tl.constexpr,
seqlen,
nheads_ngroups_ratio,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_cb_batch,
stride_cb_chunk,
stride_cb_head,
stride_cb_csize_m,
stride_cb_csize_k,
stride_x_batch,
stride_x_seqlen,
stride_x_head,
stride_x_hdim,
stride_z_batch,
stride_z_seqlen,
stride_z_head,
stride_z_hdim,
stride_out_batch,
stride_out_seqlen,
stride_out_head,
stride_out_hdim,
stride_dt_batch,
stride_dt_chunk,
stride_dt_head,
stride_dt_csize,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
stride_C_batch,
stride_C_seqlen,
stride_C_head,
stride_C_dstate,
stride_states_batch,
stride_states_chunk,
stride_states_head,
stride_states_hdim,
stride_states_dstate,
stride_init_states_batch,
stride_init_states_head,
stride_init_states_hdim,
stride_init_states_dstate,
stride_D_head,
stride_cb_chunk: tl.int64,
stride_cb_head: tl.int64,
stride_cb_csize_m: tl.int64,
stride_cb_csize_k: tl.constexpr,
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_z_seqlen: tl.int64,
stride_z_head: tl.int64,
stride_z_hdim: tl.constexpr,
stride_out_seqlen: tl.int64,
stride_out_head: tl.int64,
stride_out_hdim: tl.constexpr,
stride_dt_chunk: tl.int64,
stride_dt_head: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_head: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_seq_idx_seqlen: tl.constexpr,
stride_C_seqlen: tl.int64,
stride_C_head: tl.int64,
stride_C_dstate: tl.constexpr,
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_init_states_batch: tl.int64,
stride_init_states_head: tl.int64,
stride_init_states_hdim: tl.int64,
stride_init_states_dstate: tl.constexpr,
stride_D_head: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
@ -187,9 +174,7 @@ def _chunk_scan_fwd_kernel(
IS_TRITON_22: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_bc = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_c = tl.program_id(axis=1).to(tl.int64)
if not HAS_INITSTATES:
c_idx = pid_c
c_off = 0
@ -201,53 +186,51 @@ def _chunk_scan_fwd_kernel(
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + (
pid_h // nheads_ngroups_ratio) * stride_cb_head
x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + (
cb_ptr += c_idx * stride_cb_chunk + (pid_h //
nheads_ngroups_ratio) * stride_cb_head
x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += c_idx * chunk_size * stride_C_seqlen + (
pid_h // nheads_ngroups_ratio) * stride_C_head
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head
prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head
prev_states_hdim = stride_states_hdim
prev_states_dstate = stride_states_dstate
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen
# - we only need seq_idx_prev to be aligned to chunk boundary
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
mask=c_idx >= 1,
other=0)
seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen
# - we only need seq_idx_prev to be aligned to chunk boundary
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
mask=c_idx >= 1,
other=0)
if HAS_INITSTATES:
# if there are init states, we only need seq_idx_m to point
# what is the current seq_idx
if HAS_INITSTATES:
# if there are init states, we only need seq_idx_m to point
# what is the current seq_idx
# get current seq idx
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
seq_idx_m = tl.load(
seq_idx_ptr +
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
# get current seq idx
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
seq_idx_m = tl.load(
seq_idx_ptr +
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
# - recall that in ssd_state_passing, for the case c_off == 0
# i.e., the very first sequence, we made states_ptr hold its initial state
# so this edge case is taken care of
if ((c_off == 0) and
(seq_idx_prev != seq_idx_m
) # if a seq is changed exactly on boundary
or (c_off > 0) # implies a new example (pseudo chunk)
):
# - recall that in ssd_state_passing, for the case c_off == 0
# i.e., the very first sequence, we made states_ptr hold its initial state
# so this edge case is taken care of
if ((c_off == 0) and (seq_idx_prev != seq_idx_m
) # if a seq is changed exactly on boundary
or (c_off > 0) # implies a new example (pseudo chunk)
):
# - replace prev_states_ptr with init_states
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
prev_states_hdim = stride_init_states_hdim # override strides
prev_states_dstate = stride_init_states_dstate
# - replace prev_states_ptr with init_states
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
prev_states_hdim = stride_init_states_hdim # override strides
prev_states_dstate = stride_init_states_dstate
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
@ -256,7 +239,6 @@ def _chunk_scan_fwd_kernel(
# - handle chunk state limit
if HAS_INITSTATES:
# have to split this if otherwise compilation will have problems
dA_cs_m_boundary = 0.0
@ -296,13 +278,11 @@ def _chunk_scan_fwd_kernel(
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
else:
# - handle seq idx when HAS_INITSTATES==False
if not HAS_INITSTATES:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
@ -319,18 +299,15 @@ def _chunk_scan_fwd_kernel(
prev_states_ptrs = prev_states_ptr + (
offs_n[None, :] * prev_states_hdim +
offs_k_dstate[:, None] * prev_states_dstate)
if HAS_SEQ_IDX:
if not HAS_INITSTATES:
# - this is for continuous batching where there is no init states
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m),
0.0)
else:
# - if there is initstates, we will rely on prev_states, no zeroing
# required.
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
if not HAS_INITSTATES:
# - this is for continuous batching where there is no init states
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
else:
scale_m = tl.exp(dA_cs_m)
# - if there is initstates, we will rely on prev_states, no zeroing
# required.
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
@ -416,15 +393,7 @@ def _chunk_scan_fwd_kernel(
acc += x_residual * D
if HAS_Z:
out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] +
offs_out_n[None, :])
tl.store(out_x_ptrs,
acc,
mask=(offs_out_m[:, None] < chunk_size_limit) &
(offs_out_n[None, :] < hdim))
z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
stride_z_hdim * offs_out_n[None, :])
z = tl.load(z_ptrs,
@ -433,7 +402,7 @@ def _chunk_scan_fwd_kernel(
other=0.0).to(tl.float32)
acc *= z * tl.sigmoid(z)
out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
offs_out_n[None, :] * stride_out_hdim)
tl.store(out_ptrs,
@ -449,126 +418,110 @@ def _chunk_scan_fwd(
dA_cumsum,
C,
states,
out,
seq_idx,
D=None,
z=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
initial_states=None,
out=None,
):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = C.shape
assert seq_idx is not None, "this implementation requires seq_idx"
seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
if z is not None:
assert z.shape == x.shape
assert C.shape == (seqlen, ngroups, dstate)
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
if z is not None:
assert z.shape == x.shape
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
assert states.shape == (nchunks, nheads, headdim, dstate)
assert seq_idx.shape == (seqlen, )
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if initial_states is not None:
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert batch == 1, "chunk scan only supports initial states with batch 1"
assert chunk_indices is not None and chunk_offsets is not None, \
"chunk_indices and chunk_offsets should have been set"
else:
chunk_indices, chunk_offsets = None, None
if initial_states is not None:
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert chunk_indices is not None and chunk_offsets is not None, \
"chunk_indices and chunk_offsets should have been set"
else:
chunk_indices, chunk_offsets = None, None
assert out.shape == x.shape
if z is not None:
out_x = torch.empty_like(x)
assert out_x.stride() == out.stride()
else:
out_x = None
grid = lambda META: (
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
headdim, META['BLOCK_SIZE_N']), batch * nchunks
headdim, META['BLOCK_SIZE_N']), nchunks
if chunk_offsets is None else len(chunk_offsets), nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2),
z.stride(3)) if z is not None else (0, 0, 0, 0))
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
(0, 0, 0))
initial_states_strides = ((initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3))
if initial_states is not None else (0, 0, 0, 0))
_chunk_scan_fwd_kernel[grid](
cb,
x,
z,
out,
out_x,
dt,
dA_cumsum,
seq_idx,
C,
states,
D,
initial_states,
chunk_indices,
chunk_offsets,
len(chunk_indices) if chunk_indices is not None else 0,
chunk_size,
headdim,
dstate,
batch,
seqlen,
nheads // ngroups,
cb.stride(0),
cb.stride(1),
cb.stride(2),
cb.stride(3),
cb.stride(4),
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
z_strides[0],
z_strides[1],
z_strides[2],
z_strides[3],
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
dt.stride(0),
dt.stride(2),
dt.stride(1),
dt.stride(3),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else
(0, 0)),
C.stride(0),
C.stride(1),
C.stride(2),
C.stride(3),
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
states.stride(4),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3)) if initial_states is not None else
(0, 0, 0, 0)),
D.stride(0) if D is not None else 0,
True,
D is not None,
D.dim() == 2 if D is not None else True,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
cb_ptr=cb,
x_ptr=x,
z_ptr=z,
out_ptr=out,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
seq_idx_ptr=seq_idx,
C_ptr=C,
states_ptr=states,
D_ptr=D,
initstates_ptr=initial_states,
chunk_indices_ptr=chunk_indices,
chunk_offsets_ptr=chunk_offsets,
chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0,
chunk_size=chunk_size,
hdim=headdim,
dstate=dstate,
seqlen=seqlen,
nheads_ngroups_ratio=nheads // ngroups,
stride_cb_chunk=cb.stride(0),
stride_cb_head=cb.stride(1),
stride_cb_csize_m=cb.stride(2),
stride_cb_csize_k=cb.stride(3),
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_z_seqlen=z_strides[0],
stride_z_head=z_strides[1],
stride_z_hdim=z_strides[2],
stride_out_seqlen=out.stride(0),
stride_out_head=out.stride(1),
stride_out_hdim=out.stride(2),
stride_dt_chunk=dt.stride(1),
stride_dt_head=dt.stride(0),
stride_dt_csize=dt.stride(2),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_seq_idx_seqlen=seq_idx.stride(0),
stride_C_seqlen=C.stride(0),
stride_C_head=C.stride(1),
stride_C_dstate=C.stride(2),
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_init_states_batch=initial_states_strides[0],
stride_init_states_head=initial_states_strides[1],
stride_init_states_hdim=initial_states_strides[2],
stride_init_states_dstate=initial_states_strides[3],
stride_D_head=D.stride(0) if D is not None else 0,
IS_CAUSAL=True,
HAS_D=D is not None,
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
HAS_Z=z is not None,
HAS_SEQ_IDX=seq_idx is not None,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
IS_TRITON_22=TRITON_22,
HAS_INITSTATES=initial_states is not None,
)
return out_x
return

View File

@ -35,41 +35,35 @@ def _chunk_cumsum_fwd_kernel(
dt_out_ptr,
dA_cumsum_ptr,
# Matrix dimension
batch,
seqlen,
nheads,
chunk_size,
dt_min,
dt_max,
nheads: tl.constexpr,
chunk_size: tl.constexpr,
dt_min: tl.constexpr,
dt_max: tl.constexpr,
# Strides
stride_dt_batch,
stride_dt_seqlen,
stride_dt_head,
stride_A_head,
stride_dt_bias_head,
stride_dt_out_batch,
stride_dt_out_chunk,
stride_dt_out_head,
stride_dt_out_csize,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_dt_seqlen: tl.int64,
stride_dt_head: tl.constexpr,
stride_A_head: tl.constexpr,
stride_dt_bias_head: tl.constexpr,
stride_dt_out_head: tl.int64,
stride_dt_out_chunk: tl.int64,
stride_dt_out_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
BLOCK_SIZE_CHUNK: tl.constexpr,
):
pid_b = tl.program_id(axis=0)
# if dt is long, may cause problems, so use 64 bit
# https://github.com/triton-lang/triton/issues/1058
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2)
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
pid_c = tl.program_id(axis=0).to(tl.int64)
pid_h = tl.program_id(axis=1)
dt_ptr += pid_c * chunk_size * stride_dt_seqlen
dt_out_ptr += pid_c * stride_dt_out_chunk
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
@ -93,9 +87,8 @@ def _chunk_cumsum_fwd_kernel(
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, softplus(dt), dt)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
dt = tl.clamp(dt, dt_min, dt_max)
dt = tl.where(
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
0.0)
@ -197,56 +190,46 @@ def _chunk_state_fwd_kernel(
dA_cumsum_ptr,
seq_idx_ptr,
# Matrix dimensions
hdim,
dstate,
chunk_size,
batch,
hdim: tl.constexpr,
dstate: tl.constexpr,
chunk_size: tl.constexpr,
seqlen,
nheads_ngroups_ratio,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_x_batch,
stride_x_seqlen,
stride_x_head,
stride_x_hdim,
stride_b_batch,
stride_b_seqlen,
stride_b_head,
stride_b_dstate,
stride_states_batch,
stride_states_chunk,
stride_states_head,
stride_states_hdim,
stride_states_dstate,
stride_dt_batch,
stride_dt_chunk,
stride_dt_head,
stride_dt_csize,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_b_dstate: tl.constexpr,
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_dt_head: tl.int64,
stride_dt_chunk: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_seq_idx_seqlen: tl.constexpr,
# Meta-parameters
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
pid_h // nheads_ngroups_ratio) * stride_b_head
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@ -259,13 +242,11 @@ def _chunk_state_fwd_kernel(
dA_cs_last = tl.load(dA_cumsum_ptr +
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
if HAS_SEQ_IDX:
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
if HAS_SEQ_IDX:
seq_idx_last = tl.load(seq_idx_ptr +
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
seq_idx_last = tl.load(seq_idx_ptr +
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
@ -280,29 +261,28 @@ def _chunk_state_fwd_kernel(
dA_cs_k = tl.load(dA_cumsum_ptrs,
mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
seq_idx_k = tl.load(seq_idx_ptrs,
mask=offs_k < chunk_size_limit - k,
other=-1)
seq_idx_k = tl.load(seq_idx_ptrs,
mask=offs_k < chunk_size_limit - k,
other=-1)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
else:
scale = tl.where(seq_idx_k == seq_idx_last,
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
scale = tl.where(seq_idx_k == seq_idx_last,
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
if HAS_SEQ_IDX:
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
states = acc.to(states_ptr.dtype.element_ty)
states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
@ -400,36 +380,35 @@ def _chunk_state_varlen_kernel(
states_ptr,
initstates_ptr,
# Matrix dimensions
hdim,
dstate,
chunk_size,
seqlen,
nheads_ngroups_ratio,
hdim: tl.constexpr,
dstate: tl.constexpr,
chunk_size: tl.constexpr,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_x_seqlen,
stride_x_head,
stride_x_hdim,
stride_b_seqlen,
stride_b_head,
stride_b_dstate,
stride_dt_chunk,
stride_dt_head,
stride_dt_csize,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_chunk_states_chunk,
stride_chunk_states_head,
stride_chunk_states_hdim,
stride_chunk_states_dstate,
stride_states_batch,
stride_states_head,
stride_states_hdim,
stride_states_dstate,
stride_init_states_batch,
stride_init_states_head,
stride_init_states_hdim,
stride_init_states_dstate,
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_b_dstate: tl.constexpr,
stride_dt_head: tl.int64,
stride_dt_chunk: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_chunk_states_chunk: tl.int64,
stride_chunk_states_head: tl.int64,
stride_chunk_states_hdim: tl.int64,
stride_chunk_states_dstate: tl.constexpr,
stride_states_batch: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_init_states_batch: tl.int64,
stride_init_states_head: tl.int64,
stride_init_states_hdim: tl.int64,
stride_init_states_dstate: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
@ -558,52 +537,47 @@ def _chunk_cumsum_fwd(dt,
dt_bias=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
batch, seqlen, nheads = dt.shape
seqlen, nheads = dt.shape
assert A.shape == (nheads, )
if dt_bias is not None:
assert dt_bias.shape == (nheads, )
nchunks = math.ceil(seqlen / chunk_size)
dt_out = torch.empty(batch,
nheads,
dt_out = torch.empty(nheads,
nchunks,
chunk_size,
device=dt.device,
dtype=torch.float32)
dA_cumsum = torch.empty(batch,
nheads,
dA_cumsum = torch.empty(nheads,
nchunks,
chunk_size,
device=dt.device,
dtype=torch.float32)
grid_chunk_cs = lambda META: (batch, nchunks,
grid_chunk_cs = lambda META: (nchunks,
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
with torch.cuda.device(dt.device.index):
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
dt,
A,
dt_bias,
dt_out,
dA_cumsum,
batch,
seqlen,
nheads,
chunk_size,
dt_limit[0],
dt_limit[1],
dt.stride(0),
dt.stride(1),
dt.stride(2),
A.stride(0),
dt_bias.stride(0) if dt_bias is not None else 0,
dt_out.stride(0),
dt_out.stride(2),
dt_out.stride(1),
dt_out.stride(3),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
dt_softplus,
dt_ptr=dt,
A_ptr=A,
dt_bias_ptr=dt_bias,
dt_out_ptr=dt_out,
dA_cumsum_ptr=dA_cumsum,
seqlen=seqlen,
nheads=nheads,
chunk_size=chunk_size,
dt_min=dt_limit[0],
dt_max=dt_limit[1],
stride_dt_seqlen=dt.stride(0),
stride_dt_head=dt.stride(1),
stride_A_head=A.stride(0),
stride_dt_bias_head=dt_bias.stride(0)
if dt_bias is not None else 0,
stride_dt_out_head=dt_out.stride(0),
stride_dt_out_chunk=dt_out.stride(1),
stride_dt_out_csize=dt_out.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
DT_SOFTPLUS=dt_softplus,
HAS_DT_BIAS=dt_bias is not None,
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
)
@ -617,63 +591,57 @@ def _chunk_state_fwd(B,
seq_idx=None,
states=None,
states_in_fp32=True):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = B.shape
seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert B.shape == (seqlen, ngroups, dstate)
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
assert seq_idx is not None
assert seq_idx.shape == (seqlen, )
if states is not None:
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
assert states.shape == (nchunks, nheads, headdim, dstate)
else:
states_dtype = torch.float32 if states_in_fp32 else B.dtype
states = torch.empty((batch, nchunks, nheads, headdim, dstate),
states = torch.empty((nchunks, nheads, headdim, dstate),
device=x.device,
dtype=states_dtype)
grid = lambda META: (
triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(
dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_state_fwd_kernel[grid](
x,
B,
states,
dt,
dA_cumsum,
seq_idx,
headdim,
dstate,
chunk_size,
batch,
seqlen,
nheads // ngroups,
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
B.stride(0),
B.stride(1),
B.stride(2),
B.stride(-1),
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
states.stride(4),
dt.stride(0),
dt.stride(2),
dt.stride(1),
dt.stride(3),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
HAS_SEQ_IDX=seq_idx is not None,
x_ptr=x,
b_ptr=B,
states_ptr=states,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
seq_idx_ptr=seq_idx,
hdim=headdim,
dstate=dstate,
chunk_size=chunk_size,
seqlen=seqlen,
nheads_ngroups_ratio=nheads // ngroups,
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_b_seqlen=B.stride(0),
stride_b_head=B.stride(1),
stride_b_dstate=B.stride(2),
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_dt_head=dt.stride(0),
stride_dt_chunk=dt.stride(1),
stride_dt_csize=dt.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_seq_idx_seqlen=seq_idx.stride(0),
)
return states
@ -705,46 +673,52 @@ def chunk_state_varlen(B,
dstate,
dtype=chunk_states.dtype,
device=chunk_states.device)
initial_states_strides = ((initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3))
if initial_states is not None else (0, 0, 0, 0))
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
with torch.cuda.device(x.device.index):
_chunk_state_varlen_kernel[grid](
x,
B,
dt,
dA_cumsum,
chunk_states,
cu_seqlens,
states,
initial_states,
headdim,
dstate,
chunk_size,
total_seqlen,
nheads // ngroups,
x.stride(0),
x.stride(1),
x.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
dt.stride(1),
dt.stride(0),
dt.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
chunk_states.stride(0),
chunk_states.stride(1),
chunk_states.stride(2),
chunk_states.stride(3),
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3)) if initial_states is not None else
(0, 0, 0, 0)),
x_ptr=x,
b_ptr=B,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
chunk_states_ptr=chunk_states,
cu_seqlens_ptr=cu_seqlens,
states_ptr=states,
initstates_ptr=initial_states,
hdim=headdim,
dstate=dstate,
chunk_size=chunk_size,
nheads_ngroups_ratio=nheads // ngroups,
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_b_seqlen=B.stride(0),
stride_b_head=B.stride(1),
stride_b_dstate=B.stride(2),
stride_dt_head=dt.stride(0),
stride_dt_chunk=dt.stride(1),
stride_dt_csize=dt.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_chunk_states_chunk=chunk_states.stride(0),
stride_chunk_states_head=chunk_states.stride(1),
stride_chunk_states_hdim=chunk_states.stride(2),
stride_chunk_states_dstate=chunk_states.stride(3),
stride_states_batch=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_init_states_batch=initial_states_strides[0],
stride_init_states_head=initial_states_strides[1],
stride_init_states_hdim=initial_states_strides[2],
stride_init_states_dstate=initial_states_strides[3],
HAS_INITSTATES=initial_states is not None)
return states

View File

@ -31,6 +31,7 @@ def _mamba_chunk_scan_combined_fwd(x,
B,
C,
chunk_size,
out,
D=None,
z=None,
dt_bias=None,
@ -41,14 +42,13 @@ def _mamba_chunk_scan_combined_fwd(x,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
out=None):
state_dtype=None):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
seqlen, nheads, headdim = x.shape
_, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert dt.shape == (batch, seqlen, nheads)
assert B.shape == (seqlen, ngroups, dstate)
assert dt.shape == (seqlen, nheads)
assert A.shape == (nheads, )
assert C.shape == B.shape
if z is not None:
@ -56,25 +56,24 @@ def _mamba_chunk_scan_combined_fwd(x,
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
assert seq_idx.shape == (seqlen, )
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if x.stride(-1) != 1 and x.stride(
1) != 1: # Either M or K dimension should be contiguous
0) != 1: # Either M or K dimension should be contiguous
x = x.contiguous()
if z is not None and z.stride(-1) != 1 and z.stride(
1) != 1: # Either M or K dimension should be contiguous
0) != 1: # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
if initial_states is not None:
if cu_seqlens is None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
else:
assert initial_states.shape == (len(cu_seqlens) - 1, nheads,
headdim, dstate)
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim,
dstate)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
@ -114,18 +113,16 @@ def _mamba_chunk_scan_combined_fwd(x,
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum,
dA_cumsum, # (nheads, nchunks, chunk_size)
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else None,
if initial_states is not None else
None, # (batch, nheads, headdim*dstate)
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None,
chunk_offsets=chunk_offsets)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C,
@ -144,87 +141,88 @@ def _mamba_chunk_scan_combined_fwd(x,
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
out_x = _chunk_scan_fwd(
_chunk_scan_fwd(
CB,
x,
dt,
dA_cumsum,
C,
states,
out, # in-place update
seq_idx,
D=D,
z=z,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=initial_states,
out=out,
)
if cu_seqlens is None:
return out_x, dt, dA_cumsum, states, final_states
else:
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
varlen_states = chunk_state_varlen(
B.squeeze(0),
x.squeeze(0),
dt.squeeze(0),
dA_cumsum.squeeze(0),
cu_seqlens,
states.squeeze(0),
initial_states=initial_states,
)
return out_x, dt, dA_cumsum, states, final_states, varlen_states
varlen_states = chunk_state_varlen(
B,
x,
dt,
dA_cumsum,
cu_seqlens,
states,
initial_states=initial_states,
)
return varlen_states
def mamba_chunk_scan_combined(x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
out=None,
return_final_states=False,
return_varlen_states=False,
state_dtype=None):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
dt: (batch, seqlen, nheads)
A: (nheads)
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
chunk_size: int
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
dt_bias: (nheads,)
initial_states: (batch, nheads, headdim, dstate)
seq_idx: (batch, seqlen)
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
out: Preallocated output tensor
state_dtype: The data type of the ssm state
"""
if not return_varlen_states:
cu_seqlens = None
else:
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
def mamba_chunk_scan_combined_varlen(
x,
dt,
A,
B,
C,
chunk_size,
cu_seqlens,
seq_idx,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
chunk_indices=None,
chunk_offsets=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
):
"""
Argument:
x: (seqlen, nheads, headdim)
dt: (seqlen, nheads)
A: (nheads)
B: (seqlen, ngroups, dstate)
C: (seqlen, ngroups, dstate)
chunk_size: int
seq_idx: (seqlen)
cu_seqlens: (batch + 1)
out: (seqlen, nheads, headdim) preallocated output tensor
D: (nheads, headdim) or (nheads,)
z: (seqlen, nheads, headdim)
dt_bias: (nheads,)
initial_states: (batch, nheads, headdim, dstate)
dt_softplus: Whether to apply softplus to dt
out: (seqlen, nheads, headdim) preallocated output tensor
state_dtype: The data type of the ssm state
Return:
varlen_states: (batch, nheads, headdim, dstate)
"""
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
assert seq_idx is not None
varlen_states = _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
out,
D=D,
z=z,
dt_bias=dt_bias,
@ -235,14 +233,6 @@ def mamba_chunk_scan_combined(x,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
out=out,
state_dtype=state_dtype)
if not return_varlen_states:
if not return_final_states:
return
else:
return final_states
else:
varlen_states = rest[0]
return (varlen_states) if not return_final_states else (final_states,
varlen_states)
return varlen_states

View File

@ -27,64 +27,46 @@ def _state_passing_fwd_kernel(
# Pointers to matrices
states_ptr,
out_ptr,
final_states_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
chunk_offsets_ptr,
chunk_meta_num,
# Matrix dimensions
dim,
dim: tl.constexpr,
nchunks,
seqlen,
chunk_size,
chunk_size: tl.constexpr,
# Strides
stride_states_batch,
stride_states_chunk,
stride_states_head,
stride_states_dim,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_out_dim,
stride_final_states_batch,
stride_final_states_head,
stride_final_states_dim,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_initstates_batch,
stride_initstates_head,
stride_initstates_dim,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_dim: tl.constexpr,
stride_out_chunk: tl.int64,
stride_out_head: tl.int64,
stride_out_dim: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_initstates_batch: tl.int64,
stride_initstates_head: tl.int64,
stride_initstates_dim: tl.constexpr,
stride_seq_idx_seqlen: tl.constexpr,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
IS_CONT_BATCHED: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
pid_h = tl.program_id(axis=1)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
states_ptr += pid_h * stride_states_head
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
1) * stride_dA_cs_csize
out_ptr += pid_h * stride_out_head
if HAS_INITSTATES:
initstates_ptr += pid_h * stride_initstates_head
if not IS_CONT_BATCHED:
initstates_ptr += pid_b * stride_initstates_batch
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
# - states will be the past state of the sequence that continues on the current check
if not HAS_INITSTATES:
@ -101,65 +83,63 @@ def _state_passing_fwd_kernel(
out_ptrs += stride_out_chunk
prev_seq_idx_chunk_end = 0
logical_chunk_idx = 0
for c in range(nchunks):
for c in range(nchunks - 1):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale_mask = True
if HAS_SEQ_IDX:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
if HAS_INITSTATES:
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end = tl.load(seq_idx_ptr +
(min((c + 1) * chunk_size, seqlen) - 1) *
stride_seq_idx_seqlen)
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs,
mask=offs_m < dim,
other=0.0).to(tl.float32)
if HAS_INITSTATES:
if prev_seq_idx_chunk_end != seq_idx_chunk_end:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start = tl.load(seq_idx_ptr +
min(c * chunk_size, seqlen) *
stride_seq_idx_seqlen)
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
# - load the chunk offset:
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
mask=logical_chunk_idx < chunk_meta_num,
other=0)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if c_off > 0:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary = tl.load(
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
(c_off - 1) * stride_dA_cs_csize,
mask=(c_off - 1) > -1 and c_off < chunk_size,
other=0.0)
dA_cs -= dA_cs_boundary
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
# - increment logical chunk index for every physical chunk
logical_chunk_idx += 1
else:
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
prev_seq_idx_chunk_end = seq_idx_chunk_end
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start = tl.load(seq_idx_ptr +
min(c * chunk_size, seqlen) *
stride_seq_idx_seqlen)
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
# - load the chunk offset:
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
mask=logical_chunk_idx < chunk_meta_num,
other=0)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if c_off > 0:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary = tl.load(
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
(c_off - 1) * stride_dA_cs_csize,
mask=(c_off - 1) > -1 and c_off < chunk_size,
other=0.0)
dA_cs -= dA_cs_boundary
# - increment logical chunk index for every physical chunk
logical_chunk_idx += 1
else:
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
prev_seq_idx_chunk_end = seq_idx_chunk_end
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
else:
tl.store(final_states_ptrs, states, mask=offs_m < dim)
tl.store(out_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk
@ -168,81 +148,53 @@ def _state_passing_fwd_kernel(
def _state_passing_fwd(
states,
dA_cumsum,
seq_idx,
chunk_offsets,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
chunk_offsets=None,
):
batch, nchunks, nheads, dim = states.shape
if chunk_size is None:
chunk_size = dA_cumsum.shape[-1]
else:
assert chunk_size == dA_cumsum.shape[-1]
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
# - we also need chunk_offsets to be provided, to account
# for computation of dA_cumsum from the start of the
# sequence
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
nchunks, nheads, dim = states.shape
chunk_size = dA_cumsum.shape[-1]
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
seqlen = seq_idx.shape[-1]
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((batch, nchunks, nheads, dim),
out = torch.empty((nchunks, nheads, dim),
device=states.device,
dtype=out_dtype)
final_states = torch.empty((batch, nheads, dim),
device=states.device,
dtype=torch.float32)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
initial_states_strides = ((initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2))
if initial_states is not None else (0, 0, 0))
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states,
out,
final_states,
dA_cumsum,
initial_states,
seq_idx,
chunk_offsets,
len(chunk_offsets) if chunk_offsets is not None else 0,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size,
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2)) if initial_states is not None else
(0, 0, 0)),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
states_ptr=states,
out_ptr=out,
dA_cs_ptr=dA_cumsum,
initstates_ptr=initial_states,
seq_idx_ptr=seq_idx,
chunk_offsets_ptr=chunk_offsets,
chunk_meta_num=len(chunk_offsets)
if chunk_offsets is not None else 0,
dim=dim,
nchunks=nchunks,
seqlen=seqlen if seq_idx is not None else 0,
chunk_size=chunk_size if seq_idx is not None else 0,
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_dim=states.stride(2),
stride_out_chunk=out.stride(0),
stride_out_head=out.stride(1),
stride_out_dim=out.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_initstates_batch=initial_states_strides[0],
stride_initstates_head=initial_states_strides[1],
stride_initstates_dim=initial_states_strides[2],
stride_seq_idx_seqlen=seq_idx.stride(0),
HAS_INITSTATES=initial_states is not None,
HAS_SEQ_IDX=seq_idx is not None,
IS_CONT_BATCHED=is_cont_batched,
)
return out, final_states
return out

View File

@ -35,7 +35,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
mamba_chunk_scan_combined_varlen)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -262,6 +262,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
query_start_loc_p = attn_metadata.query_start_loc_p
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
@ -302,9 +303,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
@ -356,17 +354,17 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
varlen_state = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
dt.unsqueeze(0),
dt,
self.A,
B.view(1, num_prefill_tokens, 1, -1),
C.view(1, num_prefill_tokens, 1, -1),
B.view(num_prefill_tokens, 1, -1),
C.view(num_prefill_tokens, 1, -1),
chunk_size=chunk_size,
D=self.D,
z=gate_p.view(1, num_prefill_tokens,
z=gate_p.view(num_prefill_tokens,
self.num_heads // self.tp_size, self.head_dim),
dt_bias=self.dt_bias,
seq_idx=seq_idx_p,
@ -374,11 +372,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype,
)

View File

@ -115,7 +115,7 @@ class Mamba2AttentionMetadata:
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
query_start_loc_p: torch.Tensor
seq_lens: torch.Tensor
prep_initial_states: bool
@ -151,7 +151,7 @@ class Mamba2AttentionMetadataBuilder(
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> Mamba2AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_p = None
seq_lens = common_attn_metadata.seq_lens
seq_idx_p = None
@ -179,7 +179,7 @@ class Mamba2AttentionMetadataBuilder(
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states_p = has_initial_states_cpu.to(
query_start_loc.device)
common_attn_metadata.query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens
@ -190,7 +190,6 @@ class Mamba2AttentionMetadataBuilder(
device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=num_prefill_tokens)
seq_idx_p.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level
# model forward and reuse them in mamba layers. If not needed,
@ -217,7 +216,7 @@ class Mamba2AttentionMetadataBuilder(
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
query_start_loc_p=query_start_loc_p,
seq_lens=seq_lens,
prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size,