mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[Kernel] Chunk-aligned mamba2 (#24683)
This commit is contained in:
parent
61a3431613
commit
fea3e476aa
@ -502,9 +502,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
@ -634,9 +634,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=seq_idx_p,
|
||||
chunk_indices=chunk_indices_p,
|
||||
chunk_offsets=chunk_offsets_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
|
||||
@ -6,8 +6,6 @@
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
@ -96,7 +94,7 @@ def _bmm_chunk_fwd_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
seq_idx_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
@ -112,7 +110,6 @@ def _bmm_chunk_fwd_kernel(
|
||||
stride_out_head: tl.int64,
|
||||
stride_outm: tl.int64,
|
||||
stride_outn: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
dot_dtype: tl.constexpr,
|
||||
@ -129,10 +126,12 @@ def _bmm_chunk_fwd_kernel(
|
||||
if IS_CAUSAL:
|
||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||
return
|
||||
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
||||
|
||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
|
||||
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@ -141,7 +140,7 @@ def _bmm_chunk_fwd_kernel(
|
||||
offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||
offs_n[None, :] * stride_b_seqlen)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
@ -162,16 +161,6 @@ def _bmm_chunk_fwd_kernel(
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
# Zero out the results that are not from the same request
|
||||
# in the varlen batch
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
||||
mask=offs_n < chunk_size_limit,
|
||||
other=-2)
|
||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
||||
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
||||
@ -182,12 +171,18 @@ def _bmm_chunk_fwd_kernel(
|
||||
(offs_n[None, :] < chunk_size))
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
||||
def _bmm_chunk_fwd(a,
|
||||
b,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
causal=False,
|
||||
output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (seqlen, ngroups, k)
|
||||
b: (seqlen, ngroups, k)
|
||||
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
||||
chunk_size: int
|
||||
cu_chunk_seq_lens: (nchunks+1,)
|
||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||
guaranteed to be correct.
|
||||
Return:
|
||||
@ -195,14 +190,12 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
||||
"""
|
||||
seqlen, ngroups, k = a.shape
|
||||
assert b.shape == a.shape
|
||||
assert seq_idx is not None
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
if a.stride(-1) != 1 and a.stride(0) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(-1) != 1 and b.stride(0) != 1:
|
||||
b = b.contiguous()
|
||||
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
nchunks = len(cu_chunk_seqlens) - 1
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
|
||||
@ -220,7 +213,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
out_ptr=out,
|
||||
seq_idx_ptr=seq_idx,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
seqlen=seqlen,
|
||||
chunk_size=chunk_size,
|
||||
K=k,
|
||||
@ -235,7 +228,6 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
||||
stride_out_head=out.stride(1),
|
||||
stride_outm=out.stride(-2),
|
||||
stride_outn=out.stride(-1),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
IS_CAUSAL=causal,
|
||||
dot_dtype=dot_dtype,
|
||||
)
|
||||
|
||||
@ -120,9 +120,7 @@ def _chunk_scan_fwd_kernel(
|
||||
states_ptr,
|
||||
D_ptr,
|
||||
initstates_ptr,
|
||||
chunk_indices_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
chunk_size: tl.constexpr,
|
||||
hdim: tl.constexpr,
|
||||
@ -149,7 +147,7 @@ def _chunk_scan_fwd_kernel(
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
stride_seq_idx_chunk: tl.constexpr,
|
||||
stride_C_seqlen: tl.int64,
|
||||
stride_C_head: tl.int64,
|
||||
stride_C_dstate: tl.constexpr,
|
||||
@ -175,170 +173,107 @@ def _chunk_scan_fwd_kernel(
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
if not HAS_INITSTATES:
|
||||
c_idx = pid_c
|
||||
c_off = 0
|
||||
else:
|
||||
c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
|
||||
c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
|
||||
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
cb_ptr += c_idx * stride_cb_chunk + (pid_h //
|
||||
cb_ptr += pid_c * stride_cb_chunk + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_cb_head
|
||||
x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += c_idx * chunk_size * stride_C_seqlen + (
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += chunk_seqlen_start * stride_C_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
|
||||
seq_idx_ptr += pid_c * stride_seq_idx_chunk
|
||||
seq_idx = tl.load(seq_idx_ptr)
|
||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk,
|
||||
mask=pid_c >= 1,
|
||||
other=-1)
|
||||
|
||||
seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen
|
||||
# - we only need seq_idx_prev to be aligned to chunk boundary
|
||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
|
||||
mask=c_idx >= 1,
|
||||
other=0)
|
||||
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head
|
||||
prev_states_hdim = stride_init_states_hdim
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
else:
|
||||
prev_states_ptr = states_ptr + (
|
||||
pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states, we only need seq_idx_m to point
|
||||
# what is the current seq_idx
|
||||
|
||||
# get current seq idx
|
||||
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
||||
seq_idx_m = tl.load(
|
||||
seq_idx_ptr +
|
||||
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
|
||||
|
||||
# - recall that in ssd_state_passing, for the case c_off == 0
|
||||
# i.e., the very first sequence, we made states_ptr hold its initial state
|
||||
# so this edge case is taken care of
|
||||
if ((c_off == 0) and (seq_idx_prev != seq_idx_m
|
||||
) # if a seq is changed exactly on boundary
|
||||
or (c_off > 0) # implies a new example (pseudo chunk)
|
||||
):
|
||||
|
||||
# - replace prev_states_ptr with init_states
|
||||
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
||||
prev_states_hdim = stride_init_states_hdim # override strides
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
||||
mask=offs_m < chunk_size,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# - handle chunk state limit
|
||||
if HAS_INITSTATES:
|
||||
# have to split this if otherwise compilation will have problems
|
||||
dA_cs_m_boundary = 0.0
|
||||
|
||||
# get the c_idx for the next (logica) chunk
|
||||
c_idx_n = tl.load(
|
||||
chunk_indices_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=-1 # to trigger different chunk
|
||||
)
|
||||
|
||||
# - there are things to consider
|
||||
# A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
|
||||
# contribution of past states
|
||||
# B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
|
||||
# encroach into the next sequence, where c_off_n is the offset of the next
|
||||
# (logical) chunk.
|
||||
# An equivalent check for B is c_idx == c_idx_n, where there is repetition in
|
||||
# (logical) chunk indices.
|
||||
|
||||
if (c_idx == c_idx_n) or c_off > 0:
|
||||
|
||||
# get the next offset
|
||||
c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=chunk_size)
|
||||
|
||||
# in this case, adjust down the chunk_size_limit
|
||||
if c_idx == c_idx_n:
|
||||
chunk_size_limit = min(c_off_n, chunk_size_limit)
|
||||
|
||||
# get the cs at the offset boundary
|
||||
# - c_off == 0 is a passthrough
|
||||
# - We need dA_cs at the boundary, defined by c_off - no need
|
||||
# to increase pointer by pid_m (it is a constant offset,
|
||||
# i.e. the same for all blocks)
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
# - handle seq idx when HAS_INITSTATES==False
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# Without the if (pid_c > -1), with Triton 2.1.0, I get
|
||||
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
|
||||
# With Triton 2.2.0, this works
|
||||
if IS_TRITON_22 or c_idx > -1:
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k_dstate = tl.arange(
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
|
||||
offs_k_dstate[None, :] * stride_C_dstate)
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
prev_states_ptrs = prev_states_ptr + (
|
||||
offs_n[None, :] * prev_states_hdim +
|
||||
offs_k_dstate[:, None] * prev_states_dstate)
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k_dstate = tl.arange(
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
|
||||
offs_k_dstate[None, :] * stride_C_dstate)
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
# - this is for continuous batching where there is no init states
|
||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
|
||||
scale_m = tl.exp(dA_cs_m)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate),
|
||||
other=0.0)
|
||||
|
||||
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
# if no init states AND starting a new sequence, we need zeros
|
||||
prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N),
|
||||
dtype=C_ptr.dtype.element_ty)
|
||||
else:
|
||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
||||
# required.
|
||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
||||
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate),
|
||||
other=0.0)
|
||||
|
||||
# otherwise read the previous state
|
||||
prev_states_ptrs = prev_states_ptr \
|
||||
+ offs_n[None, :] * prev_states_hdim \
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
prev_states = tl.load(prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||
else:
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0)
|
||||
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
|
||||
|
||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||
|
||||
else:
|
||||
prev_states_ptrs = prev_states_ptr \
|
||||
+ offs_n[None, :] * prev_states_hdim \
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0)
|
||||
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K),
|
||||
dtype=C_ptr.dtype.element_ty)
|
||||
else:
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate - k) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc += tl.dot(C, prev_states)
|
||||
C_ptrs += BLOCK_SIZE_K
|
||||
prev_states_ptrs += BLOCK_SIZE_K
|
||||
acc *= scale_m[:, None]
|
||||
acc += tl.dot(C, prev_states)
|
||||
C_ptrs += BLOCK_SIZE_K
|
||||
prev_states_ptrs += BLOCK_SIZE_K
|
||||
acc *= scale_m[:, None]
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
|
||||
offs_k[None, :] * stride_cb_csize_k)
|
||||
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
|
||||
@ -375,7 +310,7 @@ def _chunk_scan_fwd_kernel(
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
if HAS_D:
|
||||
@ -393,7 +328,7 @@ def _chunk_scan_fwd_kernel(
|
||||
acc += x_residual * D
|
||||
|
||||
if HAS_Z:
|
||||
z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
||||
stride_z_hdim * offs_out_n[None, :])
|
||||
z = tl.load(z_ptrs,
|
||||
@ -402,7 +337,7 @@ def _chunk_scan_fwd_kernel(
|
||||
other=0.0).to(tl.float32)
|
||||
acc *= z * tl.sigmoid(z)
|
||||
|
||||
out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
||||
offs_out_n[None, :] * stride_out_hdim)
|
||||
tl.store(out_ptrs,
|
||||
@ -418,12 +353,11 @@ def _chunk_scan_fwd(
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
cu_chunk_seqlens,
|
||||
out,
|
||||
seq_idx,
|
||||
D=None,
|
||||
z=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
initial_states=None,
|
||||
):
|
||||
assert seq_idx is not None, "this implementation requires seq_idx"
|
||||
@ -441,20 +375,10 @@ def _chunk_scan_fwd(
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
assert seq_idx.shape == (nchunks, )
|
||||
|
||||
if initial_states is not None:
|
||||
# with initial states, we need to take care of how
|
||||
# seq_idx crosses the boundaries
|
||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||
"chunk_indices and chunk_offsets should have been set"
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
headdim, META['BLOCK_SIZE_N']), nchunks
|
||||
if chunk_offsets is None else len(chunk_offsets), nheads)
|
||||
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton
|
||||
.cdiv(headdim, META['BLOCK_SIZE_N']), nchunks, nheads)
|
||||
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
@ -476,9 +400,7 @@ def _chunk_scan_fwd(
|
||||
states_ptr=states,
|
||||
D_ptr=D,
|
||||
initstates_ptr=initial_states,
|
||||
chunk_indices_ptr=chunk_indices,
|
||||
chunk_offsets_ptr=chunk_offsets,
|
||||
chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
@ -503,7 +425,7 @@ def _chunk_scan_fwd(
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||
stride_C_seqlen=C.stride(0),
|
||||
stride_C_head=C.stride(1),
|
||||
stride_C_dstate=C.stride(2),
|
||||
|
||||
@ -6,8 +6,6 @@
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
@ -34,6 +32,7 @@ def _chunk_cumsum_fwd_kernel(
|
||||
dt_bias_ptr,
|
||||
dt_out_ptr,
|
||||
dA_cumsum_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimension
|
||||
seqlen,
|
||||
nheads: tl.constexpr,
|
||||
@ -61,7 +60,11 @@ def _chunk_cumsum_fwd_kernel(
|
||||
# https://github.com/triton-lang/triton/issues/1058
|
||||
pid_c = tl.program_id(axis=0).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=1)
|
||||
dt_ptr += pid_c * chunk_size * stride_dt_seqlen
|
||||
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
|
||||
dt_ptr += chunk_seqlen_start * stride_dt_seqlen
|
||||
dt_out_ptr += pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
||||
|
||||
@ -74,7 +77,7 @@ def _chunk_cumsum_fwd_kernel(
|
||||
offs_c[None, :] * stride_dt_out_csize)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
|
||||
offs_c[None, :] * stride_dA_cs_csize)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
dt = tl.load(dt_ptrs,
|
||||
mask=(offs_h[:, None] < nheads) &
|
||||
@ -188,7 +191,7 @@ def _chunk_state_fwd_kernel(
|
||||
states_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
@ -212,7 +215,6 @@ def _chunk_state_fwd_kernel(
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
@ -223,14 +225,14 @@ def _chunk_state_fwd_kernel(
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
b_ptr += chunk_seqlen_start * stride_b_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
|
||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
@ -243,10 +245,7 @@ def _chunk_state_fwd_kernel(
|
||||
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
seq_idx_last = tl.load(seq_idx_ptr +
|
||||
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
@ -261,15 +260,9 @@ def _chunk_state_fwd_kernel(
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
seq_idx_k = tl.load(seq_idx_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=-1)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
scale = tl.where(seq_idx_k == seq_idx_last,
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
@ -278,7 +271,6 @@ def _chunk_state_fwd_kernel(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
@ -534,6 +526,7 @@ def _chunk_state_varlen_kernel(
|
||||
def _chunk_cumsum_fwd(dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf"))):
|
||||
@ -541,7 +534,7 @@ def _chunk_cumsum_fwd(dt,
|
||||
assert A.shape == (nheads, )
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, )
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
nchunks = cu_chunk_seqlens.shape[0] - 1
|
||||
dt_out = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
@ -561,6 +554,7 @@ def _chunk_cumsum_fwd(dt,
|
||||
dt_bias_ptr=dt_bias,
|
||||
dt_out_ptr=dt_out,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
seqlen=seqlen,
|
||||
nheads=nheads,
|
||||
chunk_size=chunk_size,
|
||||
@ -588,7 +582,7 @@ def _chunk_state_fwd(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx=None,
|
||||
cu_chunk_seqlens,
|
||||
states=None,
|
||||
states_in_fp32=True):
|
||||
seqlen, nheads, headdim = x.shape
|
||||
@ -599,9 +593,6 @@ def _chunk_state_fwd(B,
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
|
||||
assert seq_idx is not None
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
|
||||
if states is not None:
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
@ -619,7 +610,7 @@ def _chunk_state_fwd(B,
|
||||
states_ptr=states,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
@ -641,7 +632,6 @@ def _chunk_state_fwd(B,
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
@ -14,8 +14,7 @@ from vllm.triton_utils import triton
|
||||
|
||||
from .ssd_bmm import _bmm_chunk_fwd
|
||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
|
||||
chunk_state_varlen)
|
||||
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
|
||||
from .ssd_state_passing import _state_passing_fwd
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
@ -37,9 +36,9 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
cu_chunk_seqlens=None,
|
||||
last_chunk_indices=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None):
|
||||
@ -56,7 +55,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1, )
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
@ -89,6 +88,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit)
|
||||
@ -99,36 +99,31 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx=seq_idx,
|
||||
cu_chunk_seqlens,
|
||||
states_in_fp32=True)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states
|
||||
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
|
||||
# - for handling chunked prefill, this requires i) initial_states and
|
||||
# ii) seq_idx to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
# - We will also make sure that the dA_cumsum is taken only from the start of the
|
||||
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
|
||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
||||
# or equal to init_states of the first example.
|
||||
states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||
cu_chunk_seqlens,
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None else
|
||||
None, # (batch, nheads, headdim*dstate)
|
||||
seq_idx=seq_idx,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
chunk_offsets=chunk_offsets)
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype)
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C,
|
||||
B,
|
||||
chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
cu_chunk_seqlens,
|
||||
output_dtype=torch.float32)
|
||||
|
||||
# 5. Scan and compute the diagonal blocks, taking into
|
||||
@ -148,26 +143,15 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
cu_chunk_seqlens,
|
||||
out, # in-place update
|
||||
seq_idx,
|
||||
D=D,
|
||||
z=z,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
varlen_states = chunk_state_varlen(
|
||||
B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_seqlens,
|
||||
states,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
return varlen_states
|
||||
return states[last_chunk_indices]
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined_varlen(
|
||||
@ -178,14 +162,14 @@ def mamba_chunk_scan_combined_varlen(
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
cu_chunk_seqlens,
|
||||
last_chunk_indices,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
@ -198,8 +182,10 @@ def mamba_chunk_scan_combined_varlen(
|
||||
B: (seqlen, ngroups, dstate)
|
||||
C: (seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
seq_idx: (seqlen)
|
||||
cu_seqlens: (batch + 1)
|
||||
cu_seqlens: (batch + 1,)
|
||||
cu_chunk_seqlens: (nchunks + 1,)
|
||||
last_chunk_indices: (batch,)
|
||||
seq_idx: (nchunks,)
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (seqlen, nheads, headdim)
|
||||
@ -228,9 +214,9 @@ def mamba_chunk_scan_combined_varlen(
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
state_dtype=state_dtype)
|
||||
|
||||
@ -30,8 +30,7 @@ def _state_passing_fwd_kernel(
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
dim: tl.constexpr,
|
||||
nchunks,
|
||||
@ -50,94 +49,52 @@ def _state_passing_fwd_kernel(
|
||||
stride_initstates_batch: tl.int64,
|
||||
stride_initstates_head: tl.int64,
|
||||
stride_initstates_dim: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
stride_seq_idx_chunk: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_h = tl.program_id(axis=1)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
|
||||
states_ptr += pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
|
||||
1) * stride_dA_cs_csize
|
||||
out_ptr += pid_h * stride_out_head
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptr += pid_h * stride_initstates_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
# - states will be the past state of the sequence that continues on the current check
|
||||
if not HAS_INITSTATES:
|
||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
else:
|
||||
initstates_ptr += offs_m * stride_initstates_dim
|
||||
initstates_ptrs = initstates_ptr
|
||||
# - for cont batches, for the first chunk mean it will be the first batch's
|
||||
# init state
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = initstates_ptr \
|
||||
+ pid_h * stride_initstates_head \
|
||||
+ offs_m * stride_initstates_dim
|
||||
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
out_ptrs += stride_out_chunk
|
||||
prev_seq_idx_chunk_end = 0
|
||||
logical_chunk_idx = 0
|
||||
for c in range(nchunks - 1):
|
||||
prev_seq_idx = 0
|
||||
for c in range(nchunks):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale_mask = True
|
||||
# - the seq to pass forward is the one that is flushed to the right
|
||||
# boundary.
|
||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||
seq_idx_chunk_end = tl.load(seq_idx_ptr +
|
||||
(min((c + 1) * chunk_size, seqlen) - 1) *
|
||||
stride_seq_idx_seqlen)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
if prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||
# this means in the current chunk the rightmost flushed seq
|
||||
# has changed.
|
||||
# - so we do not propagate the state from previous chunk
|
||||
# - but rather we load that sequence's init state
|
||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||
|
||||
# - update state with seq_idx_new's init state
|
||||
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
|
||||
# we have started a new sequence
|
||||
if prev_seq_idx != seq_idx:
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \
|
||||
+ pid_h * stride_initstates_head \
|
||||
+ offs_m * stride_initstates_dim
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
|
||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||
# - find its starting position (given by c_off of the logical chunk index)
|
||||
# - and subtract the cumsum just before that position from the total cumsum
|
||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||
# sequence index at the start of the current chunk
|
||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
||||
min(c * chunk_size, seqlen) *
|
||||
stride_seq_idx_seqlen)
|
||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
||||
# - load the chunk offset:
|
||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
||||
mask=logical_chunk_idx < chunk_meta_num,
|
||||
other=0)
|
||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||
if c_off > 0:
|
||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
||||
(c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||
other=0.0)
|
||||
dA_cs -= dA_cs_boundary
|
||||
|
||||
# - increment logical chunk index for every physical chunk
|
||||
logical_chunk_idx += 1
|
||||
else:
|
||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||
|
||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
||||
states = scale * states + new_states
|
||||
prev_seq_idx = seq_idx
|
||||
states = tl.exp(dA_cs) * states + new_states
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
|
||||
states_ptrs += stride_states_chunk
|
||||
@ -148,8 +105,8 @@ def _state_passing_fwd_kernel(
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_cumsum,
|
||||
cu_chunk_seqlens,
|
||||
seq_idx,
|
||||
chunk_offsets,
|
||||
initial_states=None,
|
||||
out_dtype=None,
|
||||
):
|
||||
@ -175,9 +132,7 @@ def _state_passing_fwd(
|
||||
dA_cs_ptr=dA_cumsum,
|
||||
initstates_ptr=initial_states,
|
||||
seq_idx_ptr=seq_idx,
|
||||
chunk_offsets_ptr=chunk_offsets,
|
||||
chunk_meta_num=len(chunk_offsets)
|
||||
if chunk_offsets is not None else 0,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
dim=dim,
|
||||
nchunks=nchunks,
|
||||
seqlen=seqlen if seq_idx is not None else 0,
|
||||
@ -194,7 +149,7 @@ def _state_passing_fwd(
|
||||
stride_initstates_batch=initial_states_strides[0],
|
||||
stride_initstates_head=initial_states_strides[1],
|
||||
stride_initstates_dim=initial_states_strides[2],
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out
|
||||
|
||||
@ -260,9 +260,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
@ -368,9 +368,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
self.num_heads // self.tp_size, self.head_dim),
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=seq_idx_p,
|
||||
chunk_indices=chunk_indices_p,
|
||||
chunk_offsets=chunk_offsets_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@ -8,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
||||
@ -17,91 +17,6 @@ from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc: torch.Tensor, chunk_size: int,
|
||||
total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
|
||||
lengths, shape (num_seqs + 1,).
|
||||
The first element should be 0. Each entry represents the starting
|
||||
index of a sequence in the flattened token array.
|
||||
chunk_size (int): The size of each physical mamba chunk
|
||||
(number of tokens per chunk).
|
||||
total_seqlens (int): The total number of tokens in the batch.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- chunk_indices (torch.Tensor): 1D tensor of indices
|
||||
indicating the physical chunk for each logical chunk.
|
||||
- chunk_offsets (torch.Tensor): 1D tensor of offsets
|
||||
indicating the starting index of each logical chunk within
|
||||
its physical chunk.
|
||||
|
||||
This function computes the chunk indices and offsets for the given
|
||||
query_start_loc and chunk_size. Both are tensors of integers with length N,
|
||||
where N is the number of logical (pseudo) chunks.
|
||||
A logical chunk is a sequence of tokens that are all part of the same
|
||||
sequence and are all in the same physical mamba chunk.
|
||||
In other words, a logical chunk changes every time we cross a sequence
|
||||
boundary or a physical mamba chunk boundary.
|
||||
Logical chunks are needed to handle batched requests with initial states
|
||||
(see _state_passing_fwd and _chunk_scan_fwd).
|
||||
The chunk_indices tensor contains the index of the physical chunk for each
|
||||
logical chunk.
|
||||
The chunk_offsets tensor contains the offset (AKA starting index) of the
|
||||
logical chunk in the physical chunk.
|
||||
|
||||
Example:
|
||||
query_start_loc = [0, 5, 10]
|
||||
chunk_size = 8
|
||||
total_seqlens = 10
|
||||
-> chunk_indices = [0, 0, 1]
|
||||
-> chunk_offsets = [0, 5, 0]
|
||||
|
||||
In this example, we have 2 sequences, each with 5 tokens. The physical
|
||||
chunk size is 8 tokens.
|
||||
We have three logical chunks:
|
||||
- the first logical chunk starts at token 0 in the first physical chunk
|
||||
and contains all 5 tokens from the first sequence
|
||||
- the second logical chunk starts at token 5 in the first physical chunk
|
||||
and contains first 3 tokens from the second sequence
|
||||
- the third logical chunk starts at token 0 in the second physical chunk
|
||||
and contains the remaining 2 tokens from the second sequence
|
||||
"""
|
||||
|
||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||
|
||||
# outputs will have length expansion of chunks that do not divide
|
||||
# chunk_size
|
||||
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
||||
> 0).sum()
|
||||
chunk_indices = torch.arange(N,
|
||||
dtype=torch.int,
|
||||
device=query_start_loc.device)
|
||||
chunk_offsets = torch.zeros((N, ),
|
||||
dtype=torch.int,
|
||||
device=query_start_loc.device)
|
||||
|
||||
p = 0 # num of insertions
|
||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||
|
||||
# if does not divide chunk_size, then there is one chunk insertion
|
||||
p += (s % chunk_size > 0)
|
||||
|
||||
# get the dimensions
|
||||
# - the + 1 for _e is to shift the boundary by one chunk
|
||||
# - this shifting is not needed if chunk_size divides e
|
||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
||||
> 0)
|
||||
|
||||
# adjust indices and offsets
|
||||
chunk_indices[_s:_e] -= p
|
||||
chunk_offsets[_s] = s % chunk_size
|
||||
|
||||
return chunk_indices, chunk_offsets
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
@ -125,8 +40,16 @@ class Mamba2AttentionMetadata:
|
||||
# the batch has no prefill request.
|
||||
has_initial_states_p: Optional[torch.Tensor]
|
||||
seq_idx_p: Optional[torch.Tensor]
|
||||
chunk_indices_p: Optional[torch.Tensor]
|
||||
chunk_offsets_p: Optional[torch.Tensor]
|
||||
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offests into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: Optional[torch.Tensor]
|
||||
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: Optional[torch.Tensor]
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
@ -151,13 +74,14 @@ class Mamba2AttentionMetadataBuilder(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc_p = None
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
query_start_loc_p = None
|
||||
seq_idx_p = None
|
||||
chunk_indices_p, chunk_offsets_p = None, None
|
||||
cu_chunk_seqlen_p = None
|
||||
last_chunk_indices_p = None
|
||||
|
||||
# Need flags to indicate if there are initial states
|
||||
# currently we really only support the FlashAttention backend
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
@ -171,7 +95,7 @@ class Mamba2AttentionMetadataBuilder(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold))
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
# Compute seq_idx for prefill only
|
||||
if num_prefills > 0:
|
||||
#[batch,]
|
||||
has_initial_states_cpu = (
|
||||
@ -184,21 +108,68 @@ class Mamba2AttentionMetadataBuilder(
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
-num_prefills - 1:] - num_decode_tokens
|
||||
|
||||
seq_idx_p = torch.repeat_interleave(torch.arange(
|
||||
num_prefills,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
num_computed_tokens_p = \
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills:num_reqs]
|
||||
query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[
|
||||
-num_prefills - 1:] - num_decode_tokens
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level
|
||||
# model forward and reuse them in mamba layers. If not needed,
|
||||
# they will be ignored inside mamba kernels.
|
||||
if prep_initial_states:
|
||||
chunk_indices_p, chunk_offsets_p = (
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc_p, self.chunk_size,
|
||||
num_prefill_tokens))
|
||||
# The code below carefully constructs the chunks such that:
|
||||
# 1. Chunks contain tokens from a *single* sequence only.
|
||||
# 2. For every sequence, we are guaranteed that we can
|
||||
# retrieve the mamba state *every* chunk_size tokens.
|
||||
# Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
# Constraint (2) dramatically simplifies the implementation
|
||||
# of prefix caching for mamba2 (wip). We need to take care
|
||||
# of the interaction with chunked prefill in order to
|
||||
# satisfy constraint (2).
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p[req_idx].item()
|
||||
this_new_tokens = query_start_loc_p_cpu[req_idx + 1].item(
|
||||
) - query_start_loc_p_cpu[req_idx].item()
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = cdiv(this_num_computed, self.chunk_size
|
||||
) * self.chunk_size - this_num_computed
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
seq_idx_p = torch.as_tensor(seq_idx,
|
||||
device=query_start_loc_p.device,
|
||||
dtype=torch.int32)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=query_start_loc_p.device,
|
||||
dtype=torch.int32)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=query_start_loc_p.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
@ -222,9 +193,9 @@ class Mamba2AttentionMetadataBuilder(
|
||||
chunk_size=self.chunk_size,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
seq_idx_p=seq_idx_p,
|
||||
chunk_indices_p=chunk_indices_p,
|
||||
chunk_offsets_p=chunk_offsets_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user