mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 04:55:02 +08:00
[Misc] Remove unused attention prefix prefill ops functions (#26971)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
d5edcb8678
commit
76e4dcf225
@ -335,216 +335,6 @@ def _fwd_kernel(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _fwd_kernel_flash_attn_v2(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
V,
|
|
||||||
K_cache,
|
|
||||||
V_cache,
|
|
||||||
B_Loc,
|
|
||||||
sm_scale,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
B_Ctxlen,
|
|
||||||
block_size,
|
|
||||||
x,
|
|
||||||
Out,
|
|
||||||
stride_b_loc_b,
|
|
||||||
stride_b_loc_s,
|
|
||||||
stride_qbs,
|
|
||||||
stride_qh,
|
|
||||||
stride_qd,
|
|
||||||
stride_kbs,
|
|
||||||
stride_kh,
|
|
||||||
stride_kd,
|
|
||||||
stride_vbs,
|
|
||||||
stride_vh,
|
|
||||||
stride_vd,
|
|
||||||
stride_obs,
|
|
||||||
stride_oh,
|
|
||||||
stride_od,
|
|
||||||
stride_k_cache_bs,
|
|
||||||
stride_k_cache_h,
|
|
||||||
stride_k_cache_d,
|
|
||||||
stride_k_cache_bl,
|
|
||||||
stride_k_cache_x,
|
|
||||||
stride_v_cache_bs,
|
|
||||||
stride_v_cache_h,
|
|
||||||
stride_v_cache_d,
|
|
||||||
stride_v_cache_bl,
|
|
||||||
num_queries_per_kv: int,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
cur_batch = tl.program_id(0)
|
|
||||||
cur_head = tl.program_id(1)
|
|
||||||
start_m = tl.program_id(2)
|
|
||||||
|
|
||||||
cur_kv_head = cur_head // num_queries_per_kv
|
|
||||||
|
|
||||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
|
||||||
|
|
||||||
block_start_loc = BLOCK_M * start_m
|
|
||||||
|
|
||||||
# initialize offsets
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
||||||
off_q = (
|
|
||||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
|
||||||
+ cur_head * stride_qh
|
|
||||||
+ offs_d[None, :] * stride_qd
|
|
||||||
)
|
|
||||||
|
|
||||||
q = tl.load(
|
|
||||||
Q + off_q,
|
|
||||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# # initialize pointer to m and l
|
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
||||||
|
|
||||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
# -- compute qk ----
|
|
||||||
bn = tl.load(
|
|
||||||
B_Loc
|
|
||||||
+ cur_batch * stride_b_loc_b
|
|
||||||
+ ((start_n + offs_n) // block_size) * stride_b_loc_s,
|
|
||||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
|
||||||
other=0,
|
|
||||||
).to(tl.int64)
|
|
||||||
off_k = (
|
|
||||||
bn[None, :] * stride_k_cache_bs
|
|
||||||
+ cur_kv_head * stride_k_cache_h
|
|
||||||
+ (offs_d[:, None] // x) * stride_k_cache_d
|
|
||||||
+ ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl
|
|
||||||
+ (offs_d[:, None] % x) * stride_k_cache_x
|
|
||||||
)
|
|
||||||
off_v = (
|
|
||||||
bn[:, None] * stride_v_cache_bs
|
|
||||||
+ cur_kv_head * stride_v_cache_h
|
|
||||||
+ offs_d[None, :] * stride_v_cache_d
|
|
||||||
+ (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl
|
|
||||||
)
|
|
||||||
k = tl.load(
|
|
||||||
K_cache + off_k,
|
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
||||||
qk += tl.dot(q, k)
|
|
||||||
qk = tl.where(
|
|
||||||
(start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
|
|
||||||
)
|
|
||||||
qk *= sm_scale
|
|
||||||
|
|
||||||
# -- compute m_ij, p, l_ij
|
|
||||||
m_ij = tl.max(qk, 1)
|
|
||||||
m_i_new = tl.maximum(m_i, m_ij)
|
|
||||||
p = tl.math.exp(qk - m_i_new[:, None])
|
|
||||||
l_ij = tl.sum(p, 1)
|
|
||||||
# -- update m_i and l_i
|
|
||||||
|
|
||||||
alpha = tl.math.exp(m_i - m_i_new)
|
|
||||||
l_i_new = alpha * l_i + l_ij
|
|
||||||
# -- update output accumulator --
|
|
||||||
# scale p
|
|
||||||
# scale acc
|
|
||||||
acc_scale = alpha
|
|
||||||
# acc_scale = l_i / l_i_new * alpha
|
|
||||||
acc = acc * acc_scale[:, None]
|
|
||||||
# update acc
|
|
||||||
v = tl.load(
|
|
||||||
V_cache + off_v,
|
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
|
||||||
acc += tl.dot(p, v)
|
|
||||||
# update m_i and l_i
|
|
||||||
l_i = l_i_new
|
|
||||||
m_i = m_i_new
|
|
||||||
|
|
||||||
off_k = (
|
|
||||||
offs_n[None, :] * stride_kbs
|
|
||||||
+ cur_kv_head * stride_kh
|
|
||||||
+ offs_d[:, None] * stride_kd
|
|
||||||
)
|
|
||||||
off_v = (
|
|
||||||
offs_n[:, None] * stride_vbs
|
|
||||||
+ cur_kv_head * stride_vh
|
|
||||||
+ offs_d[None, :] * stride_vd
|
|
||||||
)
|
|
||||||
k_ptrs = K + off_k
|
|
||||||
v_ptrs = V + off_v
|
|
||||||
|
|
||||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
|
||||||
|
|
||||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
# -- compute qk ----
|
|
||||||
k = tl.load(
|
|
||||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
||||||
qk += tl.dot(q, k)
|
|
||||||
qk *= sm_scale
|
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
|
||||||
|
|
||||||
# -- compute m_ij, p, l_ij
|
|
||||||
m_ij = tl.max(qk, 1)
|
|
||||||
m_i_new = tl.maximum(m_i, m_ij)
|
|
||||||
p = tl.math.exp(qk - m_i_new[:, None])
|
|
||||||
l_ij = tl.sum(p, 1)
|
|
||||||
# -- update m_i and l_i
|
|
||||||
|
|
||||||
alpha = tl.math.exp(m_i - m_i_new)
|
|
||||||
l_i_new = alpha * l_i + l_ij
|
|
||||||
# -- update output accumulator --
|
|
||||||
# scale p
|
|
||||||
# scale acc
|
|
||||||
acc_scale = alpha
|
|
||||||
# acc_scale = l_i / l_i_new * alpha
|
|
||||||
acc = acc * acc_scale[:, None]
|
|
||||||
# update acc
|
|
||||||
v = tl.load(
|
|
||||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
|
||||||
acc += tl.dot(p, v)
|
|
||||||
# update m_i and l_i
|
|
||||||
l_i = l_i_new
|
|
||||||
m_i = m_i_new
|
|
||||||
|
|
||||||
# acc /= l_i[:, None]
|
|
||||||
# initialize pointers to output
|
|
||||||
off_o = (
|
|
||||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
|
||||||
+ cur_head * stride_oh
|
|
||||||
+ offs_d[None, :] * stride_od
|
|
||||||
)
|
|
||||||
out_ptrs = Out + off_o
|
|
||||||
tl.store(
|
|
||||||
out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel_alibi(
|
def _fwd_kernel_alibi(
|
||||||
Q,
|
Q,
|
||||||
|
|||||||
@ -98,9 +98,6 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||||
def __init_(self, moe: FusedMoEConfig):
|
|
||||||
super().__init__(moe)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_moe_method(
|
def get_moe_method(
|
||||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user