mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:15:01 +08:00
[Kernel] Accelerate solve_tril with TMA (#26746)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
b63f2143f8
commit
9fce7bee74
@ -11,29 +11,50 @@ import os
|
|||||||
|
|
||||||
from vllm.triton_utils import tl, tldevice, triton
|
from vllm.triton_utils import tl, tldevice, triton
|
||||||
|
|
||||||
|
from .utils import is_gather_supported
|
||||||
|
|
||||||
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
|
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
|
||||||
div = tldevice.fast_dividef
|
|
||||||
exp = tldevice.fast_expf
|
exp = tldevice.fast_expf
|
||||||
log = tldevice.fast_logf
|
log = tldevice.fast_logf
|
||||||
log2 = tldevice.fast_log2f
|
log2 = tldevice.fast_log2f
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def div_normal(x, y):
|
|
||||||
return x / y
|
|
||||||
|
|
||||||
div = div_normal
|
|
||||||
exp = tl.exp
|
exp = tl.exp
|
||||||
log = tl.log
|
log = tl.log
|
||||||
log2 = tl.log2
|
log2 = tl.log2
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(tl, "gather"):
|
if not is_gather_supported:
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def gather(src, index, axis, _builder=None):
|
def gather(src, index, axis, _builder=None):
|
||||||
# This is a fallback implementation when tl.gather is not supported
|
"""
|
||||||
# In order to pass triton compiler, there is no actual gather operation
|
Gather operation that works when tl.gather is not supported.
|
||||||
return src
|
This is a fallback implementation that returns None.
|
||||||
|
Just to make triton compiler happy.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
gather = tl.gather
|
gather = tl.gather
|
||||||
|
|
||||||
|
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
|
||||||
|
# For Triton 3.3.x
|
||||||
|
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
|
||||||
|
elif hasattr(triton.language, "make_tensor_descriptor"):
|
||||||
|
# For Triton 3.4.x and later
|
||||||
|
make_tensor_descriptor = triton.language.make_tensor_descriptor
|
||||||
|
else:
|
||||||
|
"""
|
||||||
|
Fallback implementation when TMA is not supported.
|
||||||
|
Returns None to indicate TMA descriptors are unavailable.
|
||||||
|
Just make triton compiler happy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def make_tensor_descriptor(
|
||||||
|
base,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
block_shape,
|
||||||
|
_builder=None,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|||||||
@ -8,12 +8,21 @@
|
|||||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .index import prepare_chunk_indices
|
from .index import prepare_chunk_indices
|
||||||
from .utils import input_guard
|
from .op import make_tensor_descriptor
|
||||||
|
from .utils import input_guard, is_amd, is_tma_supported
|
||||||
|
|
||||||
|
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
|
||||||
|
ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"]
|
||||||
|
assert FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS, (
|
||||||
|
f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
@ -28,13 +37,15 @@ from .utils import input_guard
|
|||||||
@triton.jit(do_not_specialize=["T"])
|
@triton.jit(do_not_specialize=["T"])
|
||||||
def solve_tril_16x16_kernel(
|
def solve_tril_16x16_kernel(
|
||||||
A,
|
A,
|
||||||
Ad,
|
Ai,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
chunk_indices,
|
chunk_indices,
|
||||||
T,
|
T,
|
||||||
H: tl.constexpr,
|
H: tl.constexpr,
|
||||||
BT: tl.constexpr,
|
BT: tl.constexpr,
|
||||||
|
USE_TMA: tl.constexpr,
|
||||||
IS_VARLEN: tl.constexpr,
|
IS_VARLEN: tl.constexpr,
|
||||||
|
DOT_PRECISION: tl.constexpr,
|
||||||
):
|
):
|
||||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||||
i_b, i_h = i_bh // H, i_bh % H
|
i_b, i_h = i_bh // H, i_bh % H
|
||||||
@ -50,30 +61,43 @@ def solve_tril_16x16_kernel(
|
|||||||
T = eos - bos
|
T = eos - bos
|
||||||
else:
|
else:
|
||||||
bos, eos = i_b * T, i_b * T + T
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
o_i = tl.arange(0, 16)
|
||||||
|
m_A = o_i[:, None] > o_i[None, :]
|
||||||
|
m_I = o_i[:, None] == o_i[None, :]
|
||||||
|
|
||||||
A = A + (bos * H + i_h) * BT
|
A = A + (bos * H + i_h) * BT
|
||||||
Ad = Ad + (bos * H + i_h) * 16
|
Ai = Ai + (bos * H + i_h) * 16
|
||||||
|
|
||||||
offset = (i_t * 16) % BT
|
offset = (i_t * 16) % BT
|
||||||
p_A = tl.make_block_ptr(
|
if not USE_TMA:
|
||||||
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
|
p_A = tl.make_block_ptr(
|
||||||
)
|
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
|
||||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))
|
)
|
||||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
# [16, 16]
|
||||||
b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
|
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
else:
|
||||||
|
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
|
||||||
|
desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16])
|
||||||
|
b_A = desc.load([i_t * 16, offset]).to(tl.float32)
|
||||||
|
b_A = -tl.where(m_A, b_A, 0)
|
||||||
|
|
||||||
o_i = tl.arange(0, 16)
|
for i in range(2, min(16, T - i_t * 16)):
|
||||||
for i in range(1, min(16, T - i_t * 16)):
|
# [16]
|
||||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||||
mask = o_i == i
|
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
|
||||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
b_A += m_I
|
||||||
b_A += o_i[:, None] == o_i[None, :]
|
if not USE_TMA:
|
||||||
tl.store(
|
p_Ai = tl.make_block_ptr(
|
||||||
p_Ai,
|
Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)
|
||||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
)
|
||||||
boundary_check=(0, 1),
|
tl.store(
|
||||||
)
|
p_Ai,
|
||||||
|
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
|
boundary_check=(0, 1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
@ -88,14 +112,15 @@ def solve_tril_16x16_kernel(
|
|||||||
@triton.jit(do_not_specialize=["T"])
|
@triton.jit(do_not_specialize=["T"])
|
||||||
def merge_16x16_to_32x32_inverse_kernel(
|
def merge_16x16_to_32x32_inverse_kernel(
|
||||||
A,
|
A,
|
||||||
Ad,
|
|
||||||
Ai,
|
Ai,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
chunk_indices,
|
chunk_indices,
|
||||||
T,
|
T,
|
||||||
H: tl.constexpr,
|
H: tl.constexpr,
|
||||||
BT: tl.constexpr,
|
BT: tl.constexpr,
|
||||||
|
USE_TMA: tl.constexpr,
|
||||||
IS_VARLEN: tl.constexpr,
|
IS_VARLEN: tl.constexpr,
|
||||||
|
DOT_PRECISION: tl.constexpr,
|
||||||
):
|
):
|
||||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||||
i_b, i_h = i_bh // H, i_bh % H
|
i_b, i_h = i_bh // H, i_bh % H
|
||||||
@ -112,50 +137,92 @@ def merge_16x16_to_32x32_inverse_kernel(
|
|||||||
else:
|
else:
|
||||||
bos, eos = i_b * T, i_b * T + T
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
|
||||||
A += (bos * H + i_h) * 32
|
o_i = tl.arange(0, 16)
|
||||||
Ad += (bos * H + i_h) * 16
|
m_A = o_i[:, None] > o_i[None, :]
|
||||||
Ai += (bos * H + i_h) * 32
|
m_I = o_i[:, None] == o_i[None, :]
|
||||||
|
A += (bos * H + i_h) * BT
|
||||||
|
Ai += (bos * H + i_h) * BT
|
||||||
|
|
||||||
p_A_21 = tl.make_block_ptr(
|
if not USE_TMA:
|
||||||
A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
|
p_A_11 = tl.make_block_ptr(
|
||||||
)
|
A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
|
||||||
p_Ad_11 = tl.make_block_ptr(
|
)
|
||||||
Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)
|
p_A_22 = tl.make_block_ptr(
|
||||||
)
|
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
|
||||||
p_Ad_22 = tl.make_block_ptr(
|
)
|
||||||
Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
|
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
|
||||||
)
|
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
|
||||||
p_Ai_11 = tl.make_block_ptr(
|
else:
|
||||||
Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)
|
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
|
||||||
)
|
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
|
||||||
p_Ai_22 = tl.make_block_ptr(
|
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
|
||||||
Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)
|
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
|
||||||
)
|
|
||||||
p_Ai_21 = tl.make_block_ptr(
|
# [16, 16]
|
||||||
Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
|
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
|
||||||
|
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
|
||||||
|
|
||||||
|
for i in range(2, min(16, T - i_t * BT)):
|
||||||
|
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
|
||||||
|
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
|
||||||
|
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
|
||||||
|
for i in range(16 + 2, min(32, T - i_t * BT)):
|
||||||
|
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
|
||||||
|
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
|
||||||
|
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
|
||||||
|
|
||||||
|
b_Ai_11 += m_I
|
||||||
|
b_Ai_22 += m_I
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
|
p_A_21 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
else:
|
||||||
|
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
|
||||||
|
|
||||||
|
b_Ai_21 = -tl.dot(
|
||||||
|
tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
|
||||||
|
b_Ai_11,
|
||||||
|
input_precision=DOT_PRECISION,
|
||||||
)
|
)
|
||||||
|
|
||||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
if not USE_TMA:
|
||||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
p_Ai_11 = tl.make_block_ptr(
|
||||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
|
||||||
Ai_21 = -tl.dot(
|
)
|
||||||
tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee"
|
p_Ai_21 = tl.make_block_ptr(
|
||||||
)
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
|
||||||
tl.store(
|
)
|
||||||
p_Ai_11,
|
p_Ai_22 = tl.make_block_ptr(
|
||||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
|
||||||
boundary_check=(0, 1),
|
)
|
||||||
)
|
tl.store(
|
||||||
tl.store(
|
p_Ai_11,
|
||||||
p_Ai_22,
|
b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
boundary_check=(0, 1),
|
||||||
boundary_check=(0, 1),
|
)
|
||||||
)
|
tl.store(
|
||||||
tl.store(
|
p_Ai_22,
|
||||||
p_Ai_21,
|
b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
boundary_check=(0, 1),
|
||||||
boundary_check=(0, 1),
|
)
|
||||||
)
|
tl.store(
|
||||||
|
p_Ai_21,
|
||||||
|
b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
|
boundary_check=(0, 1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
desc_o.store(
|
||||||
|
[i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
|
)
|
||||||
|
desc_o.store(
|
||||||
|
[i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
|
)
|
||||||
|
desc_o.store(
|
||||||
|
[i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
@ -170,14 +237,15 @@ def merge_16x16_to_32x32_inverse_kernel(
|
|||||||
@triton.jit(do_not_specialize=["T"])
|
@triton.jit(do_not_specialize=["T"])
|
||||||
def merge_16x16_to_64x64_inverse_kernel(
|
def merge_16x16_to_64x64_inverse_kernel(
|
||||||
A,
|
A,
|
||||||
Ad,
|
|
||||||
Ai,
|
Ai,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
chunk_indices,
|
chunk_indices,
|
||||||
T,
|
T,
|
||||||
H: tl.constexpr,
|
H: tl.constexpr,
|
||||||
BT: tl.constexpr,
|
BT: tl.constexpr,
|
||||||
|
USE_TMA: tl.constexpr,
|
||||||
IS_VARLEN: tl.constexpr,
|
IS_VARLEN: tl.constexpr,
|
||||||
|
DOT_PRECISION: tl.constexpr,
|
||||||
):
|
):
|
||||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||||
i_b, i_h = i_bh // H, i_bh % H
|
i_b, i_h = i_bh // H, i_bh % H
|
||||||
@ -194,213 +262,245 @@ def merge_16x16_to_64x64_inverse_kernel(
|
|||||||
else:
|
else:
|
||||||
bos, eos = i_b * T, i_b * T + T
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
|
||||||
A += (bos * H + i_h) * 64
|
o_i = tl.arange(0, 16)
|
||||||
Ad += (bos * H + i_h) * 16
|
m_A = o_i[:, None] > o_i[None, :]
|
||||||
Ai += (bos * H + i_h) * 64
|
m_I = o_i[:, None] == o_i[None, :]
|
||||||
|
A += (bos * H + i_h) * BT
|
||||||
|
Ai += (bos * H + i_h) * BT
|
||||||
|
|
||||||
p_A_21 = tl.make_block_ptr(
|
if not USE_TMA:
|
||||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
|
p_A_11 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
p_A_22 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
p_A_33 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
p_A_44 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
else:
|
||||||
|
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
|
||||||
|
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
|
||||||
|
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
|
||||||
|
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
|
||||||
|
b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32)
|
||||||
|
b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32)
|
||||||
|
|
||||||
|
# [16, 16]
|
||||||
|
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
|
||||||
|
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
|
||||||
|
b_Ai_33 = -tl.where(m_A, b_Ai_33, 0)
|
||||||
|
b_Ai_44 = -tl.where(m_A, b_Ai_44, 0)
|
||||||
|
|
||||||
|
for i in range(2, min(16, T - i_t * BT)):
|
||||||
|
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
|
||||||
|
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
|
||||||
|
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
|
||||||
|
for i in range(16 + 2, min(32, T - i_t * BT)):
|
||||||
|
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
|
||||||
|
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
|
||||||
|
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
|
||||||
|
for i in range(32 + 2, min(48, T - i_t * BT)):
|
||||||
|
b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32)
|
||||||
|
b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0)
|
||||||
|
b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33)
|
||||||
|
for i in range(48 + 2, min(64, T - i_t * BT)):
|
||||||
|
b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48)
|
||||||
|
b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0)
|
||||||
|
b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44)
|
||||||
|
b_Ai_11 += m_I
|
||||||
|
b_Ai_22 += m_I
|
||||||
|
b_Ai_33 += m_I
|
||||||
|
b_Ai_44 += m_I
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
|
p_A_21 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
p_A_31 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
p_A_32 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
p_A_41 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
p_A_42 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
p_A_43 = tl.make_block_ptr(
|
||||||
|
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
|
||||||
|
)
|
||||||
|
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
else:
|
||||||
|
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
|
||||||
|
b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32)
|
||||||
|
b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32)
|
||||||
|
b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32)
|
||||||
|
b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32)
|
||||||
|
b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32)
|
||||||
|
|
||||||
|
b_Ai_21 = -tl.dot(
|
||||||
|
tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
|
||||||
|
b_Ai_11,
|
||||||
|
input_precision=DOT_PRECISION,
|
||||||
)
|
)
|
||||||
p_A_32 = tl.make_block_ptr(
|
b_Ai_32 = -tl.dot(
|
||||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)
|
tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION),
|
||||||
|
b_Ai_22,
|
||||||
|
input_precision=DOT_PRECISION,
|
||||||
)
|
)
|
||||||
p_A_31 = tl.make_block_ptr(
|
b_Ai_43 = -tl.dot(
|
||||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
|
tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION),
|
||||||
)
|
b_Ai_33,
|
||||||
p_A_43 = tl.make_block_ptr(
|
input_precision=DOT_PRECISION,
|
||||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
p_A_42 = tl.make_block_ptr(
|
|
||||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
p_A_41 = tl.make_block_ptr(
|
|
||||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
p_Ad_11 = tl.make_block_ptr(
|
|
||||||
Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
p_Ad_22 = tl.make_block_ptr(
|
|
||||||
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
p_Ad_33 = tl.make_block_ptr(
|
|
||||||
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
p_Ad_44 = tl.make_block_ptr(
|
|
||||||
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
b_Ai_31 = -tl.dot(
|
||||||
A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
|
b_Ai_33,
|
||||||
A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
|
tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION)
|
||||||
A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
|
+ tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION),
|
||||||
A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
|
input_precision=DOT_PRECISION,
|
||||||
A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
|
|
||||||
|
|
||||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
|
||||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
|
||||||
Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
|
|
||||||
Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
|
|
||||||
|
|
||||||
Ai_21 = -tl.dot(
|
|
||||||
tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee"
|
|
||||||
)
|
)
|
||||||
Ai_32 = -tl.dot(
|
b_Ai_42 = -tl.dot(
|
||||||
tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee"
|
b_Ai_44,
|
||||||
|
tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION)
|
||||||
|
+ tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION),
|
||||||
|
input_precision=DOT_PRECISION,
|
||||||
)
|
)
|
||||||
Ai_43 = -tl.dot(
|
b_Ai_41 = -tl.dot(
|
||||||
tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee"
|
b_Ai_44,
|
||||||
|
tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION)
|
||||||
|
+ tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION)
|
||||||
|
+ tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION),
|
||||||
|
input_precision=DOT_PRECISION,
|
||||||
)
|
)
|
||||||
|
|
||||||
Ai_31 = -tl.dot(
|
if not USE_TMA:
|
||||||
Ai_33,
|
p_Ai_11 = tl.make_block_ptr(
|
||||||
tl.dot(A_31, Ai_11, input_precision="ieee")
|
Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
|
||||||
+ tl.dot(A_32, Ai_21, input_precision="ieee"),
|
)
|
||||||
input_precision="ieee",
|
p_Ai_22 = tl.make_block_ptr(
|
||||||
)
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
|
||||||
Ai_42 = -tl.dot(
|
)
|
||||||
Ai_44,
|
p_Ai_33 = tl.make_block_ptr(
|
||||||
tl.dot(A_42, Ai_22, input_precision="ieee")
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
|
||||||
+ tl.dot(A_43, Ai_32, input_precision="ieee"),
|
)
|
||||||
input_precision="ieee",
|
p_Ai_44 = tl.make_block_ptr(
|
||||||
)
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
|
||||||
Ai_41 = -tl.dot(
|
)
|
||||||
Ai_44,
|
p_Ai_21 = tl.make_block_ptr(
|
||||||
tl.dot(A_41, Ai_11, input_precision="ieee")
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
|
||||||
+ tl.dot(A_42, Ai_21, input_precision="ieee")
|
)
|
||||||
+ tl.dot(A_43, Ai_31, input_precision="ieee"),
|
p_Ai_31 = tl.make_block_ptr(
|
||||||
input_precision="ieee",
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
|
||||||
)
|
)
|
||||||
|
p_Ai_32 = tl.make_block_ptr(
|
||||||
p_Ai_11 = tl.make_block_ptr(
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)
|
)
|
||||||
)
|
p_Ai_41 = tl.make_block_ptr(
|
||||||
p_Ai_22 = tl.make_block_ptr(
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)
|
)
|
||||||
)
|
p_Ai_42 = tl.make_block_ptr(
|
||||||
p_Ai_33 = tl.make_block_ptr(
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)
|
)
|
||||||
)
|
p_Ai_43 = tl.make_block_ptr(
|
||||||
p_Ai_44 = tl.make_block_ptr(
|
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)
|
)
|
||||||
)
|
tl.store(
|
||||||
p_Ai_21 = tl.make_block_ptr(
|
p_Ai_11,
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
|
b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
)
|
boundary_check=(0, 1),
|
||||||
p_Ai_31 = tl.make_block_ptr(
|
)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
|
tl.store(
|
||||||
)
|
p_Ai_22,
|
||||||
p_Ai_32 = tl.make_block_ptr(
|
b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)
|
boundary_check=(0, 1),
|
||||||
)
|
)
|
||||||
p_Ai_41 = tl.make_block_ptr(
|
tl.store(
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
|
p_Ai_33,
|
||||||
)
|
b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
p_Ai_42 = tl.make_block_ptr(
|
boundary_check=(0, 1),
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)
|
)
|
||||||
)
|
tl.store(
|
||||||
p_Ai_43 = tl.make_block_ptr(
|
p_Ai_44,
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)
|
b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
)
|
boundary_check=(0, 1),
|
||||||
tl.store(
|
)
|
||||||
p_Ai_11,
|
tl.store(
|
||||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
p_Ai_21,
|
||||||
boundary_check=(0, 1),
|
b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
)
|
boundary_check=(0, 1),
|
||||||
tl.store(
|
)
|
||||||
p_Ai_22,
|
tl.store(
|
||||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
p_Ai_31,
|
||||||
boundary_check=(0, 1),
|
b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
)
|
boundary_check=(0, 1),
|
||||||
tl.store(
|
)
|
||||||
p_Ai_33,
|
tl.store(
|
||||||
Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
|
p_Ai_32,
|
||||||
boundary_check=(0, 1),
|
b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
)
|
boundary_check=(0, 1),
|
||||||
tl.store(
|
)
|
||||||
p_Ai_44,
|
tl.store(
|
||||||
Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
|
p_Ai_41,
|
||||||
boundary_check=(0, 1),
|
b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
)
|
boundary_check=(0, 1),
|
||||||
tl.store(
|
)
|
||||||
p_Ai_21,
|
tl.store(
|
||||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
p_Ai_42,
|
||||||
boundary_check=(0, 1),
|
b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
)
|
boundary_check=(0, 1),
|
||||||
tl.store(
|
)
|
||||||
p_Ai_31,
|
tl.store(
|
||||||
Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
|
p_Ai_43,
|
||||||
boundary_check=(0, 1),
|
b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||||
)
|
boundary_check=(0, 1),
|
||||||
tl.store(
|
)
|
||||||
p_Ai_32,
|
else:
|
||||||
Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
|
desc_o.store(
|
||||||
boundary_check=(0, 1),
|
[i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
)
|
)
|
||||||
tl.store(
|
desc_o.store(
|
||||||
p_Ai_41,
|
[i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
|
)
|
||||||
boundary_check=(0, 1),
|
desc_o.store(
|
||||||
)
|
[i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
tl.store(
|
)
|
||||||
p_Ai_42,
|
desc_o.store(
|
||||||
Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
|
[i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
boundary_check=(0, 1),
|
)
|
||||||
)
|
desc_o.store(
|
||||||
tl.store(
|
[i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
p_Ai_43,
|
)
|
||||||
Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
|
desc_o.store(
|
||||||
boundary_check=(0, 1),
|
[i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
)
|
)
|
||||||
|
desc_o.store(
|
||||||
fill_zeros = tl.zeros((16, 16), dtype=tl.float32)
|
[i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
p_Ai_12 = tl.make_block_ptr(
|
)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)
|
desc_o.store(
|
||||||
)
|
[i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
p_Ai_13 = tl.make_block_ptr(
|
)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)
|
desc_o.store(
|
||||||
)
|
[i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
p_Ai_14 = tl.make_block_ptr(
|
)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)
|
desc_o.store(
|
||||||
)
|
[i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||||
p_Ai_23 = tl.make_block_ptr(
|
)
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
p_Ai_24 = tl.make_block_ptr(
|
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
p_Ai_34 = tl.make_block_ptr(
|
|
||||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
p_Ai_12,
|
|
||||||
fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
||||||
boundary_check=(0, 1),
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
p_Ai_13,
|
|
||||||
fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
||||||
boundary_check=(0, 1),
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
p_Ai_14,
|
|
||||||
fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
||||||
boundary_check=(0, 1),
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
p_Ai_23,
|
|
||||||
fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
||||||
boundary_check=(0, 1),
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
p_Ai_24,
|
|
||||||
fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
||||||
boundary_check=(0, 1),
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
p_Ai_34,
|
|
||||||
fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
||||||
boundary_check=(0, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@input_guard
|
@input_guard
|
||||||
@ -410,62 +510,47 @@ def solve_tril(
|
|||||||
output_dtype: torch.dtype = torch.float,
|
output_dtype: torch.dtype = torch.float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute the inverse of the lower triangular matrix
|
Compute the inverse of the matrix I + A
|
||||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
A (torch.Tensor):
|
A (torch.Tensor):
|
||||||
[B, T, H, K]
|
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
||||||
cu_seqlens (torch.Tensor):
|
cu_seqlens (torch.Tensor):
|
||||||
The cumulative sequence lengths of the input tensor.
|
The cumulative sequence lengths of the input tensor. Default: `None`.
|
||||||
Default: None.
|
|
||||||
output_dtype (torch.dtype):
|
output_dtype (torch.dtype):
|
||||||
The dtype of the output tensor. Default: `torch.float`
|
The dtype of the output tensor. Default: `torch.float`.
|
||||||
|
If `None`, the output dtype will be the same as the input dtype.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(I + A)^-1 with the same shape as A
|
(I + A)^-1 with the same shape as A
|
||||||
"""
|
"""
|
||||||
assert A.shape[-1] in [16, 32, 64]
|
assert A.shape[-1] in [16, 32, 64]
|
||||||
|
output_dtype = A.dtype if output_dtype is None else output_dtype
|
||||||
|
|
||||||
B, T, H, BT = A.shape
|
B, T, H, BT = A.shape
|
||||||
Ad = torch.empty(
|
|
||||||
B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk_indices = (
|
|
||||||
prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
|
|
||||||
)
|
|
||||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
|
|
||||||
solve_tril_16x16_kernel[NT, B * H](
|
|
||||||
A=A,
|
|
||||||
Ad=Ad,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
chunk_indices=chunk_indices,
|
|
||||||
T=T,
|
|
||||||
H=H,
|
|
||||||
BT=BT,
|
|
||||||
)
|
|
||||||
if BT == 16:
|
|
||||||
return Ad
|
|
||||||
|
|
||||||
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
|
||||||
merge_fn = (
|
|
||||||
merge_16x16_to_32x32_inverse_kernel
|
|
||||||
if BT == 32
|
|
||||||
else merge_16x16_to_64x64_inverse_kernel
|
|
||||||
)
|
|
||||||
chunk_indices = (
|
chunk_indices = (
|
||||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
)
|
)
|
||||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||||
|
|
||||||
|
Ai = torch.zeros_like(A, dtype=output_dtype)
|
||||||
|
if BT == 16:
|
||||||
|
merge_fn = solve_tril_16x16_kernel
|
||||||
|
elif BT == 32:
|
||||||
|
merge_fn = merge_16x16_to_32x32_inverse_kernel
|
||||||
|
elif BT == 64:
|
||||||
|
merge_fn = merge_16x16_to_64x64_inverse_kernel
|
||||||
|
|
||||||
merge_fn[NT, B * H](
|
merge_fn[NT, B * H](
|
||||||
A=A,
|
A=A,
|
||||||
Ad=Ad,
|
|
||||||
Ai=Ai,
|
Ai=Ai,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
T=T,
|
T=T,
|
||||||
H=H,
|
H=H,
|
||||||
BT=BT,
|
BT=BT,
|
||||||
|
USE_TMA=is_tma_supported,
|
||||||
|
DOT_PRECISION=FLA_TRIL_PRECISION,
|
||||||
)
|
)
|
||||||
return Ai
|
return Ai
|
||||||
|
|||||||
@ -150,6 +150,11 @@ is_nvidia_hopper = is_nvidia and (
|
|||||||
or torch.cuda.get_device_capability()[0] >= 9
|
or torch.cuda.get_device_capability()[0] >= 9
|
||||||
)
|
)
|
||||||
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
||||||
|
is_gather_supported = hasattr(triton.language, "gather")
|
||||||
|
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and (
|
||||||
|
hasattr(triton.language, "_experimental_make_tensor_descriptor")
|
||||||
|
or hasattr(triton.language, "make_tensor_descriptor")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_all_max_shared_mem():
|
def get_all_max_shared_mem():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user