[Kernel] Enable moe LoRA kernel support FP16 (#27468)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-10-27 19:48:37 +08:00 committed by GitHub
parent a663f6ae64
commit f4e8154076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 16 deletions

View File

@ -204,6 +204,11 @@ def use_torch(
return torch.stack(outputs, dim=0)
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [42]
@pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6, 12])
@pytest.mark.parametrize("num_experts", [64])
@ -212,6 +217,9 @@ def use_torch(
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_fused_moe_lora_kernel(
num_tokens,
top_k_num,
@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel(
K,
max_lora_rank,
block_size,
dtype,
device,
seed,
):
torch.set_default_device("cuda:0")
current_platform.seed_everything(42)
torch.set_default_device(device)
current_platform.seed_everything(seed)
# the number of randomly generated sentences.
num_sequences = 10
# generate data
@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel(
max_lora_rank,
K,
),
dtype=torch.bfloat16,
dtype=dtype,
)
]
lora_b_stacked = [
@ -251,7 +262,7 @@ def test_fused_moe_lora_kernel(
N,
max_lora_rank,
),
dtype=torch.bfloat16,
dtype=dtype,
)
]
hidden_states = torch.rand(
@ -259,11 +270,11 @@ def test_fused_moe_lora_kernel(
num_tokens,
K,
),
dtype=torch.bfloat16,
dtype=dtype,
)
# fused_moe_lora_kernel output
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
use_fused_moe_lora_kernel(
topk_ids,
topk_weights,

View File

@ -2,9 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
@ -110,7 +109,7 @@ def _fused_moe_lora_kernel(
# get a_ptr,b_ptr,c_ptr
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16))
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
@ -154,7 +153,7 @@ def _fused_moe_lora_kernel(
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(tl.bfloat16)
accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
@ -205,6 +204,10 @@ def _fused_moe_lora(
assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1]
for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked):
assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype
assert lora_a.dtype in [torch.float16, torch.bfloat16]
device = qcurr_hidden_states.device
num_slices = len(lora_a_stacked)
@ -227,9 +230,9 @@ def _fused_moe_lora(
num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2]
lora_intermediate_cache1 = torch.zeros(
lora_intermediate_cache1 = torch.empty(
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
dtype=torch.bfloat16,
dtype=output.dtype,
device=device,
)
@ -288,10 +291,6 @@ def _fused_moe_lora(
K = max_lora_rank
N = w1_output_dim_size
# a_intermediate_cache1 = a_intermediate_cache1.view(
# M, -1, a_intermediate_cache1.shape[3]
# )
a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3]
)