mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:24:29 +08:00
[Kernel] Optimization of the mm_k operator. (#28280)
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
b06b9470ca
commit
40e2eeeb92
@ -23,6 +23,7 @@ def mm_k(
|
||||
CAST_TYPE: tl.constexpr,
|
||||
b_dtype: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
base_k,
|
||||
):
|
||||
"""
|
||||
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||
@ -47,32 +48,62 @@ def mm_k(
|
||||
matrix dtype.
|
||||
b_dtype: datatype of the B matrix
|
||||
USE_GDC: Whether to use PDL. True indicates use.
|
||||
base_k: Base offset along K dimension for current SPLIT_K group
|
||||
"""
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
|
||||
# Step size along K for each iteration
|
||||
STEP_K = BLOCK_K * SPLIT_K
|
||||
|
||||
# Total number of iterations (compile-time constant)
|
||||
num_iters = tl.cdiv(K, STEP_K)
|
||||
|
||||
for k in range(num_iters):
|
||||
# Current iteration's global K offset
|
||||
iter_k = k * STEP_K + base_k
|
||||
|
||||
# Check if this iteration is completely valid (no masking needed)
|
||||
block_end = iter_k + BLOCK_K
|
||||
|
||||
if EVEN_K:
|
||||
# pre-fetech lora weight
|
||||
# K is divisible by BLOCK_K, no masking ever needed
|
||||
# pre-fetch lora weight
|
||||
tiled_b = tl.load(b_ptr)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
else:
|
||||
tiled_b = tl.load(
|
||||
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
|
||||
)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(
|
||||
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
|
||||
)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * SPLIT_K * ak_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * bk_stride
|
||||
# Check if we need element-wise masking
|
||||
if iter_k >= K:
|
||||
# Entire block out of range, skip
|
||||
pass
|
||||
elif block_end <= K:
|
||||
# Entire block in range, no masking needed (fast path)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
else:
|
||||
# Partial block, need masking (only last iteration)
|
||||
k_offsets = tl.arange(0, BLOCK_K)
|
||||
mask = iter_k + k_offsets < K
|
||||
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
|
||||
a_ptr += STEP_K * ak_stride
|
||||
b_ptr += STEP_K * bk_stride
|
||||
|
||||
return accumulator
|
||||
|
||||
|
||||
@ -178,6 +209,7 @@ def do_expand_kernel(
|
||||
CAST_TYPE,
|
||||
cur_lora_ptr.dtype.element_ty,
|
||||
USE_GDC,
|
||||
base_k=0,
|
||||
)
|
||||
|
||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||
@ -284,6 +316,7 @@ def do_shrink_kernel(
|
||||
False,
|
||||
cur_lora_ptr.dtype.element_ty,
|
||||
False, # USE_GDC is always False in shrink kernel
|
||||
base_k=pid_sk * BLOCK_K,
|
||||
)
|
||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||
if USE_GDC:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user