diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index b046a6d3919e..4c8bf9f43997 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -36,7 +36,7 @@ def chunk_gated_delta_rule_fwd( g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. A = chunk_scaled_dot_kkt_fwd( - k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 + k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 ) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) w, u = recompute_w_u_fwd( diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py index 1c14f84c2b89..f0b78b65c4a3 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -14,14 +14,15 @@ from vllm.triton_utils import tl, triton from .index import prepare_chunk_indices, prepare_chunk_offsets from .op import exp -from .utils import is_nvidia_hopper, use_cuda_graph +from .utils import use_cuda_graph -NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] +NUM_WARPS = [2, 4, 8, 16] @triton.heuristics( { "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, "USE_INITIAL_STATE": lambda args: args["h0"] is not None, "STORE_FINAL_STATE": lambda args: args["ht"] is not None, "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, @@ -35,7 +36,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] for num_stages in [2, 3, 4] for BV in [32, 64] ], - key=["H", "K", "V", "BT", "USE_G"], + key=["H", "K", "V", "BT"], use_cuda_graph=use_cuda_graph, ) @triton.jit(do_not_specialize=["T"]) @@ -45,6 +46,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( w, v_new, g, + gk, h, h0, ht, @@ -58,6 +60,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( BT: tl.constexpr, BV: tl.constexpr, USE_G: tl.constexpr, + USE_GK: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, SAVE_NEW_VALUE: tl.constexpr, @@ -88,12 +91,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( b_h4 = tl.zeros([64, BV], dtype=tl.float32) # calculate offset - h += (boh * H + i_h) * K * V - v += (bos * H + i_h) * V - k += (bos * Hg + i_h // (H // Hg)) * K - w += (bos * H + i_h) * K + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) if SAVE_NEW_VALUE: - v_new += (bos * H + i_h) * V + v_new += ((bos * H + i_h) * V).to(tl.int64) stride_v = H * V stride_h = H * K * V stride_k = Hg * K @@ -145,92 +148,115 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) - p_v = tl.make_block_ptr( - v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) - ) - p_v_new = ( - tl.make_block_ptr( - v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) - ) - if SAVE_NEW_VALUE - else None - ) - b_v_new = tl.zeros([BT, BV], dtype=tl.float32) p_w = tl.make_block_ptr( w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0) ) b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) if K > 64: p_w = tl.make_block_ptr( w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0) ) b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) if K > 128: p_w = tl.make_block_ptr( w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0) ) b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) if K > 192: p_w = tl.make_block_ptr( w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0) ) b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) - b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr( + v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v if SAVE_NEW_VALUE: - p_v_new = tl.make_block_ptr( + p_v = tl.make_block_ptr( v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) ) - tl.store( - p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1) - ) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min((i_t + 1) * BT, T) - 1 if USE_G: m_t = (i_t * BT + tl.arange(0, BT)) < T - last_idx = min((i_t + 1) * BT, T) - 1 b_g_last = tl.load(g + bos * H + last_idx * H + i_h) p_g = tl.make_block_ptr( g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) ) b_g = tl.load(p_g, boundary_check=(0,)) - b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] b_g_last = exp(b_g_last) - b_h1 = b_h1 * b_g_last + b_h1 *= b_g_last if K > 64: - b_h2 = b_h2 * b_g_last + b_h2 *= b_g_last if K > 128: - b_h3 = b_h3 * b_g_last + b_h3 *= b_g_last if K > 192: - b_h4 = b_h4 * b_g_last - b_v_new = b_v_new.to(k.dtype.element_ty) + b_h4 *= b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k1, + mask=(o_k1 < K), + other=0.0, + ) + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k2, + mask=(o_k2 < K), + other=0.0, + ) + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k3, + mask=(o_k3 < K), + other=0.0, + ) + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k4, + mask=(o_k4 < K), + other=0.0, + ) + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + p_k = tl.make_block_ptr( k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) ) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h1 += tl.dot(b_k, b_v_new) + b_h1 += tl.dot(b_k, b_v) if K > 64: p_k = tl.make_block_ptr( k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1) ) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h2 += tl.dot(b_k, b_v_new) + b_h2 += tl.dot(b_k, b_v) if K > 128: p_k = tl.make_block_ptr( k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1) ) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h3 += tl.dot(b_k, b_v_new) + b_h3 += tl.dot(b_k, b_v) if K > 192: p_k = tl.make_block_ptr( k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1) ) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h4 += tl.dot(b_k, b_v_new) - + b_h4 += tl.dot(b_k, b_v) # epilogue if STORE_FINAL_STATE: p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) @@ -257,12 +283,15 @@ def chunk_gated_delta_rule_fwd_h( w: torch.Tensor, u: torch.Tensor, g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, initial_state: torch.Tensor | None = None, output_final_state: bool = False, chunk_size: int = 64, # SY: remove this argument and force chunk size 64? save_new_value: bool = True, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. B, T, Hg, K, V = *k.shape, u.shape[-1] H = u.shape[-2] BT = chunk_size @@ -299,6 +328,7 @@ def chunk_gated_delta_rule_fwd_h( w=w, v_new=v_new, g=g, + gk=gk, h=h, h0=initial_state, ht=final_state, diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py index 975e119af333..7724fa513d92 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -18,8 +18,8 @@ from .op import exp @triton.heuristics( { + "USE_G": lambda args: args["g"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, - "USE_G": lambda args: args["g_cumsum"] is not None, } ) @triton.autotune( @@ -35,7 +35,7 @@ from .op import exp def chunk_scaled_dot_kkt_fwd_kernel( k, beta, - g_cumsum, + g, A, cu_seqlens, chunk_indices, @@ -85,9 +85,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) if USE_G: - p_g = tl.make_block_ptr( - g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) - ) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) b_g_diff = b_g[:, None] - b_g[None, :] b_A = b_A * exp(b_g_diff) @@ -102,8 +100,8 @@ def chunk_scaled_dot_kkt_fwd_kernel( def chunk_scaled_dot_kkt_fwd( k: torch.Tensor, - beta: torch.Tensor, - g_cumsum: torch.Tensor | None = None, + g: torch.Tensor | None = None, + beta: torch.Tensor | None = None, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, output_dtype: torch.dtype = torch.float32, @@ -116,9 +114,8 @@ def chunk_scaled_dot_kkt_fwd( The key tensor of shape `[B, T, H, K]`. beta (torch.Tensor): The beta tensor of shape `[B, T, H]`. - g_cumsum (torch.Tensor): - The cumulative sum of the gate tensor of shape `[B, T, H]`. - Default: None + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. cu_seqlens (torch.LongTensor): The cumulative sequence lengths of the input tensor. Default: None @@ -130,20 +127,21 @@ def chunk_scaled_dot_kkt_fwd( Returns: beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. """ - + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. B, T, Hg, K = k.shape - H = beta.shape[-1] BT = chunk_size chunk_indices = ( prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None ) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( k=k, + g=g, beta=beta, - g_cumsum=g_cumsum, A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index f3de1bfa2821..0f27504780ac 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -57,6 +57,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hv = i_nh // HV, i_nh % HV @@ -86,7 +87,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_beta = beta + (bos * HV + i_hv) * V + o_v else: p_beta = beta + bos * HV + i_hv - p_g = g + bos * HV + i_hv + + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v mask_k = o_k < K @@ -116,14 +122,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) - b_g = tl.load(p_g).to(tl.float32) if USE_QK_L2NORM_IN_KERNEL: b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) b_q = b_q * scale # [BK, BV] - b_h *= exp(b_g) + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) # [BV] b_v -= tl.sum(b_h * b_k[:, None], 0) if IS_BETA_HEADWISE: @@ -155,7 +165,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_k += H * K p_o += HV * V p_v += HV * V - p_g += HV + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K p_beta += HV * (V if IS_BETA_HEADWISE else 1) @@ -228,6 +241,7 @@ def fused_recurrent_gated_delta_rule_fwd( IS_BETA_HEADWISE=beta.ndim == v.ndim, USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, num_warps=num_warps, num_stages=num_stages, ) diff --git a/vllm/model_executor/layers/fla/ops/kda.py b/vllm/model_executor/layers/fla/ops/kda.py new file mode 100644 index 000000000000..a10847d347d1 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/kda.py @@ -0,0 +1,1351 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + + +import torch +import torch.nn as nn + +from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import cdiv, next_power_of_2 + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .cumsum import chunk_local_cumsum +from .fused_recurrent import fused_recurrent_gated_delta_rule_fwd_kernel +from .index import prepare_chunk_indices +from .l2norm import l2norm_fwd +from .op import exp, log +from .solve_tril import solve_tril +from .utils import is_amd + +BT_LIST_AUTOTUNE = [32, 64, 128] +NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32] + + +def fused_recurrent_kda_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8) + NK, NV = cdiv(K, BK), cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = torch.empty_like(k) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=True, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o, final_state + + +def fused_recurrent_kda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + use_qk_l2norm_in_kernel: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.LongTensor | None = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o, final_state = fused_recurrent_kda_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=None, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + return o, final_state + + +@triton.heuristics( + { + "STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None, + "HAS_RESIDUAL": lambda args: args["residual"] is not None, + "HAS_WEIGHT": lambda args: args["w"] is not None, + "HAS_BIAS": lambda args: args["b"] is not None, + } +) +@triton.jit +def layer_norm_gated_fwd_kernel( + x, # pointer to the input + g, # pointer to the gate + y, # pointer to the output + w, # pointer to the weights + b, # pointer to the biases + residual, # pointer to the residual + residual_out, # pointer to the residual + mean, # pointer to the mean + rstd, # pointer to the 1/std + eps, # epsilon to avoid division by zero + T, # number of rows in x + D: tl.constexpr, # number of columns in x + BT: tl.constexpr, + BD: tl.constexpr, + ACTIVATION: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + i_t = tl.program_id(0) + + o_d = tl.arange(0, BD) + m_d = o_d < D + + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + if HAS_RESIDUAL: + p_res = tl.make_block_ptr( + residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0) + ) + b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32) + if STORE_RESIDUAL_OUT: + p_res_out = tl.make_block_ptr( + residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0) + ) + tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1)) + if not IS_RMS_NORM: + b_mean = tl.sum(b_x, axis=1) / D + p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,)) + b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=1) / D + else: + b_xbar = tl.where(m_d[None, :], b_x, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=1) / D + b_rstd = 1 / tl.sqrt(b_var + eps) + + p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,)) + + if HAS_WEIGHT: + b_w = tl.load(w + o_d, mask=m_d).to(tl.float32) + if HAS_BIAS: + b_b = tl.load(b + o_d, mask=m_d).to(tl.float32) + b_x_hat = ( + (b_x - b_mean[:, None]) * b_rstd[:, None] + if not IS_RMS_NORM + else b_x * b_rstd[:, None] + ) + b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat + if HAS_BIAS: + b_y = b_y + b_b[None, :] + + # swish/sigmoid output gate + p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + if ACTIVATION == "swish" or ACTIVATION == "silu": + b_y = b_y * b_g * tl.sigmoid(b_g) + elif ACTIVATION == "sigmoid": + b_y = b_y * tl.sigmoid(b_g) + + # Write output + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None, + "HAS_RESIDUAL": lambda args: args["residual"] is not None, + "HAS_WEIGHT": lambda args: args["w"] is not None, + "HAS_BIAS": lambda args: args["b"] is not None, + } +) +@triton.jit +def layer_norm_gated_fwd_kernel1( + x, # pointer to the input + g, # pointer to the gate + y, # pointer to the output + w, # pointer to the weights + b, # pointer to the biases + residual, # pointer to the residual + residual_out, # pointer to the residual + mean, # pointer to the mean + rstd, # pointer to the 1/std + eps, # epsilon to avoid division by zero + D: tl.constexpr, # number of columns in x + BD: tl.constexpr, + ACTIVATION: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + g += i_t * D + if HAS_RESIDUAL: + residual += i_t * D + if STORE_RESIDUAL_OUT: + residual_out += i_t * D + + o_d = tl.arange(0, BD) + m_d = o_d < D + b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32) + if STORE_RESIDUAL_OUT: + tl.store(residual_out + o_d, b_x, mask=m_d) + if not IS_RMS_NORM: + b_mean = tl.sum(b_x, axis=0) / D + tl.store(mean + i_t, b_mean) + b_xbar = tl.where(m_d, b_x - b_mean, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=0) / D + else: + b_xbar = tl.where(m_d, b_x, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=0) / D + b_rstd = 1 / tl.sqrt(b_var + eps) + tl.store(rstd + i_t, b_rstd) + + if HAS_WEIGHT: + b_w = tl.load(w + o_d, mask=m_d).to(tl.float32) + if HAS_BIAS: + b_b = tl.load(b + o_d, mask=m_d).to(tl.float32) + b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd + b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat + if HAS_BIAS: + b_y = b_y + b_b + + # swish/sigmoid output gate + b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32) + if ACTIVATION == "swish" or ACTIVATION == "silu": + b_y = b_y * b_g * tl.sigmoid(b_g) + elif ACTIVATION == "sigmoid": + b_y = b_y * tl.sigmoid(b_g) + + # Write output + tl.store(y + o_d, b_y, mask=m_d) + + +def layer_norm_gated_fwd( + x: torch.Tensor, + g: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + activation: str = "swish", + eps: float = 1e-5, + residual: torch.Tensor = None, + out_dtype: torch.dtype = None, + residual_dtype: torch.dtype = None, + is_rms_norm: bool = False, +): + if residual is not None: + residual_dtype = residual.dtype + T, D = x.shape + if residual is not None: + assert residual.shape == (T, D) + if weight is not None: + assert weight.shape == (D,) + if bias is not None: + assert bias.shape == (D,) + # allocate output + y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype) + if residual is not None or ( + residual_dtype is not None and residual_dtype != x.dtype + ): + residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype) + else: + residual_out = None + mean = ( + torch.empty((T,), dtype=torch.float, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((T,), dtype=torch.float, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + + if D <= 512: + BT = 32 + layer_norm_gated_fwd_kernel[(cdiv(T, BT),)]( + x=x, + g=g, + y=y, + w=weight, + b=bias, + residual=residual, + residual_out=residual_out, + mean=mean, + rstd=rstd, + eps=eps, + T=T, + D=D, + BD=BD, + BT=BT, + ACTIVATION=activation, + IS_RMS_NORM=is_rms_norm, + num_warps=4, + ) + else: + layer_norm_gated_fwd_kernel1[(T,)]( + x=x, + g=g, + y=y, + w=weight, + b=bias, + residual=residual, + residual_out=residual_out, + mean=mean, + rstd=rstd, + eps=eps, + D=D, + BD=BD, + ACTIVATION=activation, + IS_RMS_NORM=is_rms_norm, + num_warps=4, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +def rms_norm_gated( + x: torch.Tensor, + g: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + activation: str = "swish", + residual: torch.Tensor | None = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + eps: float = 1e-6, +): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.contiguous().reshape(-1, x.shape[-1]) + g = g.contiguous().reshape(-1, g.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.contiguous().reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float if residual_in_fp32 else None) + ) + y, _, _, residual_out = layer_norm_gated_fwd( + x=x, + g=g, + weight=weight, + bias=bias, + activation=activation, + eps=eps, + residual=residual, + residual_dtype=residual_dtype, + is_rms_norm=True, + ) + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + +class FusedRMSNormGated(nn.Module): + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + activation: str = "swish", + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + self.activation = activation + + if self.activation not in ["swish", "silu", "sigmoid"]: + raise ValueError(f"Unsupported activation: {self.activation}") + + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def forward( + self, + x: torch.Tensor, + g: torch.Tensor, + residual: torch.Tensor | None = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + ) -> torch.Tensor: + return rms_norm_gated( + x, + g, + self.weight, + self.bias, + self.activation, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BC"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter( + q, + k, + g, + beta, + A, + Aqk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + A += (bos * H + i_h) * BT + Aqk += (bos * H + i_h) * BT + + p_b = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,) + ) + b_b = tl.load(p_b, boundary_check=(0,)) + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr( + q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0) + ) + p_k = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0) + ) + p_g = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0) + ) + b_kt = tl.make_block_ptr( + k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1) + ) + p_gk = tl.make_block_ptr( + g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1) + ) + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + # [BK,] + b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :]) + # [BK, BC] + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kt = tl.load(b_kt, boundary_check=(0, 1)) + # [BC, BC] + b_ktg = b_kt * exp(b_gn[:, None] - b_gk) + b_A += tl.dot(b_k, b_ktg) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_qg = b_q * exp(b_g - b_gn[None, :]) * scale + b_Aqk += tl.dot(b_qg, b_ktg) + + b_A *= b_b[:, None] + + p_A = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0) + ) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + p_Aqk = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0) + ) + tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BK", "BT"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra( + q, + k, + g, + beta, + A, + Aqk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + o_i) < T + o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC + + p_q = tl.make_block_ptr( + q + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT + i_i * BC, 0), + (BC, BK), + (1, 0), + ) + p_k = tl.make_block_ptr( + k + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT + i_i * BC, 0), + (BC, BK), + (1, 0), + ) + p_g = tl.make_block_ptr( + g + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT + i_i * BC, 0), + (BC, BK), + (1, 0), + ) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h + b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None] + + p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :]) + b_A = tl.sum(b_k * b_ktg, 1) + b_A = tl.where(o_i > j, b_A, 0.0) + b_Aqk = tl.sum(b_q * b_ktg, 1) + b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0) + tl.store(A + o_A + j, b_A, mask=m_A) + tl.store(Aqk + o_A + j, b_Aqk, mask=m_A) + p_kt += H * K + p_gk += H * K + + +def chunk_kda_scaled_dot_kkt_fwd( + q: torch.Tensor, + k: torch.Tensor, + gk: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + gk (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + B, T, H, K = k.shape + assert K <= 256 + BT = chunk_size + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) + NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BC = min(16, BT) + NC = cdiv(BT, BC) + BK = max(next_power_of_2(K), 16) + A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype) + Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype) + grid = (NT, NC * NC, B * H) + chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid]( + q=q, + k=k, + g=gk, + beta=beta, + A=A, + Aqk=Aqk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + NC=NC, + ) + + grid = (NT, NC, B * H) + chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid]( + q=q, + k=k, + g=gk, + beta=beta, + A=A, + Aqk=Aqk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + ) + return A, Aqk + + +@triton.heuristics( + { + "STORE_QG": lambda args: args["qg"] is not None, + "STORE_KG": lambda args: args["kg"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + q, + k, + qg, + kg, + v, + beta, + w, + u, + A, + gk, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + STORE_QG: tl.constexpr, + STORE_KG: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_k = tl.make_block_ptr( + k + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_b[:, None] + + p_gk = tl.make_block_ptr( + gk + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kb *= exp(b_gk) + if STORE_QG: + p_q = tl.make_block_ptr( + q + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_qg = tl.make_block_ptr( + qg + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_qg = b_q * exp(b_gk) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1)) + if STORE_KG: + last_idx = min(i_t * BT + BT, T) - 1 + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + b_gn = tl.load( + gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0 + ) + b_kg = b_k * exp(b_gn - b_gk) + + p_kg = tl.make_block_ptr( + kg + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1)) + + b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + q: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = A.shape[-1] + BK = 64 + BV = 64 + + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) + NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + w = torch.empty_like(k) + u = torch.empty_like(v) + kg = torch.empty_like(k) if gk is not None else None + recompute_w_u_fwd_kernel[(NT, B * H)]( + q=q, + k=k, + qg=None, + kg=kg, + v=v, + beta=beta, + w=w, + u=u, + A=A, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + DOT_PRECISION="ieee", + ) + return w, u, None, kg + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_fwd_kernel_o( + q, + v, + g, + h, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr( + q + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_g = tl.make_block_ptr( + g + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_h = tl.make_block_ptr( + h + (i_tg * H + i_h) * K * V, + (K, V), + (V, 1), + (i_k * BK, i_v * BV), + (BK, BV), + (1, 0), + ) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_fwd_o_gk( + q: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, + o: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, +): + B, T, H, K, V = *q.shape, v.shape[-1] + BT = chunk_size + + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None + else None + ) + NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + def grid(meta): + return (cdiv(V, meta["BV"]), NT, B * H) + + chunk_gla_fwd_kernel_o[grid]( + q=q, + v=v, + g=g, + h=h, + o=o, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return o + + +def chunk_kda_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, +): + chunk_size = 64 + g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) + # the intra Aqk is kept in fp32 + # the computation has very marginal effect on the entire throughput + A, Aqk = chunk_kda_scaled_dot_kkt_fwd( + q=q, + k=k, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + output_dtype=torch.float32, + ) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u, _, kg = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + gk=g, + cu_seqlens=cu_seqlens, + ) + del A + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=kg, + w=w, + u=u, + gk=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + del w, u, kg + o = chunk_gla_fwd_o_gk( + q=q, + v=v_new, + g=g, + A=Aqk, + h=h, + o=v, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + del Aqk, v_new, h + return o, final_state + + +def chunk_kda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: torch.LongTensor | None = None, + **kwargs, +): + if scale is None: + scale = k.shape[-1] ** -0.5 + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q.contiguous()) + k = l2norm_fwd(k.contiguous()) + + o, final_state = chunk_kda_fwd( + q=q, + k=k, + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state.contiguous(), + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + return o, final_state + + +@triton.autotune( + configs=[ + triton.Config({"BT": bt}, num_warps=nw, num_stages=ns) + for bt in BT_LIST_AUTOTUNE + for nw in NUM_WARPS_AUTOTUNE + for ns in [2, 3] + ], + key=["H", "D"], +) +@triton.jit +def kda_gate_fwd_kernel( + g, + A, + y, + g_bias, + beta: tl.constexpr, + threshold: tl.constexpr, + T, + H, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + i_t, i_h = tl.program_id(0), tl.program_id(1) + n_t = i_t * BT + + b_a = tl.load(A + i_h).to(tl.float32) + b_a = -tl.exp(b_a) + + stride_row = H * D + stride_col = 1 + + g_ptr = tl.make_block_ptr( + base=g + i_h * D, + shape=(T, D), + strides=(stride_row, stride_col), + offsets=(n_t, 0), + block_shape=(BT, BD), + order=(1, 0), + ) + + y_ptr = tl.make_block_ptr( + base=y + i_h * D, + shape=(T, D), + strides=(stride_row, stride_col), + offsets=(n_t, 0), + block_shape=(BT, BD), + order=(1, 0), + ) + + b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32) + + if HAS_BIAS: + n_d = tl.arange(0, BD) + bias_mask = n_d < D + b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to( + tl.float32 + ) + b_g = b_g + b_bias[None, :] + + # softplus(x, beta) = (1/beta) * log(1 + exp(beta * x)) + # When beta * x > threshold, use linear approximation x + # Use threshold to switch to linear when beta*x > threshold + g_scaled = b_g * beta + use_linear = g_scaled > threshold + sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled))) + b_y = b_a * sp + + tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1)) + + +def kda_gate_fwd( + g: torch.Tensor, + A: torch.Tensor, + head_k_dim: int, + g_bias: torch.Tensor | None = None, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + """ + Forward pass for KDA gate: + input g: [..., H*D] + param A: [H] or [1, 1, H, 1] + beta: softplus beta parameter + threshold: softplus threshold parameter + return : [..., H, D] + """ + orig_shape = g.shape[:-1] + + g = g.view(-1, g.shape[-1]) + T = g.shape[0] + HD = g.shape[1] + H = A.numel() + assert H * head_k_dim == HD + + y = torch.empty_like(g, dtype=torch.float32) + + def grid(meta): + return (cdiv(T, meta["BT"]), H) + + kda_gate_fwd_kernel[grid]( + g, + A, + y, + g_bias, + beta, + threshold, + T, + H, + head_k_dim, + BD=next_power_of_2(head_k_dim), + HAS_BIAS=g_bias is not None, + ) + + y = y.view(*orig_shape, H, head_k_dim) + return y