[Kernel] Enable moe LoRA kernel support FP16 (#27468)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-10-27 19:48:37 +08:00 committed by GitHub
parent a663f6ae64
commit f4e8154076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 16 deletions

View File

@ -204,6 +204,11 @@ def use_torch(
return torch.stack(outputs, dim=0) return torch.stack(outputs, dim=0)
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [42]
@pytest.mark.parametrize("num_tokens", [100]) @pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6, 12]) @pytest.mark.parametrize("top_k_num", [6, 12])
@pytest.mark.parametrize("num_experts", [64]) @pytest.mark.parametrize("num_experts", [64])
@ -212,6 +217,9 @@ def use_torch(
@pytest.mark.parametrize("K", [2048]) @pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) @pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_fused_moe_lora_kernel( def test_fused_moe_lora_kernel(
num_tokens, num_tokens,
top_k_num, top_k_num,
@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel(
K, K,
max_lora_rank, max_lora_rank,
block_size, block_size,
dtype,
device,
seed,
): ):
torch.set_default_device("cuda:0") torch.set_default_device(device)
current_platform.seed_everything(42) current_platform.seed_everything(seed)
# the number of randomly generated sentences. # the number of randomly generated sentences.
num_sequences = 10 num_sequences = 10
# generate data # generate data
@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel(
max_lora_rank, max_lora_rank,
K, K,
), ),
dtype=torch.bfloat16, dtype=dtype,
) )
] ]
lora_b_stacked = [ lora_b_stacked = [
@ -251,7 +262,7 @@ def test_fused_moe_lora_kernel(
N, N,
max_lora_rank, max_lora_rank,
), ),
dtype=torch.bfloat16, dtype=dtype,
) )
] ]
hidden_states = torch.rand( hidden_states = torch.rand(
@ -259,11 +270,11 @@ def test_fused_moe_lora_kernel(
num_tokens, num_tokens,
K, K,
), ),
dtype=torch.bfloat16, dtype=dtype,
) )
# fused_moe_lora_kernel output # fused_moe_lora_kernel output
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16) output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
use_fused_moe_lora_kernel( use_fused_moe_lora_kernel(
topk_ids, topk_ids,
topk_weights, topk_weights,

View File

@ -2,9 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
@ -110,7 +109,7 @@ def _fused_moe_lora_kernel(
# get a_ptr,b_ptr,c_ptr # get a_ptr,b_ptr,c_ptr
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16)) cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
@ -154,7 +153,7 @@ def _fused_moe_lora_kernel(
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(tl.bfloat16) accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
@ -205,6 +204,10 @@ def _fused_moe_lora(
assert output.shape[0] == topk_weights.shape[0] assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1] assert top_k_num == topk_weights.shape[1]
for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked):
assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype
assert lora_a.dtype in [torch.float16, torch.bfloat16]
device = qcurr_hidden_states.device device = qcurr_hidden_states.device
num_slices = len(lora_a_stacked) num_slices = len(lora_a_stacked)
@ -227,9 +230,9 @@ def _fused_moe_lora(
num_tokens = M * top_k_num num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2] w1_output_dim_size = w1_lora_b_stacked.shape[2]
lora_intermediate_cache1 = torch.zeros( lora_intermediate_cache1 = torch.empty(
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
dtype=torch.bfloat16, dtype=output.dtype,
device=device, device=device,
) )
@ -288,10 +291,6 @@ def _fused_moe_lora(
K = max_lora_rank K = max_lora_rank
N = w1_output_dim_size N = w1_output_dim_size
# a_intermediate_cache1 = a_intermediate_cache1.view(
# M, -1, a_intermediate_cache1.shape[3]
# )
a_intermediate_cache1 = a_intermediate_cache1.view( a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3] -1, a_intermediate_cache1.shape[3]
) )