mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:35:50 +08:00
[Kernel] Chunk-aligned mamba2 (#24683)
This commit is contained in:
parent
61a3431613
commit
fea3e476aa
@ -502,9 +502,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
prep_initial_states = attn_metadata.prep_initial_states
|
prep_initial_states = attn_metadata.prep_initial_states
|
||||||
chunk_size = attn_metadata.chunk_size
|
chunk_size = attn_metadata.chunk_size
|
||||||
seq_idx_p = attn_metadata.seq_idx_p
|
seq_idx_p = attn_metadata.seq_idx_p
|
||||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
|
||||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
|
||||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||||
|
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||||
|
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states, _ = self.in_proj(hidden_states)
|
projected_states, _ = self.in_proj(hidden_states)
|
||||||
@ -634,9 +634,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
z=None,
|
z=None,
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
seq_idx=seq_idx_p,
|
seq_idx=seq_idx_p,
|
||||||
chunk_indices=chunk_indices_p,
|
|
||||||
chunk_offsets=chunk_offsets_p,
|
|
||||||
cu_seqlens=query_start_loc_p,
|
cu_seqlens=query_start_loc_p,
|
||||||
|
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||||
|
last_chunk_indices=last_chunk_indices_p,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
|
|||||||
@ -6,8 +6,6 @@
|
|||||||
|
|
||||||
# ruff: noqa: E501,SIM102
|
# ruff: noqa: E501,SIM102
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
@ -96,7 +94,7 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
a_ptr,
|
a_ptr,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
seq_idx_ptr,
|
cu_chunk_seqlens_ptr,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
seqlen,
|
seqlen,
|
||||||
chunk_size: tl.constexpr,
|
chunk_size: tl.constexpr,
|
||||||
@ -112,7 +110,6 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
stride_out_head: tl.int64,
|
stride_out_head: tl.int64,
|
||||||
stride_outm: tl.int64,
|
stride_outm: tl.int64,
|
||||||
stride_outn: tl.constexpr,
|
stride_outn: tl.constexpr,
|
||||||
stride_seq_idx_seqlen: tl.constexpr,
|
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
IS_CAUSAL: tl.constexpr,
|
IS_CAUSAL: tl.constexpr,
|
||||||
dot_dtype: tl.constexpr,
|
dot_dtype: tl.constexpr,
|
||||||
@ -129,10 +126,12 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
if IS_CAUSAL:
|
if IS_CAUSAL:
|
||||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||||
return
|
return
|
||||||
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
|
||||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
|
||||||
|
|
||||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||||
|
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||||
|
|
||||||
|
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
|
||||||
|
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
|
||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
@ -141,7 +140,7 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
offs_k[None, :] * stride_ak)
|
offs_k[None, :] * stride_ak)
|
||||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||||
offs_n[None, :] * stride_b_seqlen)
|
offs_n[None, :] * stride_b_seqlen)
|
||||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||||
|
|
||||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
@ -162,16 +161,6 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
|
||||||
# Zero out the results that are not from the same request
|
|
||||||
# in the varlen batch
|
|
||||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
|
||||||
mask=offs_m < chunk_size_limit,
|
|
||||||
other=-1)
|
|
||||||
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
|
||||||
mask=offs_n < chunk_size_limit,
|
|
||||||
other=-2)
|
|
||||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
|
||||||
|
|
||||||
out = acc.to(out_ptr.dtype.element_ty)
|
out = acc.to(out_ptr.dtype.element_ty)
|
||||||
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
||||||
@ -182,12 +171,18 @@ def _bmm_chunk_fwd_kernel(
|
|||||||
(offs_n[None, :] < chunk_size))
|
(offs_n[None, :] < chunk_size))
|
||||||
|
|
||||||
|
|
||||||
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
def _bmm_chunk_fwd(a,
|
||||||
|
b,
|
||||||
|
chunk_size,
|
||||||
|
cu_chunk_seqlens,
|
||||||
|
causal=False,
|
||||||
|
output_dtype=None):
|
||||||
"""
|
"""
|
||||||
Argument:
|
Argument:
|
||||||
a: (seqlen, ngroups, k)
|
a: (seqlen, ngroups, k)
|
||||||
b: (seqlen, ngroups, k)
|
b: (seqlen, ngroups, k)
|
||||||
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
chunk_size: int
|
||||||
|
cu_chunk_seq_lens: (nchunks+1,)
|
||||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||||
guaranteed to be correct.
|
guaranteed to be correct.
|
||||||
Return:
|
Return:
|
||||||
@ -195,14 +190,12 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
|||||||
"""
|
"""
|
||||||
seqlen, ngroups, k = a.shape
|
seqlen, ngroups, k = a.shape
|
||||||
assert b.shape == a.shape
|
assert b.shape == a.shape
|
||||||
assert seq_idx is not None
|
|
||||||
assert seq_idx.shape == (seqlen, )
|
|
||||||
if a.stride(-1) != 1 and a.stride(0) != 1:
|
if a.stride(-1) != 1 and a.stride(0) != 1:
|
||||||
a = a.contiguous()
|
a = a.contiguous()
|
||||||
if b.stride(-1) != 1 and b.stride(0) != 1:
|
if b.stride(-1) != 1 and b.stride(0) != 1:
|
||||||
b = b.contiguous()
|
b = b.contiguous()
|
||||||
|
|
||||||
nchunks = math.ceil(seqlen / chunk_size)
|
nchunks = len(cu_chunk_seqlens) - 1
|
||||||
# Allocates output.
|
# Allocates output.
|
||||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||||
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
|
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
|
||||||
@ -220,7 +213,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
|||||||
a_ptr=a,
|
a_ptr=a,
|
||||||
b_ptr=b,
|
b_ptr=b,
|
||||||
out_ptr=out,
|
out_ptr=out,
|
||||||
seq_idx_ptr=seq_idx,
|
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
K=k,
|
K=k,
|
||||||
@ -235,7 +228,6 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
|||||||
stride_out_head=out.stride(1),
|
stride_out_head=out.stride(1),
|
||||||
stride_outm=out.stride(-2),
|
stride_outm=out.stride(-2),
|
||||||
stride_outn=out.stride(-1),
|
stride_outn=out.stride(-1),
|
||||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
|
||||||
IS_CAUSAL=causal,
|
IS_CAUSAL=causal,
|
||||||
dot_dtype=dot_dtype,
|
dot_dtype=dot_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -120,9 +120,7 @@ def _chunk_scan_fwd_kernel(
|
|||||||
states_ptr,
|
states_ptr,
|
||||||
D_ptr,
|
D_ptr,
|
||||||
initstates_ptr,
|
initstates_ptr,
|
||||||
chunk_indices_ptr,
|
cu_chunk_seqlens_ptr,
|
||||||
chunk_offsets_ptr,
|
|
||||||
chunk_meta_num,
|
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
chunk_size: tl.constexpr,
|
chunk_size: tl.constexpr,
|
||||||
hdim: tl.constexpr,
|
hdim: tl.constexpr,
|
||||||
@ -149,7 +147,7 @@ def _chunk_scan_fwd_kernel(
|
|||||||
stride_dA_cs_chunk: tl.int64,
|
stride_dA_cs_chunk: tl.int64,
|
||||||
stride_dA_cs_head: tl.int64,
|
stride_dA_cs_head: tl.int64,
|
||||||
stride_dA_cs_csize: tl.constexpr,
|
stride_dA_cs_csize: tl.constexpr,
|
||||||
stride_seq_idx_seqlen: tl.constexpr,
|
stride_seq_idx_chunk: tl.constexpr,
|
||||||
stride_C_seqlen: tl.int64,
|
stride_C_seqlen: tl.int64,
|
||||||
stride_C_head: tl.int64,
|
stride_C_head: tl.int64,
|
||||||
stride_C_dstate: tl.constexpr,
|
stride_C_dstate: tl.constexpr,
|
||||||
@ -175,170 +173,107 @@ def _chunk_scan_fwd_kernel(
|
|||||||
HAS_INITSTATES: tl.constexpr,
|
HAS_INITSTATES: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||||
if not HAS_INITSTATES:
|
|
||||||
c_idx = pid_c
|
|
||||||
c_off = 0
|
|
||||||
else:
|
|
||||||
c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
|
|
||||||
c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
|
|
||||||
|
|
||||||
pid_h = tl.program_id(axis=2)
|
pid_h = tl.program_id(axis=2)
|
||||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||||
cb_ptr += c_idx * stride_cb_chunk + (pid_h //
|
cb_ptr += pid_c * stride_cb_chunk + (pid_h //
|
||||||
nheads_ngroups_ratio) * stride_cb_head
|
nheads_ngroups_ratio) * stride_cb_head
|
||||||
x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||||
dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||||
dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||||
C_ptr += c_idx * chunk_size * stride_C_seqlen + (
|
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||||
|
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||||
|
C_ptr += chunk_seqlen_start * stride_C_seqlen + (
|
||||||
pid_h // nheads_ngroups_ratio) * stride_C_head
|
pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||||
|
|
||||||
# M-block offsets and prev states
|
# M-block offsets and prev states
|
||||||
# - logic in next block may override these if there is an active offset
|
# - logic in next block may override these if there is an active offset
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head
|
|
||||||
prev_states_hdim = stride_states_hdim
|
|
||||||
prev_states_dstate = stride_states_dstate
|
|
||||||
|
|
||||||
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
|
seq_idx_ptr += pid_c * stride_seq_idx_chunk
|
||||||
|
seq_idx = tl.load(seq_idx_ptr)
|
||||||
|
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk,
|
||||||
|
mask=pid_c >= 1,
|
||||||
|
other=-1)
|
||||||
|
|
||||||
seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen
|
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||||
# - we only need seq_idx_prev to be aligned to chunk boundary
|
prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head
|
||||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
|
prev_states_hdim = stride_init_states_hdim
|
||||||
mask=c_idx >= 1,
|
prev_states_dstate = stride_init_states_dstate
|
||||||
other=0)
|
else:
|
||||||
|
prev_states_ptr = states_ptr + (
|
||||||
|
pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
|
||||||
|
prev_states_hdim = stride_states_hdim
|
||||||
|
prev_states_dstate = stride_states_dstate
|
||||||
|
|
||||||
if HAS_INITSTATES:
|
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||||
# if there are init states, we only need seq_idx_m to point
|
|
||||||
# what is the current seq_idx
|
|
||||||
|
|
||||||
# get current seq idx
|
|
||||||
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
|
||||||
seq_idx_m = tl.load(
|
|
||||||
seq_idx_ptr +
|
|
||||||
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
|
|
||||||
|
|
||||||
# - recall that in ssd_state_passing, for the case c_off == 0
|
|
||||||
# i.e., the very first sequence, we made states_ptr hold its initial state
|
|
||||||
# so this edge case is taken care of
|
|
||||||
if ((c_off == 0) and (seq_idx_prev != seq_idx_m
|
|
||||||
) # if a seq is changed exactly on boundary
|
|
||||||
or (c_off > 0) # implies a new example (pseudo chunk)
|
|
||||||
):
|
|
||||||
|
|
||||||
# - replace prev_states_ptr with init_states
|
|
||||||
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
|
||||||
prev_states_hdim = stride_init_states_hdim # override strides
|
|
||||||
prev_states_dstate = stride_init_states_dstate
|
|
||||||
|
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
||||||
mask=offs_m < chunk_size,
|
mask=offs_m < chunk_size,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
|
|
||||||
# - handle chunk state limit
|
|
||||||
if HAS_INITSTATES:
|
|
||||||
# have to split this if otherwise compilation will have problems
|
|
||||||
dA_cs_m_boundary = 0.0
|
|
||||||
|
|
||||||
# get the c_idx for the next (logica) chunk
|
|
||||||
c_idx_n = tl.load(
|
|
||||||
chunk_indices_ptr + (pid_c + 1),
|
|
||||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
|
||||||
other=-1 # to trigger different chunk
|
|
||||||
)
|
|
||||||
|
|
||||||
# - there are things to consider
|
|
||||||
# A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
|
|
||||||
# contribution of past states
|
|
||||||
# B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
|
|
||||||
# encroach into the next sequence, where c_off_n is the offset of the next
|
|
||||||
# (logical) chunk.
|
|
||||||
# An equivalent check for B is c_idx == c_idx_n, where there is repetition in
|
|
||||||
# (logical) chunk indices.
|
|
||||||
|
|
||||||
if (c_idx == c_idx_n) or c_off > 0:
|
|
||||||
|
|
||||||
# get the next offset
|
|
||||||
c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1),
|
|
||||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
|
||||||
other=chunk_size)
|
|
||||||
|
|
||||||
# in this case, adjust down the chunk_size_limit
|
|
||||||
if c_idx == c_idx_n:
|
|
||||||
chunk_size_limit = min(c_off_n, chunk_size_limit)
|
|
||||||
|
|
||||||
# get the cs at the offset boundary
|
|
||||||
# - c_off == 0 is a passthrough
|
|
||||||
# - We need dA_cs at the boundary, defined by c_off - no need
|
|
||||||
# to increase pointer by pid_m (it is a constant offset,
|
|
||||||
# i.e. the same for all blocks)
|
|
||||||
dA_cs_m_boundary = tl.load(
|
|
||||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
|
||||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
|
||||||
other=0.0).to(tl.float32)
|
|
||||||
else:
|
|
||||||
# - handle seq idx when HAS_INITSTATES==False
|
|
||||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
|
||||||
mask=offs_m < chunk_size_limit,
|
|
||||||
other=-1)
|
|
||||||
|
|
||||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
# Without the if (pid_c > -1), with Triton 2.1.0, I get
|
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
|
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
# With Triton 2.2.0, this works
|
|
||||||
if IS_TRITON_22 or c_idx > -1:
|
|
||||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
|
||||||
offs_k_dstate = tl.arange(
|
|
||||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
|
||||||
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
|
|
||||||
offs_k_dstate[None, :] * stride_C_dstate)
|
|
||||||
|
|
||||||
prev_states_ptrs = prev_states_ptr + (
|
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||||
offs_n[None, :] * prev_states_hdim +
|
offs_k_dstate = tl.arange(
|
||||||
offs_k_dstate[:, None] * prev_states_dstate)
|
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||||
|
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
|
||||||
|
offs_k_dstate[None, :] * stride_C_dstate)
|
||||||
|
|
||||||
if not HAS_INITSTATES:
|
scale_m = tl.exp(dA_cs_m)
|
||||||
# - this is for continuous batching where there is no init states
|
if BLOCK_SIZE_DSTATE <= 128:
|
||||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
|
C = tl.load(C_ptrs,
|
||||||
|
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||||
|
(offs_k_dstate[None, :] < dstate),
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||||
|
# if no init states AND starting a new sequence, we need zeros
|
||||||
|
prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N),
|
||||||
|
dtype=C_ptr.dtype.element_ty)
|
||||||
else:
|
else:
|
||||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
# otherwise read the previous state
|
||||||
# required.
|
prev_states_ptrs = prev_states_ptr \
|
||||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
+ offs_n[None, :] * prev_states_hdim \
|
||||||
|
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||||
if BLOCK_SIZE_DSTATE <= 128:
|
|
||||||
C = tl.load(C_ptrs,
|
|
||||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
|
||||||
(offs_k_dstate[None, :] < dstate),
|
|
||||||
other=0.0)
|
|
||||||
|
|
||||||
prev_states = tl.load(prev_states_ptrs,
|
prev_states = tl.load(prev_states_ptrs,
|
||||||
mask=(offs_k_dstate[:, None] < dstate) &
|
mask=(offs_k_dstate[:, None] < dstate) &
|
||||||
(offs_n[None, :] < hdim),
|
(offs_n[None, :] < hdim),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
|
||||||
else:
|
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
|
||||||
C = tl.load(C_ptrs,
|
else:
|
||||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
prev_states_ptrs = prev_states_ptr \
|
||||||
(offs_k_dstate[None, :] < dstate - k),
|
+ offs_n[None, :] * prev_states_hdim \
|
||||||
other=0.0)
|
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||||
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
|
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||||
|
C = tl.load(C_ptrs,
|
||||||
|
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||||
|
(offs_k_dstate[None, :] < dstate - k),
|
||||||
|
other=0.0)
|
||||||
|
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||||
|
prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K),
|
||||||
|
dtype=C_ptr.dtype.element_ty)
|
||||||
|
else:
|
||||||
prev_states = tl.load(
|
prev_states = tl.load(
|
||||||
prev_states_ptrs,
|
prev_states_ptrs,
|
||||||
mask=(offs_k_dstate[:, None] < dstate - k) &
|
mask=(offs_k_dstate[:, None] < dstate - k) &
|
||||||
(offs_n[None, :] < hdim),
|
(offs_n[None, :] < hdim),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||||
acc += tl.dot(C, prev_states)
|
acc += tl.dot(C, prev_states)
|
||||||
C_ptrs += BLOCK_SIZE_K
|
C_ptrs += BLOCK_SIZE_K
|
||||||
prev_states_ptrs += BLOCK_SIZE_K
|
prev_states_ptrs += BLOCK_SIZE_K
|
||||||
acc *= scale_m[:, None]
|
acc *= scale_m[:, None]
|
||||||
|
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
|
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
|
||||||
offs_k[None, :] * stride_cb_csize_k)
|
offs_k[None, :] * stride_cb_csize_k)
|
||||||
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
|
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
|
||||||
@ -375,7 +310,7 @@ def _chunk_scan_fwd_kernel(
|
|||||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||||
|
|
||||||
offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
|
||||||
if HAS_D:
|
if HAS_D:
|
||||||
@ -393,7 +328,7 @@ def _chunk_scan_fwd_kernel(
|
|||||||
acc += x_residual * D
|
acc += x_residual * D
|
||||||
|
|
||||||
if HAS_Z:
|
if HAS_Z:
|
||||||
z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
|
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
|
||||||
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
||||||
stride_z_hdim * offs_out_n[None, :])
|
stride_z_hdim * offs_out_n[None, :])
|
||||||
z = tl.load(z_ptrs,
|
z = tl.load(z_ptrs,
|
||||||
@ -402,7 +337,7 @@ def _chunk_scan_fwd_kernel(
|
|||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
acc *= z * tl.sigmoid(z)
|
acc *= z * tl.sigmoid(z)
|
||||||
|
|
||||||
out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
|
||||||
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
||||||
offs_out_n[None, :] * stride_out_hdim)
|
offs_out_n[None, :] * stride_out_hdim)
|
||||||
tl.store(out_ptrs,
|
tl.store(out_ptrs,
|
||||||
@ -418,12 +353,11 @@ def _chunk_scan_fwd(
|
|||||||
dA_cumsum,
|
dA_cumsum,
|
||||||
C,
|
C,
|
||||||
states,
|
states,
|
||||||
|
cu_chunk_seqlens,
|
||||||
out,
|
out,
|
||||||
seq_idx,
|
seq_idx,
|
||||||
D=None,
|
D=None,
|
||||||
z=None,
|
z=None,
|
||||||
chunk_indices=None,
|
|
||||||
chunk_offsets=None,
|
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
):
|
):
|
||||||
assert seq_idx is not None, "this implementation requires seq_idx"
|
assert seq_idx is not None, "this implementation requires seq_idx"
|
||||||
@ -441,20 +375,10 @@ def _chunk_scan_fwd(
|
|||||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||||
assert seq_idx.shape == (seqlen, )
|
assert seq_idx.shape == (nchunks, )
|
||||||
|
|
||||||
if initial_states is not None:
|
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton
|
||||||
# with initial states, we need to take care of how
|
.cdiv(headdim, META['BLOCK_SIZE_N']), nchunks, nheads)
|
||||||
# seq_idx crosses the boundaries
|
|
||||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
|
||||||
"chunk_indices and chunk_offsets should have been set"
|
|
||||||
else:
|
|
||||||
chunk_indices, chunk_offsets = None, None
|
|
||||||
|
|
||||||
grid = lambda META: (
|
|
||||||
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
|
||||||
headdim, META['BLOCK_SIZE_N']), nchunks
|
|
||||||
if chunk_offsets is None else len(chunk_offsets), nheads)
|
|
||||||
|
|
||||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||||
(0, 0, 0))
|
(0, 0, 0))
|
||||||
@ -476,9 +400,7 @@ def _chunk_scan_fwd(
|
|||||||
states_ptr=states,
|
states_ptr=states,
|
||||||
D_ptr=D,
|
D_ptr=D,
|
||||||
initstates_ptr=initial_states,
|
initstates_ptr=initial_states,
|
||||||
chunk_indices_ptr=chunk_indices,
|
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||||
chunk_offsets_ptr=chunk_offsets,
|
|
||||||
chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0,
|
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
hdim=headdim,
|
hdim=headdim,
|
||||||
dstate=dstate,
|
dstate=dstate,
|
||||||
@ -503,7 +425,7 @@ def _chunk_scan_fwd(
|
|||||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||||
stride_C_seqlen=C.stride(0),
|
stride_C_seqlen=C.stride(0),
|
||||||
stride_C_head=C.stride(1),
|
stride_C_head=C.stride(1),
|
||||||
stride_C_dstate=C.stride(2),
|
stride_C_dstate=C.stride(2),
|
||||||
|
|||||||
@ -6,8 +6,6 @@
|
|||||||
|
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
@ -34,6 +32,7 @@ def _chunk_cumsum_fwd_kernel(
|
|||||||
dt_bias_ptr,
|
dt_bias_ptr,
|
||||||
dt_out_ptr,
|
dt_out_ptr,
|
||||||
dA_cumsum_ptr,
|
dA_cumsum_ptr,
|
||||||
|
cu_chunk_seqlens_ptr,
|
||||||
# Matrix dimension
|
# Matrix dimension
|
||||||
seqlen,
|
seqlen,
|
||||||
nheads: tl.constexpr,
|
nheads: tl.constexpr,
|
||||||
@ -61,7 +60,11 @@ def _chunk_cumsum_fwd_kernel(
|
|||||||
# https://github.com/triton-lang/triton/issues/1058
|
# https://github.com/triton-lang/triton/issues/1058
|
||||||
pid_c = tl.program_id(axis=0).to(tl.int64)
|
pid_c = tl.program_id(axis=0).to(tl.int64)
|
||||||
pid_h = tl.program_id(axis=1)
|
pid_h = tl.program_id(axis=1)
|
||||||
dt_ptr += pid_c * chunk_size * stride_dt_seqlen
|
|
||||||
|
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||||
|
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||||
|
|
||||||
|
dt_ptr += chunk_seqlen_start * stride_dt_seqlen
|
||||||
dt_out_ptr += pid_c * stride_dt_out_chunk
|
dt_out_ptr += pid_c * stride_dt_out_chunk
|
||||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
||||||
|
|
||||||
@ -74,7 +77,7 @@ def _chunk_cumsum_fwd_kernel(
|
|||||||
offs_c[None, :] * stride_dt_out_csize)
|
offs_c[None, :] * stride_dt_out_csize)
|
||||||
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
|
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
|
||||||
offs_c[None, :] * stride_dA_cs_csize)
|
offs_c[None, :] * stride_dA_cs_csize)
|
||||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||||
|
|
||||||
dt = tl.load(dt_ptrs,
|
dt = tl.load(dt_ptrs,
|
||||||
mask=(offs_h[:, None] < nheads) &
|
mask=(offs_h[:, None] < nheads) &
|
||||||
@ -188,7 +191,7 @@ def _chunk_state_fwd_kernel(
|
|||||||
states_ptr,
|
states_ptr,
|
||||||
dt_ptr,
|
dt_ptr,
|
||||||
dA_cumsum_ptr,
|
dA_cumsum_ptr,
|
||||||
seq_idx_ptr,
|
cu_chunk_seqlens_ptr,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
hdim: tl.constexpr,
|
hdim: tl.constexpr,
|
||||||
dstate: tl.constexpr,
|
dstate: tl.constexpr,
|
||||||
@ -212,7 +215,6 @@ def _chunk_state_fwd_kernel(
|
|||||||
stride_dA_cs_head: tl.int64,
|
stride_dA_cs_head: tl.int64,
|
||||||
stride_dA_cs_chunk: tl.int64,
|
stride_dA_cs_chunk: tl.int64,
|
||||||
stride_dA_cs_csize: tl.constexpr,
|
stride_dA_cs_csize: tl.constexpr,
|
||||||
stride_seq_idx_seqlen: tl.constexpr,
|
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
@ -223,14 +225,14 @@ def _chunk_state_fwd_kernel(
|
|||||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||||
|
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||||
|
b_ptr += chunk_seqlen_start * stride_b_seqlen + (
|
||||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||||
|
|
||||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
|
||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
@ -243,10 +245,7 @@ def _chunk_state_fwd_kernel(
|
|||||||
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||||
|
|
||||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
|
||||||
seq_idx_last = tl.load(seq_idx_ptr +
|
|
||||||
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
|
||||||
|
|
||||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||||
@ -261,15 +260,9 @@ def _chunk_state_fwd_kernel(
|
|||||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||||
mask=offs_k < chunk_size_limit - k,
|
mask=offs_k < chunk_size_limit - k,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
|
|
||||||
seq_idx_k = tl.load(seq_idx_ptrs,
|
|
||||||
mask=offs_k < chunk_size_limit - k,
|
|
||||||
other=-1)
|
|
||||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
|
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
||||||
scale = tl.where(seq_idx_k == seq_idx_last,
|
|
||||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
|
||||||
b *= scale[:, None]
|
b *= scale[:, None]
|
||||||
b = b.to(x_ptr.dtype.element_ty)
|
b = b.to(x_ptr.dtype.element_ty)
|
||||||
acc += tl.dot(x, b)
|
acc += tl.dot(x, b)
|
||||||
@ -278,7 +271,6 @@ def _chunk_state_fwd_kernel(
|
|||||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
|
||||||
|
|
||||||
states = acc.to(states_ptr.dtype.element_ty)
|
states = acc.to(states_ptr.dtype.element_ty)
|
||||||
|
|
||||||
@ -534,6 +526,7 @@ def _chunk_state_varlen_kernel(
|
|||||||
def _chunk_cumsum_fwd(dt,
|
def _chunk_cumsum_fwd(dt,
|
||||||
A,
|
A,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
cu_chunk_seqlens,
|
||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf"))):
|
dt_limit=(0.0, float("inf"))):
|
||||||
@ -541,7 +534,7 @@ def _chunk_cumsum_fwd(dt,
|
|||||||
assert A.shape == (nheads, )
|
assert A.shape == (nheads, )
|
||||||
if dt_bias is not None:
|
if dt_bias is not None:
|
||||||
assert dt_bias.shape == (nheads, )
|
assert dt_bias.shape == (nheads, )
|
||||||
nchunks = math.ceil(seqlen / chunk_size)
|
nchunks = cu_chunk_seqlens.shape[0] - 1
|
||||||
dt_out = torch.empty(nheads,
|
dt_out = torch.empty(nheads,
|
||||||
nchunks,
|
nchunks,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
@ -561,6 +554,7 @@ def _chunk_cumsum_fwd(dt,
|
|||||||
dt_bias_ptr=dt_bias,
|
dt_bias_ptr=dt_bias,
|
||||||
dt_out_ptr=dt_out,
|
dt_out_ptr=dt_out,
|
||||||
dA_cumsum_ptr=dA_cumsum,
|
dA_cumsum_ptr=dA_cumsum,
|
||||||
|
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
nheads=nheads,
|
nheads=nheads,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
@ -588,7 +582,7 @@ def _chunk_state_fwd(B,
|
|||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
dA_cumsum,
|
dA_cumsum,
|
||||||
seq_idx=None,
|
cu_chunk_seqlens,
|
||||||
states=None,
|
states=None,
|
||||||
states_in_fp32=True):
|
states_in_fp32=True):
|
||||||
seqlen, nheads, headdim = x.shape
|
seqlen, nheads, headdim = x.shape
|
||||||
@ -599,9 +593,6 @@ def _chunk_state_fwd(B,
|
|||||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||||
assert dA_cumsum.shape == dt.shape
|
assert dA_cumsum.shape == dt.shape
|
||||||
|
|
||||||
assert seq_idx is not None
|
|
||||||
assert seq_idx.shape == (seqlen, )
|
|
||||||
|
|
||||||
if states is not None:
|
if states is not None:
|
||||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||||
else:
|
else:
|
||||||
@ -619,7 +610,7 @@ def _chunk_state_fwd(B,
|
|||||||
states_ptr=states,
|
states_ptr=states,
|
||||||
dt_ptr=dt,
|
dt_ptr=dt,
|
||||||
dA_cumsum_ptr=dA_cumsum,
|
dA_cumsum_ptr=dA_cumsum,
|
||||||
seq_idx_ptr=seq_idx,
|
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||||
hdim=headdim,
|
hdim=headdim,
|
||||||
dstate=dstate,
|
dstate=dstate,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
@ -641,7 +632,6 @@ def _chunk_state_fwd(B,
|
|||||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
|
||||||
)
|
)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
|||||||
@ -14,8 +14,7 @@ from vllm.triton_utils import triton
|
|||||||
|
|
||||||
from .ssd_bmm import _bmm_chunk_fwd
|
from .ssd_bmm import _bmm_chunk_fwd
|
||||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||||
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
|
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
|
||||||
chunk_state_varlen)
|
|
||||||
from .ssd_state_passing import _state_passing_fwd
|
from .ssd_state_passing import _state_passing_fwd
|
||||||
|
|
||||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||||
@ -37,9 +36,9 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
seq_idx=None,
|
seq_idx=None,
|
||||||
chunk_indices=None,
|
|
||||||
chunk_offsets=None,
|
|
||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
|
cu_chunk_seqlens=None,
|
||||||
|
last_chunk_indices=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
state_dtype=None):
|
state_dtype=None):
|
||||||
@ -56,7 +55,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
if D is not None:
|
if D is not None:
|
||||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||||
if seq_idx is not None:
|
if seq_idx is not None:
|
||||||
assert seq_idx.shape == (seqlen, )
|
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1, )
|
||||||
if B.stride(-1) != 1:
|
if B.stride(-1) != 1:
|
||||||
B = B.contiguous()
|
B = B.contiguous()
|
||||||
if C.stride(-1) != 1:
|
if C.stride(-1) != 1:
|
||||||
@ -89,6 +88,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
|
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
|
||||||
A,
|
A,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
cu_chunk_seqlens,
|
||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=dt_softplus,
|
dt_softplus=dt_softplus,
|
||||||
dt_limit=dt_limit)
|
dt_limit=dt_limit)
|
||||||
@ -99,36 +99,31 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
x,
|
x,
|
||||||
dt,
|
dt,
|
||||||
dA_cumsum,
|
dA_cumsum,
|
||||||
seq_idx=seq_idx,
|
cu_chunk_seqlens,
|
||||||
states_in_fp32=True)
|
states_in_fp32=True)
|
||||||
|
|
||||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||||
# (middle term of factorization of off-diag blocks; A terms)
|
# (middle term of factorization of off-diag blocks; A terms)
|
||||||
# - for handling chunked prefill, this requires i) initial_states
|
# - for handling chunked prefill, this requires i) initial_states and
|
||||||
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
|
# ii) seq_idx to be all specified.
|
||||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||||
# - We will also make sure that the dA_cumsum is taken only from the start of the
|
|
||||||
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
|
|
||||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
|
||||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
|
||||||
# or equal to init_states of the first example.
|
|
||||||
states = _state_passing_fwd(
|
states = _state_passing_fwd(
|
||||||
rearrange(states, "... p n -> ... (p n)"),
|
rearrange(states, "... p n -> ... (p n)"),
|
||||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||||
|
cu_chunk_seqlens,
|
||||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||||
if initial_states is not None else
|
if initial_states is not None else
|
||||||
None, # (batch, nheads, headdim*dstate)
|
None, # (batch, nheads, headdim*dstate)
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
out_dtype=state_dtype if state_dtype is not None else C.dtype)
|
||||||
chunk_offsets=chunk_offsets)
|
|
||||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||||
|
|
||||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||||
CB = _bmm_chunk_fwd(C,
|
CB = _bmm_chunk_fwd(C,
|
||||||
B,
|
B,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
seq_idx=seq_idx,
|
cu_chunk_seqlens,
|
||||||
output_dtype=torch.float32)
|
output_dtype=torch.float32)
|
||||||
|
|
||||||
# 5. Scan and compute the diagonal blocks, taking into
|
# 5. Scan and compute the diagonal blocks, taking into
|
||||||
@ -148,26 +143,15 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
dA_cumsum,
|
dA_cumsum,
|
||||||
C,
|
C,
|
||||||
states,
|
states,
|
||||||
|
cu_chunk_seqlens,
|
||||||
out, # in-place update
|
out, # in-place update
|
||||||
seq_idx,
|
seq_idx,
|
||||||
D=D,
|
D=D,
|
||||||
z=z,
|
z=z,
|
||||||
chunk_indices=chunk_indices,
|
|
||||||
chunk_offsets=chunk_offsets,
|
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
varlen_states = chunk_state_varlen(
|
return states[last_chunk_indices]
|
||||||
B,
|
|
||||||
x,
|
|
||||||
dt,
|
|
||||||
dA_cumsum,
|
|
||||||
cu_seqlens,
|
|
||||||
states,
|
|
||||||
initial_states=initial_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
return varlen_states
|
|
||||||
|
|
||||||
|
|
||||||
def mamba_chunk_scan_combined_varlen(
|
def mamba_chunk_scan_combined_varlen(
|
||||||
@ -178,14 +162,14 @@ def mamba_chunk_scan_combined_varlen(
|
|||||||
C,
|
C,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
cu_chunk_seqlens,
|
||||||
|
last_chunk_indices,
|
||||||
seq_idx,
|
seq_idx,
|
||||||
out,
|
out,
|
||||||
D=None,
|
D=None,
|
||||||
z=None,
|
z=None,
|
||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
chunk_indices=None,
|
|
||||||
chunk_offsets=None,
|
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
state_dtype=None,
|
state_dtype=None,
|
||||||
@ -198,8 +182,10 @@ def mamba_chunk_scan_combined_varlen(
|
|||||||
B: (seqlen, ngroups, dstate)
|
B: (seqlen, ngroups, dstate)
|
||||||
C: (seqlen, ngroups, dstate)
|
C: (seqlen, ngroups, dstate)
|
||||||
chunk_size: int
|
chunk_size: int
|
||||||
seq_idx: (seqlen)
|
cu_seqlens: (batch + 1,)
|
||||||
cu_seqlens: (batch + 1)
|
cu_chunk_seqlens: (nchunks + 1,)
|
||||||
|
last_chunk_indices: (batch,)
|
||||||
|
seq_idx: (nchunks,)
|
||||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||||
D: (nheads, headdim) or (nheads,)
|
D: (nheads, headdim) or (nheads,)
|
||||||
z: (seqlen, nheads, headdim)
|
z: (seqlen, nheads, headdim)
|
||||||
@ -228,9 +214,9 @@ def mamba_chunk_scan_combined_varlen(
|
|||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
chunk_indices=chunk_indices,
|
|
||||||
chunk_offsets=chunk_offsets,
|
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
|
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||||
|
last_chunk_indices=last_chunk_indices,
|
||||||
dt_softplus=dt_softplus,
|
dt_softplus=dt_softplus,
|
||||||
dt_limit=dt_limit,
|
dt_limit=dt_limit,
|
||||||
state_dtype=state_dtype)
|
state_dtype=state_dtype)
|
||||||
|
|||||||
@ -30,8 +30,7 @@ def _state_passing_fwd_kernel(
|
|||||||
dA_cs_ptr,
|
dA_cs_ptr,
|
||||||
initstates_ptr,
|
initstates_ptr,
|
||||||
seq_idx_ptr,
|
seq_idx_ptr,
|
||||||
chunk_offsets_ptr,
|
cu_chunk_seqlens_ptr,
|
||||||
chunk_meta_num,
|
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
dim: tl.constexpr,
|
dim: tl.constexpr,
|
||||||
nchunks,
|
nchunks,
|
||||||
@ -50,94 +49,52 @@ def _state_passing_fwd_kernel(
|
|||||||
stride_initstates_batch: tl.int64,
|
stride_initstates_batch: tl.int64,
|
||||||
stride_initstates_head: tl.int64,
|
stride_initstates_head: tl.int64,
|
||||||
stride_initstates_dim: tl.constexpr,
|
stride_initstates_dim: tl.constexpr,
|
||||||
stride_seq_idx_seqlen: tl.constexpr,
|
stride_seq_idx_chunk: tl.constexpr,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
HAS_INITSTATES: tl.constexpr,
|
HAS_INITSTATES: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid_h = tl.program_id(axis=1)
|
pid_h = tl.program_id(axis=1)
|
||||||
pid_m = tl.program_id(axis=0)
|
pid_m = tl.program_id(axis=0)
|
||||||
|
|
||||||
states_ptr += pid_h * stride_states_head
|
states_ptr += pid_h * stride_states_head
|
||||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
|
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
|
||||||
1) * stride_dA_cs_csize
|
1) * stride_dA_cs_csize
|
||||||
out_ptr += pid_h * stride_out_head
|
out_ptr += pid_h * stride_out_head
|
||||||
if HAS_INITSTATES:
|
|
||||||
initstates_ptr += pid_h * stride_initstates_head
|
|
||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||||
|
|
||||||
# - states will be the past state of the sequence that continues on the current check
|
if HAS_INITSTATES:
|
||||||
if not HAS_INITSTATES:
|
initstates_ptrs = initstates_ptr \
|
||||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
+ pid_h * stride_initstates_head \
|
||||||
else:
|
+ offs_m * stride_initstates_dim
|
||||||
initstates_ptr += offs_m * stride_initstates_dim
|
|
||||||
initstates_ptrs = initstates_ptr
|
|
||||||
# - for cont batches, for the first chunk mean it will be the first batch's
|
|
||||||
# init state
|
|
||||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
|
else:
|
||||||
|
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||||
|
|
||||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
prev_seq_idx = 0
|
||||||
out_ptrs += stride_out_chunk
|
for c in range(nchunks):
|
||||||
prev_seq_idx_chunk_end = 0
|
|
||||||
logical_chunk_idx = 0
|
|
||||||
for c in range(nchunks - 1):
|
|
||||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||||
scale_mask = True
|
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
|
||||||
# - the seq to pass forward is the one that is flushed to the right
|
# we have started a new sequence
|
||||||
# boundary.
|
if prev_seq_idx != seq_idx:
|
||||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
if HAS_INITSTATES:
|
||||||
seq_idx_chunk_end = tl.load(seq_idx_ptr +
|
initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \
|
||||||
(min((c + 1) * chunk_size, seqlen) - 1) *
|
+ pid_h * stride_initstates_head \
|
||||||
stride_seq_idx_seqlen)
|
+ offs_m * stride_initstates_dim
|
||||||
|
|
||||||
if HAS_INITSTATES:
|
|
||||||
if prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
|
||||||
# this means in the current chunk the rightmost flushed seq
|
|
||||||
# has changed.
|
|
||||||
# - so we do not propagate the state from previous chunk
|
|
||||||
# - but rather we load that sequence's init state
|
|
||||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
|
||||||
|
|
||||||
# - update state with seq_idx_new's init state
|
|
||||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
|
else:
|
||||||
|
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||||
|
|
||||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
prev_seq_idx = seq_idx
|
||||||
# - find its starting position (given by c_off of the logical chunk index)
|
states = tl.exp(dA_cs) * states + new_states
|
||||||
# - and subtract the cumsum just before that position from the total cumsum
|
|
||||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
|
||||||
# sequence index at the start of the current chunk
|
|
||||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
|
||||||
min(c * chunk_size, seqlen) *
|
|
||||||
stride_seq_idx_seqlen)
|
|
||||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
|
||||||
# - load the chunk offset:
|
|
||||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
|
||||||
mask=logical_chunk_idx < chunk_meta_num,
|
|
||||||
other=0)
|
|
||||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
|
||||||
if c_off > 0:
|
|
||||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
|
||||||
dA_cs_boundary = tl.load(
|
|
||||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
|
||||||
(c_off - 1) * stride_dA_cs_csize,
|
|
||||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
|
||||||
other=0.0)
|
|
||||||
dA_cs -= dA_cs_boundary
|
|
||||||
|
|
||||||
# - increment logical chunk index for every physical chunk
|
|
||||||
logical_chunk_idx += 1
|
|
||||||
else:
|
|
||||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
|
||||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
|
||||||
|
|
||||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
|
||||||
states = scale * states + new_states
|
|
||||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||||
|
|
||||||
states_ptrs += stride_states_chunk
|
states_ptrs += stride_states_chunk
|
||||||
@ -148,8 +105,8 @@ def _state_passing_fwd_kernel(
|
|||||||
def _state_passing_fwd(
|
def _state_passing_fwd(
|
||||||
states,
|
states,
|
||||||
dA_cumsum,
|
dA_cumsum,
|
||||||
|
cu_chunk_seqlens,
|
||||||
seq_idx,
|
seq_idx,
|
||||||
chunk_offsets,
|
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
out_dtype=None,
|
out_dtype=None,
|
||||||
):
|
):
|
||||||
@ -175,9 +132,7 @@ def _state_passing_fwd(
|
|||||||
dA_cs_ptr=dA_cumsum,
|
dA_cs_ptr=dA_cumsum,
|
||||||
initstates_ptr=initial_states,
|
initstates_ptr=initial_states,
|
||||||
seq_idx_ptr=seq_idx,
|
seq_idx_ptr=seq_idx,
|
||||||
chunk_offsets_ptr=chunk_offsets,
|
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||||
chunk_meta_num=len(chunk_offsets)
|
|
||||||
if chunk_offsets is not None else 0,
|
|
||||||
dim=dim,
|
dim=dim,
|
||||||
nchunks=nchunks,
|
nchunks=nchunks,
|
||||||
seqlen=seqlen if seq_idx is not None else 0,
|
seqlen=seqlen if seq_idx is not None else 0,
|
||||||
@ -194,7 +149,7 @@ def _state_passing_fwd(
|
|||||||
stride_initstates_batch=initial_states_strides[0],
|
stride_initstates_batch=initial_states_strides[0],
|
||||||
stride_initstates_head=initial_states_strides[1],
|
stride_initstates_head=initial_states_strides[1],
|
||||||
stride_initstates_dim=initial_states_strides[2],
|
stride_initstates_dim=initial_states_strides[2],
|
||||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||||
HAS_INITSTATES=initial_states is not None,
|
HAS_INITSTATES=initial_states is not None,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -260,9 +260,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
prep_initial_states = attn_metadata.prep_initial_states
|
prep_initial_states = attn_metadata.prep_initial_states
|
||||||
chunk_size = attn_metadata.chunk_size
|
chunk_size = attn_metadata.chunk_size
|
||||||
seq_idx_p = attn_metadata.seq_idx_p
|
seq_idx_p = attn_metadata.seq_idx_p
|
||||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
|
||||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
|
||||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||||
|
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||||
|
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)
|
projected_states = self.in_proj(hidden_states)
|
||||||
@ -368,9 +368,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
self.num_heads // self.tp_size, self.head_dim),
|
self.num_heads // self.tp_size, self.head_dim),
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
seq_idx=seq_idx_p,
|
seq_idx=seq_idx_p,
|
||||||
chunk_indices=chunk_indices_p,
|
|
||||||
chunk_offsets=chunk_offsets_p,
|
|
||||||
cu_seqlens=query_start_loc_p,
|
cu_seqlens=query_start_loc_p,
|
||||||
|
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||||
|
last_chunk_indices=last_chunk_indices_p,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -8,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.mamba_attn import (
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
BaseMambaAttentionMetadataBuilder)
|
BaseMambaAttentionMetadataBuilder)
|
||||||
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
||||||
@ -17,91 +17,6 @@ from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
|
||||||
def _query_start_loc_to_chunk_indices_offsets(
|
|
||||||
query_start_loc: torch.Tensor, chunk_size: int,
|
|
||||||
total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
|
|
||||||
lengths, shape (num_seqs + 1,).
|
|
||||||
The first element should be 0. Each entry represents the starting
|
|
||||||
index of a sequence in the flattened token array.
|
|
||||||
chunk_size (int): The size of each physical mamba chunk
|
|
||||||
(number of tokens per chunk).
|
|
||||||
total_seqlens (int): The total number of tokens in the batch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
||||||
- chunk_indices (torch.Tensor): 1D tensor of indices
|
|
||||||
indicating the physical chunk for each logical chunk.
|
|
||||||
- chunk_offsets (torch.Tensor): 1D tensor of offsets
|
|
||||||
indicating the starting index of each logical chunk within
|
|
||||||
its physical chunk.
|
|
||||||
|
|
||||||
This function computes the chunk indices and offsets for the given
|
|
||||||
query_start_loc and chunk_size. Both are tensors of integers with length N,
|
|
||||||
where N is the number of logical (pseudo) chunks.
|
|
||||||
A logical chunk is a sequence of tokens that are all part of the same
|
|
||||||
sequence and are all in the same physical mamba chunk.
|
|
||||||
In other words, a logical chunk changes every time we cross a sequence
|
|
||||||
boundary or a physical mamba chunk boundary.
|
|
||||||
Logical chunks are needed to handle batched requests with initial states
|
|
||||||
(see _state_passing_fwd and _chunk_scan_fwd).
|
|
||||||
The chunk_indices tensor contains the index of the physical chunk for each
|
|
||||||
logical chunk.
|
|
||||||
The chunk_offsets tensor contains the offset (AKA starting index) of the
|
|
||||||
logical chunk in the physical chunk.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
query_start_loc = [0, 5, 10]
|
|
||||||
chunk_size = 8
|
|
||||||
total_seqlens = 10
|
|
||||||
-> chunk_indices = [0, 0, 1]
|
|
||||||
-> chunk_offsets = [0, 5, 0]
|
|
||||||
|
|
||||||
In this example, we have 2 sequences, each with 5 tokens. The physical
|
|
||||||
chunk size is 8 tokens.
|
|
||||||
We have three logical chunks:
|
|
||||||
- the first logical chunk starts at token 0 in the first physical chunk
|
|
||||||
and contains all 5 tokens from the first sequence
|
|
||||||
- the second logical chunk starts at token 5 in the first physical chunk
|
|
||||||
and contains first 3 tokens from the second sequence
|
|
||||||
- the third logical chunk starts at token 0 in the second physical chunk
|
|
||||||
and contains the remaining 2 tokens from the second sequence
|
|
||||||
"""
|
|
||||||
|
|
||||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
|
||||||
|
|
||||||
# outputs will have length expansion of chunks that do not divide
|
|
||||||
# chunk_size
|
|
||||||
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
|
||||||
> 0).sum()
|
|
||||||
chunk_indices = torch.arange(N,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=query_start_loc.device)
|
|
||||||
chunk_offsets = torch.zeros((N, ),
|
|
||||||
dtype=torch.int,
|
|
||||||
device=query_start_loc.device)
|
|
||||||
|
|
||||||
p = 0 # num of insertions
|
|
||||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
|
||||||
|
|
||||||
# if does not divide chunk_size, then there is one chunk insertion
|
|
||||||
p += (s % chunk_size > 0)
|
|
||||||
|
|
||||||
# get the dimensions
|
|
||||||
# - the + 1 for _e is to shift the boundary by one chunk
|
|
||||||
# - this shifting is not needed if chunk_size divides e
|
|
||||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
|
||||||
> 0)
|
|
||||||
|
|
||||||
# adjust indices and offsets
|
|
||||||
chunk_indices[_s:_e] -= p
|
|
||||||
chunk_offsets[_s] = s % chunk_size
|
|
||||||
|
|
||||||
return chunk_indices, chunk_offsets
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionBackend(AttentionBackend):
|
class Mamba2AttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -125,8 +40,16 @@ class Mamba2AttentionMetadata:
|
|||||||
# the batch has no prefill request.
|
# the batch has no prefill request.
|
||||||
has_initial_states_p: Optional[torch.Tensor]
|
has_initial_states_p: Optional[torch.Tensor]
|
||||||
seq_idx_p: Optional[torch.Tensor]
|
seq_idx_p: Optional[torch.Tensor]
|
||||||
chunk_indices_p: Optional[torch.Tensor]
|
|
||||||
chunk_offsets_p: Optional[torch.Tensor]
|
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||||
|
# each chunk, its offests into the varlen sequence dimension. It is defined
|
||||||
|
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||||
|
# cu_chunk_seqlen_p[i+1].
|
||||||
|
cu_chunk_seqlen_p: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||||
|
# index of the last chunk for every sequence in the (prefill) batch.
|
||||||
|
last_chunk_indices_p: Optional[torch.Tensor]
|
||||||
|
|
||||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||||
|
|
||||||
@ -151,13 +74,14 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
query_start_loc_p = None
|
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
|
||||||
|
query_start_loc_p = None
|
||||||
seq_idx_p = None
|
seq_idx_p = None
|
||||||
chunk_indices_p, chunk_offsets_p = None, None
|
cu_chunk_seqlen_p = None
|
||||||
|
last_chunk_indices_p = None
|
||||||
|
|
||||||
# Need flags to indicate if there are initial states
|
# Need flags to indicate if there are initial states
|
||||||
# currently we really only support the FlashAttention backend
|
|
||||||
has_initial_states_p = None
|
has_initial_states_p = None
|
||||||
prep_initial_states = False
|
prep_initial_states = False
|
||||||
|
|
||||||
@ -171,7 +95,7 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
common_attn_metadata,
|
common_attn_metadata,
|
||||||
decode_threshold=self.reorder_batch_threshold))
|
decode_threshold=self.reorder_batch_threshold))
|
||||||
|
|
||||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
# Compute seq_idx for prefill only
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
#[batch,]
|
#[batch,]
|
||||||
has_initial_states_cpu = (
|
has_initial_states_cpu = (
|
||||||
@ -184,21 +108,68 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||||
-num_prefills - 1:] - num_decode_tokens
|
-num_prefills - 1:] - num_decode_tokens
|
||||||
|
|
||||||
seq_idx_p = torch.repeat_interleave(torch.arange(
|
num_computed_tokens_p = \
|
||||||
num_prefills,
|
common_attn_metadata.num_computed_tokens_cpu[
|
||||||
dtype=torch.int32,
|
num_reqs - num_prefills:num_reqs]
|
||||||
device=query_start_loc_p.device),
|
query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[
|
||||||
query_start_loc_p.diff(),
|
-num_prefills - 1:] - num_decode_tokens
|
||||||
output_size=num_prefill_tokens)
|
|
||||||
|
|
||||||
# We compute metadata for chunked prefill once at the top level
|
# The code below carefully constructs the chunks such that:
|
||||||
# model forward and reuse them in mamba layers. If not needed,
|
# 1. Chunks contain tokens from a *single* sequence only.
|
||||||
# they will be ignored inside mamba kernels.
|
# 2. For every sequence, we are guaranteed that we can
|
||||||
if prep_initial_states:
|
# retrieve the mamba state *every* chunk_size tokens.
|
||||||
chunk_indices_p, chunk_offsets_p = (
|
# Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
# Constraint (2) dramatically simplifies the implementation
|
||||||
query_start_loc_p, self.chunk_size,
|
# of prefix caching for mamba2 (wip). We need to take care
|
||||||
num_prefill_tokens))
|
# of the interaction with chunked prefill in order to
|
||||||
|
# satisfy constraint (2).
|
||||||
|
# TODO (tdoublep): This code could probably be optimized.
|
||||||
|
cu_chunk_seqlen = []
|
||||||
|
seq_idx = []
|
||||||
|
last_chunk_indices = []
|
||||||
|
seqlen_pos = 0
|
||||||
|
for req_idx in range(num_prefills):
|
||||||
|
this_num_computed = num_computed_tokens_p[req_idx].item()
|
||||||
|
this_new_tokens = query_start_loc_p_cpu[req_idx + 1].item(
|
||||||
|
) - query_start_loc_p_cpu[req_idx].item()
|
||||||
|
|
||||||
|
# if computed tokens are not chunk-aligned, use the first
|
||||||
|
# chunk to finish it off
|
||||||
|
if this_num_computed % self.chunk_size != 0:
|
||||||
|
seq_idx.append(req_idx)
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
# how many tokens to finish the chunk?
|
||||||
|
chunk_len = cdiv(this_num_computed, self.chunk_size
|
||||||
|
) * self.chunk_size - this_num_computed
|
||||||
|
# we can only use at most this_new_tokens
|
||||||
|
chunk_len = min(chunk_len, this_new_tokens)
|
||||||
|
seqlen_pos += chunk_len
|
||||||
|
this_new_tokens -= chunk_len
|
||||||
|
|
||||||
|
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||||
|
for chunk in range(n_chunks):
|
||||||
|
seq_idx.append(req_idx)
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||||
|
seqlen_pos += chunk_len
|
||||||
|
this_new_tokens -= chunk_len
|
||||||
|
|
||||||
|
assert this_new_tokens == 0
|
||||||
|
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||||
|
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
|
||||||
|
seq_idx_p = torch.as_tensor(seq_idx,
|
||||||
|
device=query_start_loc_p.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
cu_chunk_seqlen_p = torch.as_tensor(
|
||||||
|
cu_chunk_seqlen,
|
||||||
|
device=query_start_loc_p.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
last_chunk_indices_p = torch.as_tensor(
|
||||||
|
last_chunk_indices,
|
||||||
|
device=query_start_loc_p.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||||
@ -222,9 +193,9 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
has_initial_states_p=has_initial_states_p,
|
has_initial_states_p=has_initial_states_p,
|
||||||
seq_idx_p=seq_idx_p,
|
seq_idx_p=seq_idx_p,
|
||||||
chunk_indices_p=chunk_indices_p,
|
|
||||||
chunk_offsets_p=chunk_offsets_p,
|
|
||||||
state_indices_tensor=state_indices_tensor,
|
state_indices_tensor=state_indices_tensor,
|
||||||
|
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||||
|
last_chunk_indices_p=last_chunk_indices_p,
|
||||||
nums_dict=nums_dict,
|
nums_dict=nums_dict,
|
||||||
batch_ptr=batch_ptr,
|
batch_ptr=batch_ptr,
|
||||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user