mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 03:11:52 +08:00
[Kernel] Enable moe LoRA kernel support FP16 (#27468)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
a663f6ae64
commit
f4e8154076
@ -204,6 +204,11 @@ def use_torch(
|
||||
return torch.stack(outputs, dim=0)
|
||||
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
DEVICES = [f"cuda:{0}"]
|
||||
SEED = [42]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [100])
|
||||
@pytest.mark.parametrize("top_k_num", [6, 12])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@ -212,6 +217,9 @@ def use_torch(
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
def test_fused_moe_lora_kernel(
|
||||
num_tokens,
|
||||
top_k_num,
|
||||
@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel(
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
dtype,
|
||||
device,
|
||||
seed,
|
||||
):
|
||||
torch.set_default_device("cuda:0")
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
# the number of randomly generated sentences.
|
||||
num_sequences = 10
|
||||
# generate data
|
||||
@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel(
|
||||
max_lora_rank,
|
||||
K,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
lora_b_stacked = [
|
||||
@ -251,7 +262,7 @@ def test_fused_moe_lora_kernel(
|
||||
N,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
hidden_states = torch.rand(
|
||||
@ -259,11 +270,11 @@ def test_fused_moe_lora_kernel(
|
||||
num_tokens,
|
||||
K,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# fused_moe_lora_kernel output
|
||||
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
|
||||
output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
|
||||
use_fused_moe_lora_kernel(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
|
||||
@ -2,9 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||
@ -110,7 +109,7 @@ def _fused_moe_lora_kernel(
|
||||
|
||||
# get a_ptr,b_ptr,c_ptr
|
||||
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
|
||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16))
|
||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
|
||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
@ -154,7 +153,7 @@ def _fused_moe_lora_kernel(
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
accumulator = accumulator.to(tl.bfloat16)
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
@ -205,6 +204,10 @@ def _fused_moe_lora(
|
||||
assert output.shape[0] == topk_weights.shape[0]
|
||||
assert top_k_num == topk_weights.shape[1]
|
||||
|
||||
for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked):
|
||||
assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype
|
||||
assert lora_a.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
device = qcurr_hidden_states.device
|
||||
num_slices = len(lora_a_stacked)
|
||||
|
||||
@ -227,9 +230,9 @@ def _fused_moe_lora(
|
||||
num_tokens = M * top_k_num
|
||||
w1_output_dim_size = w1_lora_b_stacked.shape[2]
|
||||
|
||||
lora_intermediate_cache1 = torch.zeros(
|
||||
lora_intermediate_cache1 = torch.empty(
|
||||
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
|
||||
dtype=torch.bfloat16,
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@ -288,10 +291,6 @@ def _fused_moe_lora(
|
||||
K = max_lora_rank
|
||||
N = w1_output_dim_size
|
||||
|
||||
# a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||
# M, -1, a_intermediate_cache1.shape[3]
|
||||
# )
|
||||
|
||||
a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||
-1, a_intermediate_cache1.shape[3]
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user