mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 14:45:00 +08:00
[FLA] Introduce Kimi Delta Attention(KDA) to VLLM (#27654)
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
This commit is contained in:
parent
05e034f085
commit
e88bdd60d9
@ -36,7 +36,7 @@ def chunk_gated_delta_rule_fwd(
|
|||||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||||
# obtain WY representation. u is actually the new v.
|
# obtain WY representation. u is actually the new v.
|
||||||
A = chunk_scaled_dot_kkt_fwd(
|
A = chunk_scaled_dot_kkt_fwd(
|
||||||
k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
|
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
|
||||||
)
|
)
|
||||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||||
w, u = recompute_w_u_fwd(
|
w, u = recompute_w_u_fwd(
|
||||||
|
|||||||
@ -14,14 +14,15 @@ from vllm.triton_utils import tl, triton
|
|||||||
|
|
||||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||||
from .op import exp
|
from .op import exp
|
||||||
from .utils import is_nvidia_hopper, use_cuda_graph
|
from .utils import use_cuda_graph
|
||||||
|
|
||||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
NUM_WARPS = [2, 4, 8, 16]
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics(
|
@triton.heuristics(
|
||||||
{
|
{
|
||||||
"USE_G": lambda args: args["g"] is not None,
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
|
"USE_GK": lambda args: args["gk"] is not None,
|
||||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||||
@ -35,7 +36,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
|||||||
for num_stages in [2, 3, 4]
|
for num_stages in [2, 3, 4]
|
||||||
for BV in [32, 64]
|
for BV in [32, 64]
|
||||||
],
|
],
|
||||||
key=["H", "K", "V", "BT", "USE_G"],
|
key=["H", "K", "V", "BT"],
|
||||||
use_cuda_graph=use_cuda_graph,
|
use_cuda_graph=use_cuda_graph,
|
||||||
)
|
)
|
||||||
@triton.jit(do_not_specialize=["T"])
|
@triton.jit(do_not_specialize=["T"])
|
||||||
@ -45,6 +46,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
w,
|
w,
|
||||||
v_new,
|
v_new,
|
||||||
g,
|
g,
|
||||||
|
gk,
|
||||||
h,
|
h,
|
||||||
h0,
|
h0,
|
||||||
ht,
|
ht,
|
||||||
@ -58,6 +60,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
BT: tl.constexpr,
|
BT: tl.constexpr,
|
||||||
BV: tl.constexpr,
|
BV: tl.constexpr,
|
||||||
USE_G: tl.constexpr,
|
USE_G: tl.constexpr,
|
||||||
|
USE_GK: tl.constexpr,
|
||||||
USE_INITIAL_STATE: tl.constexpr,
|
USE_INITIAL_STATE: tl.constexpr,
|
||||||
STORE_FINAL_STATE: tl.constexpr,
|
STORE_FINAL_STATE: tl.constexpr,
|
||||||
SAVE_NEW_VALUE: tl.constexpr,
|
SAVE_NEW_VALUE: tl.constexpr,
|
||||||
@ -88,12 +91,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
|
||||||
# calculate offset
|
# calculate offset
|
||||||
h += (boh * H + i_h) * K * V
|
h += ((boh * H + i_h) * K * V).to(tl.int64)
|
||||||
v += (bos * H + i_h) * V
|
v += ((bos * H + i_h) * V).to(tl.int64)
|
||||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
|
||||||
w += (bos * H + i_h) * K
|
w += ((bos * H + i_h) * K).to(tl.int64)
|
||||||
if SAVE_NEW_VALUE:
|
if SAVE_NEW_VALUE:
|
||||||
v_new += (bos * H + i_h) * V
|
v_new += ((bos * H + i_h) * V).to(tl.int64)
|
||||||
stride_v = H * V
|
stride_v = H * V
|
||||||
stride_h = H * K * V
|
stride_h = H * K * V
|
||||||
stride_k = Hg * K
|
stride_k = Hg * K
|
||||||
@ -145,92 +148,115 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
)
|
)
|
||||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
p_v = tl.make_block_ptr(
|
|
||||||
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
|
||||||
)
|
|
||||||
p_v_new = (
|
|
||||||
tl.make_block_ptr(
|
|
||||||
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
|
||||||
)
|
|
||||||
if SAVE_NEW_VALUE
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
|
|
||||||
p_w = tl.make_block_ptr(
|
p_w = tl.make_block_ptr(
|
||||||
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
|
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
|
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||||
if K > 64:
|
if K > 64:
|
||||||
p_w = tl.make_block_ptr(
|
p_w = tl.make_block_ptr(
|
||||||
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
|
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
|
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||||
if K > 128:
|
if K > 128:
|
||||||
p_w = tl.make_block_ptr(
|
p_w = tl.make_block_ptr(
|
||||||
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
|
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
|
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||||
if K > 192:
|
if K > 192:
|
||||||
p_w = tl.make_block_ptr(
|
p_w = tl.make_block_ptr(
|
||||||
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
|
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
|
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||||
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
|
p_v = tl.make_block_ptr(
|
||||||
|
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
|
)
|
||||||
|
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
|
||||||
|
|
||||||
if SAVE_NEW_VALUE:
|
if SAVE_NEW_VALUE:
|
||||||
p_v_new = tl.make_block_ptr(
|
p_v = tl.make_block_ptr(
|
||||||
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
)
|
)
|
||||||
tl.store(
|
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
|
||||||
p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
last_idx = min((i_t + 1) * BT, T) - 1
|
||||||
if USE_G:
|
if USE_G:
|
||||||
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
||||||
last_idx = min((i_t + 1) * BT, T) - 1
|
|
||||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||||
p_g = tl.make_block_ptr(
|
p_g = tl.make_block_ptr(
|
||||||
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||||
)
|
)
|
||||||
b_g = tl.load(p_g, boundary_check=(0,))
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
|
b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
|
||||||
b_g_last = exp(b_g_last)
|
b_g_last = exp(b_g_last)
|
||||||
b_h1 = b_h1 * b_g_last
|
b_h1 *= b_g_last
|
||||||
if K > 64:
|
if K > 64:
|
||||||
b_h2 = b_h2 * b_g_last
|
b_h2 *= b_g_last
|
||||||
if K > 128:
|
if K > 128:
|
||||||
b_h3 = b_h3 * b_g_last
|
b_h3 *= b_g_last
|
||||||
if K > 192:
|
if K > 192:
|
||||||
b_h4 = b_h4 * b_g_last
|
b_h4 *= b_g_last
|
||||||
b_v_new = b_v_new.to(k.dtype.element_ty)
|
|
||||||
|
if USE_GK:
|
||||||
|
o_k1 = tl.arange(0, 64)
|
||||||
|
b_gk_last1 = tl.load(
|
||||||
|
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
|
||||||
|
mask=(o_k1 < K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b_h1 *= exp(b_gk_last1)[:, None]
|
||||||
|
if K > 64:
|
||||||
|
o_k2 = 64 + o_k1
|
||||||
|
b_gk_last2 = tl.load(
|
||||||
|
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
|
||||||
|
mask=(o_k2 < K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b_h2 *= exp(b_gk_last2)[:, None]
|
||||||
|
if K > 128:
|
||||||
|
o_k3 = 128 + o_k1
|
||||||
|
b_gk_last3 = tl.load(
|
||||||
|
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
|
||||||
|
mask=(o_k3 < K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b_h3 *= exp(b_gk_last3)[:, None]
|
||||||
|
if K > 192:
|
||||||
|
o_k4 = 192 + o_k1
|
||||||
|
b_gk_last4 = tl.load(
|
||||||
|
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
|
||||||
|
mask=(o_k4 < K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b_h4 *= exp(b_gk_last4)[:, None]
|
||||||
|
b_v = b_v.to(k.dtype.element_ty)
|
||||||
|
|
||||||
p_k = tl.make_block_ptr(
|
p_k = tl.make_block_ptr(
|
||||||
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
|
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_h1 += tl.dot(b_k, b_v_new)
|
b_h1 += tl.dot(b_k, b_v)
|
||||||
if K > 64:
|
if K > 64:
|
||||||
p_k = tl.make_block_ptr(
|
p_k = tl.make_block_ptr(
|
||||||
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
|
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_h2 += tl.dot(b_k, b_v_new)
|
b_h2 += tl.dot(b_k, b_v)
|
||||||
if K > 128:
|
if K > 128:
|
||||||
p_k = tl.make_block_ptr(
|
p_k = tl.make_block_ptr(
|
||||||
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
|
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_h3 += tl.dot(b_k, b_v_new)
|
b_h3 += tl.dot(b_k, b_v)
|
||||||
if K > 192:
|
if K > 192:
|
||||||
p_k = tl.make_block_ptr(
|
p_k = tl.make_block_ptr(
|
||||||
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
|
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_h4 += tl.dot(b_k, b_v_new)
|
b_h4 += tl.dot(b_k, b_v)
|
||||||
|
|
||||||
# epilogue
|
# epilogue
|
||||||
if STORE_FINAL_STATE:
|
if STORE_FINAL_STATE:
|
||||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||||
@ -257,12 +283,15 @@ def chunk_gated_delta_rule_fwd_h(
|
|||||||
w: torch.Tensor,
|
w: torch.Tensor,
|
||||||
u: torch.Tensor,
|
u: torch.Tensor,
|
||||||
g: torch.Tensor | None = None,
|
g: torch.Tensor | None = None,
|
||||||
|
gk: torch.Tensor | None = None,
|
||||||
initial_state: torch.Tensor | None = None,
|
initial_state: torch.Tensor | None = None,
|
||||||
output_final_state: bool = False,
|
output_final_state: bool = False,
|
||||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||||
save_new_value: bool = True,
|
save_new_value: bool = True,
|
||||||
cu_seqlens: torch.LongTensor | None = None,
|
cu_seqlens: torch.LongTensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||||
|
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||||
H = u.shape[-2]
|
H = u.shape[-2]
|
||||||
BT = chunk_size
|
BT = chunk_size
|
||||||
@ -299,6 +328,7 @@ def chunk_gated_delta_rule_fwd_h(
|
|||||||
w=w,
|
w=w,
|
||||||
v_new=v_new,
|
v_new=v_new,
|
||||||
g=g,
|
g=g,
|
||||||
|
gk=gk,
|
||||||
h=h,
|
h=h,
|
||||||
h0=initial_state,
|
h0=initial_state,
|
||||||
ht=final_state,
|
ht=final_state,
|
||||||
|
|||||||
@ -18,8 +18,8 @@ from .op import exp
|
|||||||
|
|
||||||
@triton.heuristics(
|
@triton.heuristics(
|
||||||
{
|
{
|
||||||
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
"USE_G": lambda args: args["g_cumsum"] is not None,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
@ -35,7 +35,7 @@ from .op import exp
|
|||||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||||
k,
|
k,
|
||||||
beta,
|
beta,
|
||||||
g_cumsum,
|
g,
|
||||||
A,
|
A,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
chunk_indices,
|
chunk_indices,
|
||||||
@ -85,9 +85,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
|||||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
||||||
|
|
||||||
if USE_G:
|
if USE_G:
|
||||||
p_g = tl.make_block_ptr(
|
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||||
g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
|
||||||
)
|
|
||||||
b_g = tl.load(p_g, boundary_check=(0,))
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||||
b_A = b_A * exp(b_g_diff)
|
b_A = b_A * exp(b_g_diff)
|
||||||
@ -102,8 +100,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
|||||||
|
|
||||||
def chunk_scaled_dot_kkt_fwd(
|
def chunk_scaled_dot_kkt_fwd(
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
beta: torch.Tensor,
|
g: torch.Tensor | None = None,
|
||||||
g_cumsum: torch.Tensor | None = None,
|
beta: torch.Tensor | None = None,
|
||||||
cu_seqlens: torch.LongTensor | None = None,
|
cu_seqlens: torch.LongTensor | None = None,
|
||||||
chunk_size: int = 64,
|
chunk_size: int = 64,
|
||||||
output_dtype: torch.dtype = torch.float32,
|
output_dtype: torch.dtype = torch.float32,
|
||||||
@ -116,9 +114,8 @@ def chunk_scaled_dot_kkt_fwd(
|
|||||||
The key tensor of shape `[B, T, H, K]`.
|
The key tensor of shape `[B, T, H, K]`.
|
||||||
beta (torch.Tensor):
|
beta (torch.Tensor):
|
||||||
The beta tensor of shape `[B, T, H]`.
|
The beta tensor of shape `[B, T, H]`.
|
||||||
g_cumsum (torch.Tensor):
|
g (torch.Tensor):
|
||||||
The cumulative sum of the gate tensor of shape `[B, T, H]`.
|
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||||
Default: None
|
|
||||||
cu_seqlens (torch.LongTensor):
|
cu_seqlens (torch.LongTensor):
|
||||||
The cumulative sequence lengths of the input tensor.
|
The cumulative sequence lengths of the input tensor.
|
||||||
Default: None
|
Default: None
|
||||||
@ -130,20 +127,21 @@ def chunk_scaled_dot_kkt_fwd(
|
|||||||
Returns:
|
Returns:
|
||||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||||
"""
|
"""
|
||||||
|
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||||
|
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||||
B, T, Hg, K = k.shape
|
B, T, Hg, K = k.shape
|
||||||
|
|
||||||
H = beta.shape[-1]
|
H = beta.shape[-1]
|
||||||
BT = chunk_size
|
BT = chunk_size
|
||||||
chunk_indices = (
|
chunk_indices = (
|
||||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
)
|
)
|
||||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
|
||||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||||
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
||||||
k=k,
|
k=k,
|
||||||
|
g=g,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
g_cumsum=g_cumsum,
|
|
||||||
A=A,
|
A=A,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
|
|||||||
@ -57,6 +57,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
IS_VARLEN: tl.constexpr,
|
IS_VARLEN: tl.constexpr,
|
||||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||||
IS_SPEC_DECODING: tl.constexpr,
|
IS_SPEC_DECODING: tl.constexpr,
|
||||||
|
IS_KDA: tl.constexpr,
|
||||||
):
|
):
|
||||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||||
@ -86,7 +87,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
||||||
else:
|
else:
|
||||||
p_beta = beta + bos * HV + i_hv
|
p_beta = beta + bos * HV + i_hv
|
||||||
p_g = g + bos * HV + i_hv
|
|
||||||
|
if not IS_KDA:
|
||||||
|
p_g = g + bos * HV + i_hv
|
||||||
|
else:
|
||||||
|
p_gk = g + (bos * HV + i_hv) * K + o_k
|
||||||
|
|
||||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||||
|
|
||||||
mask_k = o_k < K
|
mask_k = o_k < K
|
||||||
@ -116,14 +122,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||||
b_g = tl.load(p_g).to(tl.float32)
|
|
||||||
|
|
||||||
if USE_QK_L2NORM_IN_KERNEL:
|
if USE_QK_L2NORM_IN_KERNEL:
|
||||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||||
b_q = b_q * scale
|
b_q = b_q * scale
|
||||||
# [BK, BV]
|
# [BK, BV]
|
||||||
b_h *= exp(b_g)
|
if not IS_KDA:
|
||||||
|
b_g = tl.load(p_g).to(tl.float32)
|
||||||
|
b_h *= exp(b_g)
|
||||||
|
else:
|
||||||
|
b_gk = tl.load(p_gk).to(tl.float32)
|
||||||
|
b_h *= exp(b_gk[:, None])
|
||||||
# [BV]
|
# [BV]
|
||||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||||
if IS_BETA_HEADWISE:
|
if IS_BETA_HEADWISE:
|
||||||
@ -155,7 +165,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
p_k += H * K
|
p_k += H * K
|
||||||
p_o += HV * V
|
p_o += HV * V
|
||||||
p_v += HV * V
|
p_v += HV * V
|
||||||
p_g += HV
|
if not IS_KDA:
|
||||||
|
p_g += HV
|
||||||
|
else:
|
||||||
|
p_gk += HV * K
|
||||||
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
||||||
|
|
||||||
|
|
||||||
@ -228,6 +241,7 @@ def fused_recurrent_gated_delta_rule_fwd(
|
|||||||
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
||||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||||
INPLACE_FINAL_STATE=inplace_final_state,
|
INPLACE_FINAL_STATE=inplace_final_state,
|
||||||
|
IS_KDA=False,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
)
|
)
|
||||||
|
|||||||
1351
vllm/model_executor/layers/fla/ops/kda.py
Normal file
1351
vllm/model_executor/layers/fla/ops/kda.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user