Add SpecDec support to selective_state_update (#29488)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755 2025-12-08 23:45:18 +02:00 committed by GitHub
parent 799804d140
commit ae0f69b16a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 505 additions and 72 deletions

View File

@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("max_seq_len", [1, 2, 4])
def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
# set seed
current_platform.seed_everything(0)
batch_size = 4
token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(token_counts.sum().item())
cu_seqlens = torch.tensor(
[0] + torch.cumsum(token_counts, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
)
out_ref_list = []
for seq_idx in range(batch_size):
start_idx = cu_seqlens[seq_idx].item()
end_idx = cu_seqlens[seq_idx + 1].item()
num_tokens = end_idx - start_idx
for token_idx in range(num_tokens):
idx = start_idx + token_idx
out_ref_list.append(
selective_state_update_ref(
state_ref[seq_idx : seq_idx + 1],
x[idx : idx + 1],
dt[idx : idx + 1],
A,
B[idx : idx + 1],
C[idx : idx + 1],
D=D,
z=z[idx : idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
)
out_ref = torch.cat(out_ref_list, dim=0)
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("wtype", [torch.float32])
@pytest.mark.parametrize("itype", [torch.float32])
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.parametrize("max_seq_len", [2, 4])
def test_selective_state_update_with_num_accepted_tokens(
dim, dstate, has_z, itype, max_seq_len
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
current_platform.seed_everything(0)
batch_size = 4
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(tokens_per_seq.sum().item())
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
cu_seqlens = torch.tensor(
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
total_state_slots = 50
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
slot_offset = 15
dst_slots_map = {}
for seq_idx in range(batch_size):
for token_idx in range(tokens_per_seq[seq_idx].item()):
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
dst_slots_map[(seq_idx, token_idx)] = slot_offset
slot_offset += 1
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref_intermediate = {}
out_ref_list = []
for seq_idx in range(batch_size):
seq_start = cu_seqlens[seq_idx].item()
seq_end = cu_seqlens[seq_idx + 1].item()
num_tokens = seq_end - seq_start
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
initial_slot = state_batch_indices[seq_idx, token_pos].item()
state_seq = state[initial_slot : initial_slot + 1].clone()
for token_idx in range(num_tokens):
global_idx = seq_start + token_idx
out_token = selective_state_update_ref(
state_seq,
x[global_idx : global_idx + 1],
dt[global_idx : global_idx + 1],
A,
B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1],
D=D,
z=z[global_idx : global_idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
out_ref_list.append(out_token)
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
out_ref = torch.cat(out_ref_list, dim=0)
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
for seq_idx in range(batch_size):
num_tokens = tokens_per_seq[seq_idx].item()
for token_idx in range(num_tokens):
dst_slot = dst_slots_map[(seq_idx, token_idx)]
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.parametrize("max_seq_len", [2, 4])
def test_selective_state_update_varlen_with_num_accepted(
dim, dstate, has_z, itype, max_seq_len
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
current_platform.seed_everything(0)
batch_size = 4
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(tokens_per_seq.sum().item())
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
cu_seqlens = torch.tensor(
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
total_state_slots = 50
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
slot_offset = 15
dst_slots_map = {}
for seq_idx in range(batch_size):
for token_idx in range(tokens_per_seq[seq_idx].item()):
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
dst_slots_map[(seq_idx, token_idx)] = slot_offset
slot_offset += 1
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref_intermediate = {}
for seq_idx in range(batch_size):
seq_start = cu_seqlens[seq_idx].item()
seq_end = cu_seqlens[seq_idx + 1].item()
num_tokens = seq_end - seq_start
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
initial_slot = state_batch_indices[seq_idx, token_pos].item()
state_seq = state[initial_slot : initial_slot + 1].clone()
for token_idx in range(num_tokens):
global_idx = seq_start + token_idx
selective_state_update_ref(
state_seq,
x[global_idx : global_idx + 1],
dt[global_idx : global_idx + 1],
A,
B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1],
D=D,
z=z[global_idx : global_idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
for seq_idx in range(batch_size):
num_tokens = tokens_per_seq[seq_idx].item()
for token_idx in range(num_tokens):
dst_slot = dst_slots_map[(seq_idx, token_idx)]
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)

View File

@ -36,10 +36,14 @@ else:
is not None
}
)
@triton.heuristics(
{"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens_ptr"] is not None}
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens_ptr"] is not None})
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
)
@triton.jit
@triton.jit(do_not_specialize=["N"])
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,
@ -55,8 +59,10 @@ def _selective_scan_update_kernel(
state_batch_indices_ptr,
dst_state_batch_indices_ptr,
pad_slot_id,
num_accepted_tokens_ptr,
cu_seqlens_ptr,
# Matrix dimensions
batch,
N,
nheads,
dim,
dstate,
@ -91,6 +97,10 @@ def _selective_scan_update_kernel(
stride_out_batch,
stride_out_head,
stride_out_dim,
stride_state_indices_batch,
stride_state_indices_T,
stride_dst_state_indices_batch,
stride_dst_state_indices_T,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
TIE_HDIM: tl.constexpr,
@ -99,22 +109,50 @@ def _selective_scan_update_kernel(
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_VARLEN: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
if IS_VARLEN:
bos = tl.load(cu_seqlens_ptr + pid_b).to(tl.int64)
eos = tl.load(cu_seqlens_ptr + pid_b + 1).to(tl.int64)
seq_len = eos - bos
if seq_len == 0:
return
else:
bos = pid_b
seq_len = 1
state_ptr_base = state_ptr
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
dst_state_batch_indices_ptr += pid_b
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
dst_state_ptr = state_ptr + (
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
if IS_SPEC_DECODING:
num_accepted = tl.load(num_accepted_tokens_ptr + pid_b).to(tl.int64)
init_token_idx = tl.maximum(num_accepted - 1, 0)
else:
init_token_idx = 0
dst_state_batch_indices_ptr += pid_b * stride_dst_state_indices_batch
if not IS_SPEC_DECODING:
dst_state_batch_idx = tl.load(
dst_state_batch_indices_ptr
+ init_token_idx * stride_dst_state_indices_T
).to(tl.int64)
dst_state_ptr = state_ptr + (
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
)
state_batch_indices_ptr += (
pid_b * stride_state_indices_batch + init_token_idx * stride_state_indices_T
)
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
@ -123,86 +161,112 @@ def _selective_scan_update_kernel(
)
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
x_ptr += bos * stride_x_batch + pid_h * stride_x_head
dt_ptr += bos * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
dt_bias_ptr += pid_h * stride_dt_bias_head
A_ptr += pid_h * stride_A_head
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
B_ptr += bos * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
C_ptr += bos * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
if HAS_Z:
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
z_ptr += bos * stride_z_batch + pid_h * stride_z_head
out_ptr += bos * stride_out_batch + pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
dst_state_ptrs = dst_state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if not IS_SPEC_DECODING:
dst_state_ptrs = dst_state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
if HAS_D:
D_ptr += pid_h * stride_D_head
A_ptrs = A_ptr + (
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0)
A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
for i_t in range(seq_len):
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
tl.store(dst_state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
if IS_SPEC_DECODING:
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
if token_dst_idx != pad_slot_id:
token_dst_ptrs = (
state_ptr_base
+ token_dst_idx * stride_state_batch
+ pid_h * stride_state_head
+ offs_m[:, None] * stride_state_dim
+ offs_n[None, :] * stride_state_dstate
)
tl.store(
token_dst_ptrs, state.to(token_dst_ptrs.dtype.element_ty), mask=mask
)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
x_ptr += stride_x_batch
dt_ptr += stride_dt_batch
B_ptr += stride_B_batch
C_ptr += stride_C_batch
out_ptr += stride_out_batch
if HAS_Z:
z_ptr += stride_z_batch
if not IS_SPEC_DECODING:
tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask)
def selective_state_update(
@ -220,6 +284,8 @@ def selective_state_update(
dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None,
num_accepted_tokens=None,
cu_seqlens=None,
):
"""
Argument:
@ -240,6 +306,11 @@ def selective_state_update(
indices 0 and 3
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
num_accepted_tokens: (batch,)
number of accepted tokens from previous verification step,
tells the kernel which initial state to use
cu_seqlens: (batch,)
length per sequence, for variable length in speculative decoding cases
"""
if state.dim() == 3:
state = state.unsqueeze(1)
@ -261,9 +332,26 @@ def selective_state_update(
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
if num_accepted_tokens is not None:
assert state_batch_indices is not None and state_batch_indices.dim() == 2
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
if state_batch_indices is not None and state_batch_indices.dim() == 1:
state_batch_indices = state_batch_indices.unsqueeze(1)
if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1:
dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1)
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
if cu_seqlens is not None:
N = len(cu_seqlens) - 1
# Only used to verify the shape of
# state_batch_indices and dst_state_batch_indices
max_seqlen = (
state_batch_indices.size(-1) if state_batch_indices is not None else 1
)
else:
N = batch
max_seqlen = 1
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
@ -279,16 +367,30 @@ def selective_state_update(
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch,)
assert state_batch_indices.shape[0] >= N
assert state_batch_indices.shape[1] >= max_seqlen
if dst_state_batch_indices is not None:
assert dst_state_batch_indices.shape == (batch,)
assert dst_state_batch_indices.shape[0] >= N
assert dst_state_batch_indices.shape[1] >= max_seqlen
else:
# revert to the default behavior of in-place state updates
dst_state_batch_indices = state_batch_indices
assert out.shape == x.shape
if num_accepted_tokens is not None:
assert num_accepted_tokens.shape == (N,)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads)
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
state_batch_indices_strides = (
(state_batch_indices.stride(0), state_batch_indices.stride(1))
if state_batch_indices is not None
else (0, 0)
)
dst_state_batch_indices_strides = (
(dst_state_batch_indices.stride(0), dst_state_batch_indices.stride(1))
if dst_state_batch_indices is not None
else (0, 0)
)
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
@ -321,7 +423,9 @@ def selective_state_update(
state_batch_indices,
dst_state_batch_indices,
pad_slot_id,
batch,
num_accepted_tokens,
cu_seqlens,
N,
nheads,
dim,
dstate,
@ -353,6 +457,10 @@ def selective_state_update(
out.stride(0),
out.stride(1),
out.stride(2),
state_batch_indices_strides[0],
state_batch_indices_strides[1],
dst_state_batch_indices_strides[0],
dst_state_batch_indices_strides[1],
dt_softplus,
tie_hdim,
BLOCK_SIZE_M,