mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:17:07 +08:00
Add SpecDec support to selective_state_update (#29488)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
parent
799804d140
commit
ae0f69b16a
@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
@pytest.mark.parametrize("max_seq_len", [1, 2, 4])
|
||||
def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 5e-2, 1.5e-1
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 4
|
||||
token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
|
||||
total_tokens = int(token_counts.sum().item())
|
||||
cu_seqlens = torch.tensor(
|
||||
[0] + torch.cumsum(token_counts, dim=0).tolist(),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(total_tokens, dstate, device=device)
|
||||
C = torch.randn(total_tokens, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state.detach().clone()
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
out=out,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
out_ref_list = []
|
||||
for seq_idx in range(batch_size):
|
||||
start_idx = cu_seqlens[seq_idx].item()
|
||||
end_idx = cu_seqlens[seq_idx + 1].item()
|
||||
num_tokens = end_idx - start_idx
|
||||
for token_idx in range(num_tokens):
|
||||
idx = start_idx + token_idx
|
||||
out_ref_list.append(
|
||||
selective_state_update_ref(
|
||||
state_ref[seq_idx : seq_idx + 1],
|
||||
x[idx : idx + 1],
|
||||
dt[idx : idx + 1],
|
||||
A,
|
||||
B[idx : idx + 1],
|
||||
C[idx : idx + 1],
|
||||
D=D,
|
||||
z=z[idx : idx + 1] if has_z else None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
)
|
||||
out_ref = torch.cat(out_ref_list, dim=0)
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wtype", [torch.float32])
|
||||
@pytest.mark.parametrize("itype", [torch.float32])
|
||||
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
|
||||
@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
@pytest.mark.parametrize("max_seq_len", [2, 4])
|
||||
def test_selective_state_update_with_num_accepted_tokens(
|
||||
dim, dstate, has_z, itype, max_seq_len
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 5e-2, 1.5e-1
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 4
|
||||
|
||||
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
|
||||
total_tokens = int(tokens_per_seq.sum().item())
|
||||
|
||||
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
|
||||
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
|
||||
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
|
||||
|
||||
cu_seqlens = torch.tensor(
|
||||
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
total_state_slots = 50
|
||||
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
||||
|
||||
state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
initial_state_slots = torch.randint(
|
||||
0, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
for seq_idx in range(batch_size):
|
||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
||||
|
||||
dst_state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
slot_offset = 15
|
||||
dst_slots_map = {}
|
||||
for seq_idx in range(batch_size):
|
||||
for token_idx in range(tokens_per_seq[seq_idx].item()):
|
||||
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
|
||||
dst_slots_map[(seq_idx, token_idx)] = slot_offset
|
||||
slot_offset += 1
|
||||
|
||||
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(total_tokens, dstate, device=device)
|
||||
C = torch.randn(total_tokens, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
|
||||
state_ref_intermediate = {}
|
||||
out_ref_list = []
|
||||
|
||||
for seq_idx in range(batch_size):
|
||||
seq_start = cu_seqlens[seq_idx].item()
|
||||
seq_end = cu_seqlens[seq_idx + 1].item()
|
||||
num_tokens = seq_end - seq_start
|
||||
|
||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||
initial_slot = state_batch_indices[seq_idx, token_pos].item()
|
||||
state_seq = state[initial_slot : initial_slot + 1].clone()
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
global_idx = seq_start + token_idx
|
||||
|
||||
out_token = selective_state_update_ref(
|
||||
state_seq,
|
||||
x[global_idx : global_idx + 1],
|
||||
dt[global_idx : global_idx + 1],
|
||||
A,
|
||||
B[global_idx : global_idx + 1],
|
||||
C[global_idx : global_idx + 1],
|
||||
D=D,
|
||||
z=z[global_idx : global_idx + 1] if has_z else None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
out_ref_list.append(out_token)
|
||||
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
|
||||
|
||||
out_ref = torch.cat(out_ref_list, dim=0)
|
||||
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
out=out,
|
||||
cu_seqlens=cu_seqlens,
|
||||
state_batch_indices=state_batch_indices,
|
||||
dst_state_batch_indices=dst_state_batch_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
for seq_idx in range(batch_size):
|
||||
num_tokens = tokens_per_seq[seq_idx].item()
|
||||
for token_idx in range(num_tokens):
|
||||
dst_slot = dst_slots_map[(seq_idx, token_idx)]
|
||||
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
|
||||
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
@pytest.mark.parametrize("max_seq_len", [2, 4])
|
||||
def test_selective_state_update_varlen_with_num_accepted(
|
||||
dim, dstate, has_z, itype, max_seq_len
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 5e-2, 1.5e-1
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 4
|
||||
|
||||
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
|
||||
total_tokens = int(tokens_per_seq.sum().item())
|
||||
|
||||
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
|
||||
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
|
||||
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
|
||||
|
||||
cu_seqlens = torch.tensor(
|
||||
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
total_state_slots = 50
|
||||
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
||||
|
||||
state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
initial_state_slots = torch.randint(
|
||||
0, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
for seq_idx in range(batch_size):
|
||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
||||
|
||||
dst_state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
slot_offset = 15
|
||||
dst_slots_map = {}
|
||||
for seq_idx in range(batch_size):
|
||||
for token_idx in range(tokens_per_seq[seq_idx].item()):
|
||||
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
|
||||
dst_slots_map[(seq_idx, token_idx)] = slot_offset
|
||||
slot_offset += 1
|
||||
|
||||
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(total_tokens, dstate, device=device)
|
||||
C = torch.randn(total_tokens, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
|
||||
state_ref_intermediate = {}
|
||||
|
||||
for seq_idx in range(batch_size):
|
||||
seq_start = cu_seqlens[seq_idx].item()
|
||||
seq_end = cu_seqlens[seq_idx + 1].item()
|
||||
num_tokens = seq_end - seq_start
|
||||
|
||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||
initial_slot = state_batch_indices[seq_idx, token_pos].item()
|
||||
state_seq = state[initial_slot : initial_slot + 1].clone()
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
global_idx = seq_start + token_idx
|
||||
|
||||
selective_state_update_ref(
|
||||
state_seq,
|
||||
x[global_idx : global_idx + 1],
|
||||
dt[global_idx : global_idx + 1],
|
||||
A,
|
||||
B[global_idx : global_idx + 1],
|
||||
C[global_idx : global_idx + 1],
|
||||
D=D,
|
||||
z=z[global_idx : global_idx + 1] if has_z else None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
|
||||
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
|
||||
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
out=out,
|
||||
cu_seqlens=cu_seqlens,
|
||||
state_batch_indices=state_batch_indices,
|
||||
dst_state_batch_indices=dst_state_batch_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
for seq_idx in range(batch_size):
|
||||
num_tokens = tokens_per_seq[seq_idx].item()
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
dst_slot = dst_slots_map[(seq_idx, token_idx)]
|
||||
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
|
||||
|
||||
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)
|
||||
|
||||
@ -36,10 +36,14 @@ else:
|
||||
is not None
|
||||
}
|
||||
)
|
||||
@triton.heuristics(
|
||||
{"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens_ptr"] is not None}
|
||||
)
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens_ptr"] is not None})
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
||||
)
|
||||
@triton.jit
|
||||
@triton.jit(do_not_specialize=["N"])
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
@ -55,8 +59,10 @@ def _selective_scan_update_kernel(
|
||||
state_batch_indices_ptr,
|
||||
dst_state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
num_accepted_tokens_ptr,
|
||||
cu_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
N,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
@ -91,6 +97,10 @@ def _selective_scan_update_kernel(
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
stride_state_indices_batch,
|
||||
stride_state_indices_T,
|
||||
stride_dst_state_indices_batch,
|
||||
stride_dst_state_indices_T,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
@ -99,22 +109,50 @@ def _selective_scan_update_kernel(
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
|
||||
if IS_VARLEN:
|
||||
bos = tl.load(cu_seqlens_ptr + pid_b).to(tl.int64)
|
||||
eos = tl.load(cu_seqlens_ptr + pid_b + 1).to(tl.int64)
|
||||
seq_len = eos - bos
|
||||
|
||||
if seq_len == 0:
|
||||
return
|
||||
else:
|
||||
bos = pid_b
|
||||
seq_len = 1
|
||||
|
||||
state_ptr_base = state_ptr
|
||||
|
||||
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
dst_state_batch_indices_ptr += pid_b
|
||||
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
|
||||
if IS_SPEC_DECODING:
|
||||
num_accepted = tl.load(num_accepted_tokens_ptr + pid_b).to(tl.int64)
|
||||
init_token_idx = tl.maximum(num_accepted - 1, 0)
|
||||
else:
|
||||
init_token_idx = 0
|
||||
|
||||
dst_state_batch_indices_ptr += pid_b * stride_dst_state_indices_batch
|
||||
if not IS_SPEC_DECODING:
|
||||
dst_state_batch_idx = tl.load(
|
||||
dst_state_batch_indices_ptr
|
||||
+ init_token_idx * stride_dst_state_indices_T
|
||||
).to(tl.int64)
|
||||
dst_state_ptr = state_ptr + (
|
||||
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
)
|
||||
state_batch_indices_ptr += pid_b
|
||||
|
||||
state_batch_indices_ptr += (
|
||||
pid_b * stride_state_indices_batch + init_token_idx * stride_state_indices_T
|
||||
)
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
else:
|
||||
@ -123,45 +161,47 @@ def _selective_scan_update_kernel(
|
||||
)
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||
x_ptr += bos * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += bos * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
||||
B_ptr += bos * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += bos * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
z_ptr += bos * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += bos * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
if not IS_SPEC_DECODING:
|
||||
dst_state_ptrs = dst_state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (
|
||||
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
||||
)
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
||||
|
||||
for i_t in range(seq_len):
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0)
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
@ -171,7 +211,9 @@ def _selective_scan_update_kernel(
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(
|
||||
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
||||
A_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
@ -193,10 +235,21 @@ def _selective_scan_update_kernel(
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
tl.store(dst_state_ptrs, state, mask=mask)
|
||||
if IS_SPEC_DECODING:
|
||||
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
|
||||
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
|
||||
if token_dst_idx != pad_slot_id:
|
||||
token_dst_ptrs = (
|
||||
state_ptr_base
|
||||
+ token_dst_idx * stride_state_batch
|
||||
+ pid_h * stride_state_head
|
||||
+ offs_m[:, None] * stride_state_dim
|
||||
+ offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
tl.store(
|
||||
token_dst_ptrs, state.to(token_dst_ptrs.dtype.element_ty), mask=mask
|
||||
)
|
||||
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
@ -204,6 +257,17 @@ def _selective_scan_update_kernel(
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
x_ptr += stride_x_batch
|
||||
dt_ptr += stride_dt_batch
|
||||
B_ptr += stride_B_batch
|
||||
C_ptr += stride_C_batch
|
||||
out_ptr += stride_out_batch
|
||||
if HAS_Z:
|
||||
z_ptr += stride_z_batch
|
||||
|
||||
if not IS_SPEC_DECODING:
|
||||
tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def selective_state_update(
|
||||
state,
|
||||
@ -220,6 +284,8 @@ def selective_state_update(
|
||||
dst_state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None,
|
||||
num_accepted_tokens=None,
|
||||
cu_seqlens=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
@ -240,6 +306,11 @@ def selective_state_update(
|
||||
indices 0 and 3
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
num_accepted_tokens: (batch,)
|
||||
number of accepted tokens from previous verification step,
|
||||
tells the kernel which initial state to use
|
||||
cu_seqlens: (batch,)
|
||||
length per sequence, for variable length in speculative decoding cases
|
||||
"""
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
@ -261,9 +332,26 @@ def selective_state_update(
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
if num_accepted_tokens is not None:
|
||||
assert state_batch_indices is not None and state_batch_indices.dim() == 2
|
||||
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
|
||||
if state_batch_indices is not None and state_batch_indices.dim() == 1:
|
||||
state_batch_indices = state_batch_indices.unsqueeze(1)
|
||||
if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1:
|
||||
dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1)
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
if cu_seqlens is not None:
|
||||
N = len(cu_seqlens) - 1
|
||||
# Only used to verify the shape of
|
||||
# state_batch_indices and dst_state_batch_indices
|
||||
max_seqlen = (
|
||||
state_batch_indices.size(-1) if state_batch_indices is not None else 1
|
||||
)
|
||||
else:
|
||||
N = batch
|
||||
max_seqlen = 1
|
||||
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
@ -279,16 +367,30 @@ def selective_state_update(
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch,)
|
||||
assert state_batch_indices.shape[0] >= N
|
||||
assert state_batch_indices.shape[1] >= max_seqlen
|
||||
if dst_state_batch_indices is not None:
|
||||
assert dst_state_batch_indices.shape == (batch,)
|
||||
assert dst_state_batch_indices.shape[0] >= N
|
||||
assert dst_state_batch_indices.shape[1] >= max_seqlen
|
||||
else:
|
||||
# revert to the default behavior of in-place state updates
|
||||
dst_state_batch_indices = state_batch_indices
|
||||
assert out.shape == x.shape
|
||||
if num_accepted_tokens is not None:
|
||||
assert num_accepted_tokens.shape == (N,)
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads)
|
||||
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
||||
state_batch_indices_strides = (
|
||||
(state_batch_indices.stride(0), state_batch_indices.stride(1))
|
||||
if state_batch_indices is not None
|
||||
else (0, 0)
|
||||
)
|
||||
dst_state_batch_indices_strides = (
|
||||
(dst_state_batch_indices.stride(0), dst_state_batch_indices.stride(1))
|
||||
if dst_state_batch_indices is not None
|
||||
else (0, 0)
|
||||
)
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = (
|
||||
@ -321,7 +423,9 @@ def selective_state_update(
|
||||
state_batch_indices,
|
||||
dst_state_batch_indices,
|
||||
pad_slot_id,
|
||||
batch,
|
||||
num_accepted_tokens,
|
||||
cu_seqlens,
|
||||
N,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
@ -353,6 +457,10 @@ def selective_state_update(
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
state_batch_indices_strides[0],
|
||||
state_batch_indices_strides[1],
|
||||
dst_state_batch_indices_strides[0],
|
||||
dst_state_batch_indices_strides[1],
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user